├── .gitignore
├── LICENSE
├── Makefile
├── README.md
├── images
└── motivation.png
├── infosol
├── alignment.py
├── decoding.py
├── env.py
├── evaluate.py
├── modeling_utils.py
├── models
│ ├── __pycache__
│ │ └── word_edit_model.cpython-310.pyc
│ └── word_edit_model.py
├── train.py
└── utils
│ ├── data
│ └── cnn_dailymail.py
│ ├── keywords.py
│ └── pointer_utils.py
├── jobs
└── interactive
│ └── cnn-bart-s2s-len
├── requirements.txt
├── scripts
├── dowload_models.sh
├── make_data.py
├── make_eval_jobs.py
└── run_eval.py
└── setup.py
/.gitignore:
--------------------------------------------------------------------------------
1 | *.pyc
2 | infosol_models.tar.zst
3 | data/
4 | models/
5 | jobs/main/
6 | jobs/interactive/cnn-bart-editor
7 | jobs/interactive/cnn-bart-s2s
8 | jobs/interactive/cnn-bart_editor_large
9 | out/
10 | infosol.egg-info/
11 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) Microsoft Corporation.
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE
22 |
--------------------------------------------------------------------------------
/Makefile:
--------------------------------------------------------------------------------
1 | #Makefile for training
2 | CUDA_DEVICE ?= 0
3 |
4 | # CNN BASE
5 | base_run_dir := #/path/to/your/run_dir/
6 | base-cnn-bart_editor_large:
7 | python infosol/train.py --cuda_device $(CUDA_DEVICE) --run_dir $(base_run_dir) --run_name cnn-bart_editor_large base --n_epochs 6 --max_train_edits 1000000 --max_val_edits 2000 --model_name bart-large --track_gradient_norm True
8 | base-cnn-bart_editor_large-clip_grad_norm:
9 | python infosol/train.py --cuda_device $(CUDA_DEVICE) --run_dir $(base_run_dir) --run_name cnn-bart_editor_large base --n_epochs 6 --max_train_edits 1000000 --max_val_edits 2000 --model_name bart-large --track_gradient_norm True --clip_grad_norm True
10 |
11 |
12 | # CNN DAGGER
13 | run_dir := #/path/to/your/run_dir/
14 | dagger_args_ := --use_timestamp False --run_dir $(run_dir) --cuda_device $(CUDA_DEVICE) dagger --n_epochs 600 --max_train_edits 1000000 --max_val_edits 2000 --n_warmup_epochs 300 --dagger_sampling_rate 1.0 --sample_batch_size 10000 --val_sample_batch_size 1000
15 | dagger_args := $(dagger_args_) --sampling_annealing_rate 0.9
16 |
17 | # main
18 | cnn-bart_editor:
19 | python infosol/train.py --run_name cnn-bart_editor $(dagger_args)
20 | cnn-bart_editor_large:
21 | python infosol/train.py --run_name cnn-bart_editor_large $(dagger_args) --model_name bart-large --track_gradient_norm True --clip_grad_norm True
22 | cnn-bart_s2s:
23 | python infosol/train.py --run_name cnn-bart_s2s $(dagger_args) --model_name barts2s
24 | # different oracles
25 | cnn-bart_editor-adj_edits:
26 | python infosol/train.py --run_name cnn-bart_editor-adj_edits $(dagger_args) --adjacent_ops True
27 | cnn-bart_editor-contig_edits:
28 | python infosol/train.py --run_name cnn-bart_editor-contig_edits $(dagger_args) --contiguous_edits True
29 | cnn-bart_editor-adj_edits-contig_edits:
30 | python infosol/train.py --run_name cnn-bart_editor-adj_edits-contig_edits $(dagger_args) --adjacent_ops True --contiguous_edits True
31 |
32 | # noise fractions
33 | cnn-bart_editor-noise_0.0:
34 | python infosol/train.py --run_name cnn-bart_editor-noise_0.0 $(dagger_args) --noise_frac 0
35 | cnn-bart_editor-noise_0.1:
36 | python infosol/train.py --run_name cnn-bart_editor-noise_0.1 $(dagger_args) --noise_frac 0.1
37 | cnn-bart_editor-noise_0.2:
38 | python infosol/train.py --run_name cnn-bart_editor-noise_0.2 $(dagger_args) --noise_frac 0.2
39 | # annealing rate
40 | cnn-bart_editor-anneal_0.85:
41 | python infosol/train.py --run_name cnn-bart_editor-anneal_0.85 $(dagger_args_) --sampling_annealing_rate 0.85
42 | cnn-bart_editor-anneal_0.80:
43 | python infosol/train.py --run_name cnn-bart_editor-anneal_0.80 $(dagger_args_) --sampling_annealing_rate 0.80
44 |
45 |
46 |
47 |
48 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Overview
2 |
3 | This project offers an implementation of the paper:
4 |
5 | **Interactive Text Generation**\
6 | Felix Faltings, Michel Galley, Baolin Peng, Kianté Brantley, Weixin Cai, Yizhe Zhang, Jianfeng Gao, Bill Dolan\
7 | [arXiv](https://arxiv.org/abs/2303.00908)
8 |
9 |
10 |
11 | # Installation
12 |
13 | Install dependencies using requirements.txt:
14 | `pip install -r requirements.txt`
15 | Then, install the package. From the top level directory (where `setup.py` is located), run:
16 | `pip install -e .`
17 | This will install this package as an editable module named `infosol`.
18 |
19 | Download model files (requires wget and zstd):
20 | `sh ./scripts/dowload_models.sh`
21 |
22 | # DATA
23 |
24 | You can regenerate the data used in the paper using the `make_data.py` script. You only need to specify the `data_dir` argument where the data will be saved (under `data/cnn_bart`). This script first downloads the raw data (CNN/DailyMail) from the Huggingface hub. The script can easily be adapted to generate other textual datasets from the hub.
25 |
26 | `python scripts/make_data.py --data_dir=data`
27 |
28 | # Test
29 |
30 | To replicate the main experiments of the paper, run:
31 |
32 | `python scripts/make_eval_jobs.py --model_dir=models --data_dir=data/cnn --job_dir=jobs --out_dir=out`
33 |
34 | The above command creates jobs files in 'jobs' directory, as well as the directory structure ('out') where test results will be stored. Then, you can pass any of the configuration files under 'jobs' as argument to `scripts/run_eval.py`. For example, run the following to replicate the BART-large 'interactive' experiments (Table 4 of the paper):
35 |
36 | `python scripts/run_eval.py --args_path jobs/interactive/cnn-bart_editor_large --cuda_device 0`
37 |
38 | Note: The S2S experiments of the paper yield generation that were inconsisent in length and hurt S2S performance. Thus, we tuned its length_penalty hyperparameter on a held out set, and the corresponding job files can be found in jobs/interactive/cnn-bart-s2s-len.
39 |
40 | # Train
41 |
42 | In order to train all the models presented in the paper, you may use the provided Makefile. Set the `run_dir` variable to the directory where you would like model weights to be saved to. You also need to set the `DATA_DIR` and `SAVE_DIR` paths in `train.py`.
43 |
44 | # Code walkthrough
45 |
46 | We suggest you go through the code in the order alignment -> model -> train/evaluate. Alignment defines some basic objects like alignments (used heavily by the oracle) and canvases (the objects that the model and oracle operate on). The others are self-explanatory and are commented.
47 |
48 | The main files:
49 | - alignment.py
50 | - env.py
51 | - evaluate.py
52 | - models/word_edit_model.py
53 | - run_eval.py
54 | - train.py
55 | - (decoding.py)
56 |
--------------------------------------------------------------------------------
/images/motivation.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ffaltings/InteractiveTextGeneration/febd658a91227dd88fbc5355111382b732a91647/images/motivation.png
--------------------------------------------------------------------------------
/infosol/alignment.py:
--------------------------------------------------------------------------------
1 | import random
2 | import itertools
3 | import torch
4 | import numpy as np
5 |
6 | from transformers import AutoTokenizer, BertModel
7 |
8 | def find_max_score(alignment_matrix, score_matrix, i, j, baseline_score=0):
9 | """
10 | Maximum alignment score up to this position, given max alignment scores up to previous positions
11 | """
12 | del_score = alignment_matrix[i, j-1] + baseline_score
13 | ins_score = alignment_matrix[i-1, j] + baseline_score
14 | match_score = alignment_matrix[i-1,j-1] + score_matrix[i-1, j-1]
15 | scores = (ins_score, del_score, match_score) # this order matters for priority b/w insert/del
16 | return np.max(scores), np.argmax(scores)
17 |
18 | def compute_alignment_scores(score_matrix, baseline_score=0):
19 | """
20 | Compute maximum alignment score for each position using DP
21 | """
22 | alignment_matrix = np.zeros((score_matrix.shape[0] + 1, score_matrix.shape[1] + 1))
23 | for i in range(1, alignment_matrix.shape[0]):
24 | alignment_matrix[i,0] = i * (baseline_score*10)/10 #avoids a weird bug with floating point precision
25 | for j in range(1, alignment_matrix.shape[1]):
26 | alignment_matrix[0,j] = j * (baseline_score*10)/10
27 | move_matrix = np.zeros(alignment_matrix.shape)
28 | for j in range(1,alignment_matrix.shape[1]):
29 | for i in range(1,alignment_matrix.shape[0]):
30 | alignment_matrix[i,j], move_matrix[i,j] = find_max_score(alignment_matrix, score_matrix, i, j, baseline_score)
31 | return alignment_matrix, move_matrix
32 |
33 | def get_alignment(move_matrix, tokens_a, tokens_b, score_matrix=None, baseline_score=0.):
34 | """
35 | Retrieve the actual alignment from the matrix of alignment scores,
36 | and matrix of "moves" thru the score matrix
37 | """
38 | i,j = move_matrix.shape
39 | i -= 1
40 | j -= 1
41 | alignment = []
42 | alignment_score = []
43 | while i >= 1 and j >= 1:
44 | move = move_matrix[i,j]
45 | if move == 2:
46 | op = (tokens_a[i-1], tokens_b[j-1])
47 | if score_matrix is not None:
48 | alignment_score.append(score_matrix[i-1,j-1])
49 | # op = op + (float(score_matrix[i-1, j-1]),)
50 | alignment.append(op)
51 | i -= 1
52 | j -= 1
53 | elif move == 0:
54 | op = (tokens_a[i-1], '')
55 | if score_matrix is not None:
56 | # op = op + (baseline_score,)
57 | alignment_score.append(baseline_score)
58 | alignment.append(op)
59 | i -= 1
60 | elif move == 1:
61 | op = ('',tokens_b[j-1])
62 | if score_matrix is not None:
63 | # op = op + (baseline_score,)
64 | alignment_score.append(baseline_score)
65 | alignment.append(op)
66 | j -= 1
67 | if i == 0:
68 | while j >= 1:
69 | op = ('', tokens_b[j-1])
70 | if score_matrix is not None:
71 | # op = op + (baseline_score,)
72 | alignment_score.append(baseline_score)
73 | alignment.append(op)
74 | j -= 1
75 | elif j == 0:
76 | while i >= 1:
77 | op = (tokens_a[i-1], '')
78 | if score_matrix is not None:
79 | # op = op + (baseline_score,)
80 | alignment_score.append(baseline_score)
81 | alignment.append(op)
82 | i -= 1
83 | return alignment[::-1], alignment_score[::-1]
84 |
85 | def batch_cos_sim(tokens_a, tokens_b, model, tokenizer, device=torch.device('cpu'), use_model=True):
86 | """
87 | Compute cosine similarities
88 | """
89 | if len(tokens_a) != len(tokens_b):
90 | raise ValueError('batch sizes cannot be different')
91 | for tokens in tokens_a + tokens_b:
92 | if len(tokens) == 0:
93 | raise ValueError('tokens cannot be empty')
94 |
95 | def batch_encodings(batch):
96 | # convert to ids
97 | batch = [torch.tensor(
98 | tokenizer.convert_tokens_to_ids(tokens),
99 | dtype=torch.long,
100 | device=device) for tokens in batch]
101 | max_len = np.max([ids.shape[0] for ids in batch])
102 | input_ids = torch.zeros((len(batch), max_len), dtype=torch.long,
103 | device=device)
104 | attention_mask = torch.zeros_like(input_ids, dtype=torch.int32)
105 | for i,b in enumerate(batch):
106 | input_ids[i, :len(b)] = b
107 | attention_mask[i, :len(b)] = 1
108 | with torch.no_grad():
109 | out = model(input_ids=input_ids, attention_mask=attention_mask)
110 |
111 | for i in range(len(batch)):
112 | canv_len = attention_mask[i,:].sum()
113 | yield out.last_hidden_state[i, :canv_len],\
114 | input_ids[i, :canv_len]
115 |
116 | batch_enc_a = batch_encodings(tokens_a)
117 | batch_enc_b = batch_encodings(tokens_b)
118 |
119 | for (enc_a, ids_a), (enc_b, ids_b) in zip(batch_enc_a,
120 | batch_enc_b):
121 | exact_matches = (ids_a.unsqueeze(-1).expand(-1, ids_b.shape[0])\
122 | == ids_b.unsqueeze(0).expand(ids_a.shape[0], -1)).type(torch.float)
123 |
124 | if not use_model:
125 | yield exact_matches
126 |
127 | dot_products = torch.matmul(enc_a, enc_b.transpose(0,1))
128 | try:
129 | norms = torch.matmul(torch.norm(enc_a, dim=-1).unsqueeze(1),
130 | torch.norm(enc_b, dim=-1).unsqueeze(0))
131 | except IndexError as e:
132 | print(enc_a.shape, enc_b.shape, ids_a, ids_b)
133 | raise e
134 | yield torch.max(dot_products/norms, exact_matches)
135 |
136 | def batch_align(tokens_a, tokens_b, model, tokenizer,
137 | baseline_score=0., device=torch.device('cpu'), **kwargs):
138 | """
139 | Align tokens
140 | :param tokens_a: list of lists of tokens
141 | :param tokens_b: list of lists of tokens
142 | :param model: encoder model
143 | :param tokenizer: tokenizer
144 | :param baseline_score: score assigned for non matches
145 | """
146 | if len(tokens_a) != len(tokens_b):
147 | raise ValueError('batch sizes cannot differ')
148 | non_empty_idxs = [i for i,(tok_a,tok_b) in enumerate(
149 | zip(tokens_a, tokens_b)) if len(tok_a) > 0 and len(tok_b) > 0]
150 | score_matrices = batch_cos_sim([tokens_a[i] for i in non_empty_idxs],
151 | [tokens_b[i] for i in non_empty_idxs], model, tokenizer, device=device, **kwargs)
152 |
153 | for j in range(len(tokens_a)):
154 | tok_a = tokens_a[j]
155 | tok_b = tokens_b[j]
156 | if len(tok_a) == 0:
157 | alignment = [('', b) for b in tok_b]
158 | scores = [baseline_score] * len(alignment)
159 | yield Alignment(alignment, scores=scores, baseline_score=baseline_score)
160 | elif len(tok_b) == 0:
161 | alignment = [(a, '') for a in tok_a]
162 | scores = [baseline_score] * len(alignment)
163 | yield Alignment(alignment, scores=scores, baseline_score=baseline_score)
164 | else:
165 | score_matrix = next(score_matrices).cpu().numpy()
166 | alignment_matrix, move_matrix = compute_alignment_scores(
167 | score_matrix, baseline_score=baseline_score)
168 | alignment, scores = get_alignment(move_matrix, tok_a, tok_b,
169 | score_matrix=score_matrix,
170 | baseline_score=baseline_score)
171 | yield Alignment(alignment, scores=scores, baseline_score=baseline_score)
172 |
173 | def align(tokens_a, tokens_b, model, tokenizer, **kwargs):
174 | return next(batch_align([tokens_a], [tokens_b], model, tokenizer, **kwargs))
175 |
176 | def batch_align_canvases(canvases, targets, model, tokenizer, **kwargs):
177 | """
178 | Align canvases to target tokens, differs from aligning tokens because it handles the type ids of the canvases
179 | """
180 | solid_tokens = [canvas.clean().tokens for canvas in canvases]
181 | for canvas,alignment in zip(canvases,
182 | batch_align(solid_tokens, targets, model, tokenizer, **kwargs)):
183 | recovered_alignment = []
184 | recovered_alignment_scores = []
185 | i,j = 0,0
186 | while i < len(canvas) or j < len(alignment):
187 | if i < len(canvas) and j < len(alignment) and \
188 | canvas.tokens[i] == alignment.alignment[j][0] and\
189 | not canvas.is_stricken(i) and not alignment.is_insertion(j):
190 | recovered_alignment.append(alignment.alignment[j])
191 | recovered_alignment_scores.append(alignment.scores[j])
192 | i += 1
193 | j += 1
194 | elif i < len(canvas) and canvas.is_stricken(i):
195 | recovered_alignment.append((canvas.tokens[i], ''))
196 | recovered_alignment_scores.append(1)
197 | i += 1
198 | elif j < len(alignment) and alignment.is_insertion(j):
199 | recovered_alignment.append(alignment.alignment[j])
200 | recovered_alignment_scores.append(alignment.scores[j])
201 | j += 1
202 | else:
203 | print(canvas.tokens, canvas.type_ids)
204 | print(alignment)
205 | print(i,j)
206 | print(recovered_alignment)
207 | print(canvas.tokens[i], canvas.type_ids[i], alignment.alignment[j])
208 | raise RuntimeError('Error with alignment')
209 | alignment = Alignment(recovered_alignment, recovered_alignment_scores)
210 | alignment.set_type_ids(canvas.type_ids)
211 | yield alignment
212 |
213 | def align_canvas(canvas, target, model, tokenizer, **kwargs):
214 | return next(batch_align_canvases([canvas], [target], model, tokenizer, **kwargs))
215 |
216 | class Canvas():
217 |
218 | """
219 | Canvas is a sequence of tokens with type ids
220 | """
221 |
222 | html_types = [
223 | "{}", # 0: plain text
224 | "{}", #1: agent inserted text
225 | "{}", #2 oracle inserted text
226 | "{}", #3: user deleted text
227 | "{}", #4 oracle deleted text
228 | ]
229 |
230 | @staticmethod
231 | def latex_type(tid, tok):
232 | """
233 | Utility for rendering to latex
234 | """
235 | if tid == 0:
236 | text = tok
237 | elif tid == 1:
238 | text = ''.join(('\\textcolor{seagreen}{',tok,'}'))
239 | elif tid == 2:
240 | text = ''.join(('\\textcolor{BrickRed}{',tok,'}'))
241 | elif tid == 3:
242 | text = ''.join(('\sout{', tok, '}'))
243 | elif tid == 4:
244 | text = ''.join(('\sout{\\textcolor{BrickRed}{', tok, '}}'))
245 | if tok[0] == ' ' and tid != 0:
246 | text = ' '+text
247 | return text
248 |
249 | def __init__(self, tokens, type_ids=None):
250 | if not type_ids is None and len(tokens) != len(type_ids):
251 | raise ValueError('length of tokens must match length of type ids')
252 | self.tokens = tokens
253 | if type_ids is None:
254 | type_ids = [0] * len(tokens)
255 | self.type_ids = type_ids
256 |
257 | def __len__(self):
258 | return len(self.tokens)
259 |
260 | def __eq__(self, other):
261 | return self.tokens == other.tokens and self.type_ids == other.type_ids
262 |
263 | def __lt__(self, other):
264 | return len(self.tokens) < len(other.tokens)
265 |
266 | def __gt__(self, other):
267 | return len(self.tokens) > len(other.tokens)
268 |
269 | def __le__(self, other):
270 | return len(self.tokens) <= len(other.tokens)
271 |
272 | def __ge__(self, other):
273 | return len(self.tokens) >= len(other.tokens)
274 |
275 | def __repr__(self):
276 | return str((self.tokens, self.type_ids))
277 |
278 | def is_stricken(self, i):
279 | return self.type_ids[i] >= 3
280 |
281 | def operate(self, loc, operation, token, agent=0):
282 | """
283 | Apply an edit to a canvas
284 | """
285 | if operation == 0: # insertion
286 | self.tokens.insert(loc+1, token)
287 | self.type_ids.insert(loc+1, agent+1)
288 | elif operation == 1: # substitution
289 | self.tokens[loc] = token
290 | self.type_ids[loc] = agent+1
291 | elif operation == 2: # deletion
292 | if self.type_ids[loc] != 0:
293 | del self.tokens[loc]
294 | del self.type_ids[loc]
295 | else:
296 | self.type_ids[loc] = agent + 3
297 | else:
298 | raise ValueError('Invalid Operation: {}'.format(operation))
299 |
300 | def render(self, tokenizer, clean=True):
301 | """
302 | Return string from the canvas
303 | """
304 | if clean:
305 | canvas = self.clean()
306 | if len(canvas) == 0: return ''
307 | else:
308 | return tokenizer.convert_tokens_to_string(canvas.tokens)
309 |
310 | def render_to_html(self, tokenizer):
311 | """
312 | Format into html
313 | """
314 | html_string = ''
315 | for tok, tid in zip(self.tokens, self.type_ids):
316 | html_string += Canvas.html_types[tid].format(
317 | tokenizer.convert_tokens_to_string(tok))
318 | return html_string
319 |
320 | def render_to_latex(self, tokenizer):
321 | """
322 | Format into LATEX
323 | """
324 | latex_string = ''
325 | for tok, tid in zip(self.tokens, self.type_ids):
326 | latex_string += Canvas.latex_type(
327 | tid, tokenizer.convert_tokens_to_string(tok))
328 | return latex_string
329 |
330 | def clean(self):
331 | """
332 | Remove tokens that are stricken out and remove type ids (except for oracle type ids)
333 | """
334 | tokens = [t for i,t in enumerate(self.tokens) if not self.is_stricken(i)]
335 | type_ids = [tid for i,tid in enumerate(self.type_ids) if not self.is_stricken(i)]
336 | # reset the type ids for model insertions
337 | type_ids = [0 if tid==1 else tid for tid in type_ids]
338 | return Canvas(tokens, type_ids)
339 |
340 | def copy(self):
341 | return Canvas(list(self.tokens), list(self.type_ids))
342 |
343 | class Alignment():
344 |
345 | """
346 | Alignment between two sets of tokens
347 | """
348 |
349 | def __init__(self, alignment, scores=None, type_ids=None,
350 | baseline_score=0.3, add_type_ids=True):
351 | self.alignment=alignment
352 |
353 | if not scores is None and len(scores) != len(alignment):
354 | raise ValueError('length of scores must match length of alignment')
355 | self.scores=scores
356 | self.baseline_score = baseline_score
357 |
358 | if not type_ids is None:
359 | self.set_type_ids(type_ids)
360 | elif add_type_ids:
361 | self.set_type_ids([0] * len(self.get_source_tokens()))
362 | else:
363 | self.type_ids = None
364 |
365 | def copy(self):
366 | return Alignment(list(self.alignment), list(self.scores))
367 |
368 | def __len__(self):
369 | return len(self.alignment)
370 |
371 | def __str__(self):
372 | print_str = ''
373 | if len(self.alignment) == 0: return print_str
374 | src_pad = max(1, np.max([len(tup[0]) for tup in self.alignment]))
375 | tgt_pad = max(1, np.max([len(tup[1]) for tup in self.alignment]))
376 | for i in range(len(self.alignment)):
377 | line = '{src:{src_pad}} - {tgt:{tgt_pad}}'.format(src=self.alignment[i][0],
378 | tgt=self.alignment[i][1], src_pad=src_pad, tgt_pad=tgt_pad)
379 | if not self.scores is None:
380 | line += ' {:.2f}'.format(self.scores[i])
381 | if not self.type_ids is None:
382 | line += ' {:2}'.format(self.type_ids[i])
383 | print_str += line + '\n'
384 | return print_str
385 |
386 | def is_match(self, idx):
387 | return self.alignment[idx][0] != '' and self.alignment[idx][1] != ''
388 |
389 | def is_exact_match(self, idx):
390 | return self.alignment[idx][0] == self.alignment[idx][1]
391 |
392 | def is_insertion(self, idx):
393 | return self.alignment[idx][0] == '' and self.alignment[idx][1] != ''
394 |
395 | def is_addition(self, idx):
396 | return self.alignment[idx][1] != '' and self.alignment[idx][0] != self.alignment[idx][1]
397 |
398 | def is_deletion(self, idx):
399 | return self.alignment[idx][0] != '' and self.alignment[idx][1] == ''
400 |
401 | def get_non_const_ops(self):
402 | """
403 | All positions in alignment that aren't exact matches
404 | """
405 | if self.type_ids is None:
406 | return [i for i in range(len(self.alignment))\
407 | if not self.is_exact_match(i)]
408 | else:
409 | return [i for i in range(len(self.alignment))\
410 | if not self.is_exact_match(i) and self.type_ids[i] < 3]
411 |
412 | def get_adjacent_ops(self, idxs, adj_range = 1):
413 | """
414 | Return idxs of adjacent ops in alignment. Adjacent idxs are idxs
415 | that are near a pair of matched tokens (not necessarily an exact match)
416 | """
417 | def is_adjacent(idx, adj_range):
418 | min_lim = max(0, idx - adj_range)
419 | max_lim = min(len(self.alignment), idx + adj_range + 1)
420 | for i in range(min_lim, max_lim):
421 | if self.is_match(i):
422 | return True
423 | return False
424 | ret_idxs = [i for i in idxs if is_adjacent(i, adj_range)]
425 | if not ret_idxs: return idxs
426 | return ret_idxs
427 |
428 | def get_actions(self, action_idxs, include_deleted_words=False):
429 | """
430 | Get actual action tuples for given idxs of an alignment
431 | """
432 | actions = {}
433 | j = -1
434 | k = -1# <- keep track of this for insertions, -1 is sentinel token
435 | for i in range(len(self.alignment)):
436 | if not self.is_insertion(i):
437 | j += 1
438 | if self.type_ids is None:
439 | k = j
440 | else:
441 | if self.type_ids[i] < 3:
442 | k = j
443 | if i in action_idxs:
444 | if self.is_insertion(i):
445 | actions[i] = (k, 0, self.alignment[i][1])
446 | elif self.is_deletion(i):
447 | if include_deleted_words:
448 | actions[i] = (j, 2, None, self.alignment[i][0])
449 | else:
450 | actions[i] = (j, 2, None)
451 | else: # substitution
452 | actions[i] = (j, 1, self.alignment[i][1])
453 | actions = [actions[i] for i in action_idxs]
454 | return actions
455 |
456 | def push_forward(self, idxs, agent = 0):
457 | """
458 | Push alignment forward at given idxs (i.e. replace the source token with the target token at the location
459 | giving an exact match)
460 | """
461 | if len(idxs) == 0: return
462 | if np.min(idxs) < 0 or np.max(idxs) >= len(self.alignment):
463 | raise ValueError('idxs out of range')
464 | delete_idxs = []
465 | for i in idxs:
466 | if self.type_ids is None:
467 | if not self.is_deletion(i):
468 | self.alignment[i] = (self.alignment[i][1], self.alignment[i][1])
469 | else:
470 | delete_idxs.append(i)
471 | else:
472 | if not self.is_deletion(i):
473 | self.alignment[i] = (self.alignment[i][1], self.alignment[i][1])
474 | self.type_ids[i] = agent+1
475 | elif self.type_ids[i] == 0:
476 | self.type_ids[i] = agent+3
477 | elif self.type_ids[i] < 3: #otherwise token was already deleted
478 | delete_idxs.append(i)
479 | self.alignment = [tup for i,tup in enumerate(self.alignment) if i not in delete_idxs]
480 | if not self.scores is None:
481 | self.scores = [s for i,s in enumerate(self.scores) if i not in delete_idxs]
482 | if not self.type_ids is None:
483 | self.type_ids = [tid for i,tid in enumerate(self.type_ids) if i not in delete_idxs]
484 |
485 | def canvas_location_to_alignment_location(self, canvas_location):
486 | """
487 | Map a location in the source canvas to the alignment
488 | """
489 | j = 0
490 | for i in range(len(self.alignment)):
491 | if not self.is_insertion(i):
492 | canvas_location -= 1
493 | if canvas_location < -1: break
494 | j = i
495 | return j
496 |
497 | def operate(self, canvas_location, operation, token, agent = 0):
498 | '''
499 | Operate on the source canvas while updating the alignment to the target tokens
500 | '''
501 | if canvas_location >= len(self.get_source_tokens()):
502 | raise ValueError('canvas location out of bounds')
503 | if operation == 0: #insertion
504 | if token is None: raise ValueError('token cannot be None when inserting')
505 | # insert on the left of the next token to keep contiguous insertions and deletions grouped
506 | location = self.canvas_location_to_alignment_location(canvas_location+1)
507 | self.alignment.insert(location, (token, ''))
508 | if not self.scores is None:
509 | self.scores.insert(location, self.baseline_score)
510 | if not self.type_ids is None:
511 | self.type_ids.insert(location, agent+1)
512 | elif operation == 1: #substitution
513 | if token is None: raise ValueError('token cannot be None when substituting')
514 | if canvas_location == -1: raise ValueError('cannot substitute bos token')
515 |
516 | location = self.canvas_location_to_alignment_location(canvas_location)
517 | if not self.type_ids is None and self.type_ids[location] >= 3:
518 | raise ValueError('cannot edit deleted token')
519 | if not self.type_ids is None: self.type_ids[location] = 1 # trick to make deletion operate correctly
520 | self.operate(canvas_location, 2, token, agent)
521 | self.operate(canvas_location-1, 0, token, agent)
522 | elif operation == 2: #deletion
523 | location = self.canvas_location_to_alignment_location(canvas_location)
524 | if not self.type_ids is None and self.type_ids[location] >= 3:
525 | raise ValueError('cannot edit deleted token')
526 | if self.is_match(location):
527 | deleted_token = self.alignment[location][0]
528 | self.alignment[location] = ('', self.alignment[location][1])
529 | if not self.scores is None:
530 | self.scores[location] = self.baseline_score
531 | if not self.type_ids is None and self.type_ids[location] == 0: #this needs to remain
532 | self.type_ids[location] = -1
533 | replacement_location = self.canvas_location_to_alignment_location(
534 | canvas_location)
535 | self.alignment.insert(replacement_location, (deleted_token, ''))
536 | self.type_ids.insert(replacement_location, agent+3)
537 | if not self.scores is None:
538 | self.scores.insert(replacement_location, self.baseline_score)
539 | elif not self.type_ids is None: self.type_ids[location] = -1
540 | else:
541 | if not self.type_ids is None and self.type_ids[location] == 0:
542 | self.type_ids[location] = agent+3
543 | else:
544 | del self.alignment[location]
545 | if not self.scores is None:
546 | del self.scores[location]
547 | if not self.type_ids is None:
548 | del self.type_ids[location]
549 |
550 | if len(self.get_type_ids()) != len(self.get_source_tokens()):
551 | print(canvas_location, operation, token)
552 | print(self.get_type_ids(), self.get_source_tokens())
553 | print(self)
554 | raise RuntimeError('violated consitency')
555 |
556 | def set_type_ids(self, type_ids):
557 | '''
558 | Add type ids to alignment
559 | :param type_ids: type ids to add. Assumes they correspond to the type ids of the source tokens
560 | '''
561 | if len(type_ids) != len(self.get_source_tokens()):
562 | raise ValueError('length of type ids must match length of source tokens in alignment')
563 | self.type_ids = []
564 | j = 0
565 | for i in range(len(self.alignment)):
566 | if not self.is_insertion(i):
567 | tag = type_ids[j]
568 | j += 1
569 | else: tag = -1
570 | self.type_ids.append(tag)
571 |
572 | def get_type_ids(self):
573 | if self.type_ids is None: return None
574 | else: return [tid for tid in self.type_ids if tid != -1]
575 |
576 | def get_source_tokens(self):
577 | return [self.alignment[i][0] for i in range(len(self.alignment))\
578 | if not self.is_insertion(i)]
579 |
580 | def get_target_tokens(self):
581 | return [self.alignment[i][1] for i in range(len(self.alignment))\
582 | if not self.is_deletion(i)]
583 |
584 | def get_source_canvas(self):
585 | source_tokens = self.get_source_tokens()
586 | type_ids = self.get_type_ids()
587 | return Canvas(source_tokens, type_ids)
588 |
589 | def get_target_canvas(self):
590 | target_tokens = self.get_target_tokens()
591 | return Canvas(target_tokens)
592 |
593 | def print_alignment(alignment):
594 | for (a,b) in alignment:
595 | print("{} - {}".format(a,b))
596 |
597 | def get_frozen_idxs(alignment):
598 | """
599 | Determine idxs with exact matches. Opposite of get_non_const_ops
600 | """
601 | idxs = []
602 | i = 0
603 | for tup in alignment:
604 | if tup[0] == tup[1]:
605 | idxs.append(i)
606 | if tup[0] != '':
607 | i += 1
608 | return idxs
609 |
610 | def score_token_idf(token, idf_dict, max_score):
611 | return idf_dict[token] if token in idf_dict else max_score
612 |
613 | def get_non_const_ops(alignment):
614 | return [i for i,op in enumerate(alignment) if op[0] != op[1] and op[2] < 3]
615 |
616 | def get_adjacent_ops(idxs, alignment, adj_range = 1):
617 | """
618 | Return idxs of adjacent ops in alignment. Adjacent idxs are idxs
619 | that are near a pair of matched tokens (not necessarily an exact match)
620 | """
621 | def is_adjacent(idx, adj_range):
622 | if idx < adj_range: # "adjacent to start token"
623 | return True
624 | min_lim = max(0, idx - adj_range)
625 | max_lim = min(len(alignment), idx + adj_range + 1)
626 | for i in range(min_lim, max_lim):
627 | if alignment[i][0] != '' and alignment[i][1] != '': # matched tokens?
628 | return True
629 | return False
630 | ret_idxs = [i for i in idxs if is_adjacent(i, adj_range)]
631 | return ret_idxs
632 |
633 | def operate_canvas(canvas, token_type_ids, loc, operation, token, agent=0):
634 | """
635 | Apply an edit to a canvas
636 | """
637 | canvas = list(canvas)
638 | token_type_ids = list(token_type_ids)
639 | if operation == 0: # insertion
640 | canvas = canvas[:loc+1] + [token] + canvas[loc+1:]
641 | token_type_ids = token_type_ids[:loc+1] + [agent+1] + token_type_ids[loc+1:]
642 | elif operation == 1: # substitution
643 | canvas = canvas[:loc] + [token] + canvas[loc+1:]
644 | token_type_ids = token_type_ids[:loc] + [agent+1] + token_type_ids[loc+1:]
645 | elif operation == 2: # deletion
646 | if token_type_ids[loc] != 0:
647 | canvas = canvas[:loc] + canvas[loc+1:]
648 | token_type_ids = token_type_ids[:loc] + token_type_ids[loc+1:]
649 | else:
650 | token_type_ids[loc] = agent + 3
651 | else:
652 | raise ValueError('Invalid Operation: {}'.format(operation))
653 | return canvas, token_type_ids
654 |
655 | def operate_tagged_alignment(alignment, idxs, agent = 0):
656 | """
657 | Push alignment forward at given idxs
658 | :param alignment: tagged alignment
659 | """
660 | alignment = list(alignment) #so it doesn't do anything in place
661 | for i in idxs:
662 | if alignment[i][1] != '':
663 | alignment[i] = (alignment[i][1], alignment[i][1], agent + 1)
664 | elif alignment[i][-1] == 0:
665 | alignment[i] = (alignment[i][0], alignment[i][1], agent + 3)
666 | else:
667 | alignment[i] = (alignment[i][0], alignment[i][1], -2)
668 | alignment = [tup for tup in alignment if tup[-1] >= -1]
669 | return alignment
670 |
671 | def compress_alignment(alignment, alignment_scores):
672 | # TODO: this should assume that the alignment can contain other tags, such as scores
673 | new_alignment, new_alignment_scores = [], []
674 | for tup, score in zip(alignment, alignment_scores):
675 | if tup[0] == '' and tup[1] == '': continue
676 | new_alignment.append(tup)
677 | new_alignment_scores.append(score)
678 | return new_alignment, new_alignment_scores
679 |
680 | def update_idxs(idxs, loc, op, tok):
681 | """
682 | Update a set of idxs given an action that has been
683 | applied to a canvas
684 | """
685 | new_idxs = []
686 | for i in list(idxs):
687 | if i > loc: #unaffected otherwise
688 | if op == 0:
689 | i += 1
690 | elif op == 2:
691 | i -= 1
692 | new_idxs.append(i)
693 | return set(new_idxs)
694 |
695 | def get_actions_from_tagged_alignment(alignment, action_idxs, include_deleted_words=False):
696 | """
697 | Get actual action tuples for given idxs of an alignment
698 | """
699 | input_tokens = []
700 | for i,tup in enumerate(alignment):
701 | input_tokens.append(tup[0])
702 |
703 | actions = {}
704 | j = -1
705 | k = -1# <- keep track of this for insertions, -1 is sentinel token
706 | for i, (tok, op) in enumerate(zip(input_tokens, alignment)):
707 | if tok != '' and op[2] < 3:
708 | j += 1
709 | k = j
710 | elif tok != '' and op[2] >= 3:
711 | j += 1
712 | if i in action_idxs:
713 | if op[0] == '': #insertion
714 | actions[i] = (k, 0, op[1])
715 | elif op[1] == '': #deletion
716 | if include_deleted_words:
717 | actions[i] = (j, 2, None, op[0])
718 | else:
719 | actions[i] = (j, 2, None)
720 | else: # substitution
721 | actions[i] = (j, 1, op[1])
722 | actions = [actions[i] for i in action_idxs]
723 | return actions
724 |
725 | def get_correct_canvas_positions(alignment):
726 | correct_positions = []
727 | idx = 0
728 | for p in alignment:
729 | if p[0] == '': continue
730 | if p[0] == p[1]: correct_positions.append(idx)
731 | idx += 1
732 | return correct_positions
733 |
734 | def canvas_to_text(canvas, tokenizer, token_type_ids=None):
735 | if token_type_ids is None:
736 | token_type_ids = [0] * len(canvas)
737 | canvas, token_type_ids = clean_canvas(canvas, token_type_ids)
738 | if len(canvas) == 0: return ''
739 | else:
740 | return tokenizer.decode(tokenizer.convert_tokens_to_ids(canvas))
741 |
742 | def tag_alignment(alignment, token_type_ids):
743 | tagged_alignment = []
744 | j = 0
745 | for i,tup in enumerate(alignment):
746 | if tup[0] != '':
747 | tag = token_type_ids[j]
748 | j += 1
749 | else: tag = -1
750 | tagged_alignment.append([tup[0], tup[1], tag])
751 | return tagged_alignment
752 |
753 | def get_tags_from_alignment(alignment):
754 | return [tup[2] for tup in alignment if tup[0] != '']
755 |
756 | def get_source_canvas(alignment):
757 | source_tokens = []
758 | for i,tup in enumerate(alignment):
759 | if tup[0] != '': source_tokens.append(tup[0])
760 | return source_tokens
761 |
762 | def get_target_canvas(alignment):
763 | target_tokens = []
764 | for i,tup in enumerate(alignment):
765 | if tup[1] != '': target_tokens.append(tup[1])
766 | return target_tokens
767 |
768 | def clean_canvas(canvas, token_type_ids):
769 | canvas = [c for c,t in zip(canvas, token_type_ids) if t < 3]
770 | token_type_ids = [t for t in token_type_ids if t < 3]
771 | # reset the type ids for model insertions
772 | token_type_ids = [0 if t==1 else t for t in token_type_ids]
773 | return canvas, token_type_ids
774 |
775 | def get_token_type_ids(alignment, idxs):
776 | token_type_ids = []
777 | for i,tup in enumerate(alignment):
778 | if tup[0] != '':
779 | token_type_ids.append(1 if i in idxs else 0)
780 | return token_type_ids
781 |
782 | class VocabSampler():
783 |
784 | def __init__(self,vocab_weights, vocab_dict, batch_size=1000):
785 | self.weights = torch.tensor(vocab_weights)
786 | self.dict = vocab_dict
787 | self.batch_size=batch_size
788 | self.samples = torch.multinomial(self.weights, self.batch_size)
789 | self.sample_idx = 0
790 |
791 | def get(self):
792 | if self.sample_idx == self.batch_size:
793 | self.samples = torch.multinomial(self.weights, self.batch_size)
794 | self.sample_idx = 0
795 | self.sample_idx += 1
796 | return self.dict[self.samples[self.sample_idx-1].item()]
797 |
798 | ####
799 | #### VVVV Most of this now deperacted, moved to models/word_edit_model.py
800 |
801 | def noise_alignment_(tagged_alignment, vocab_sampler):
802 | positions = [i for i,tup in enumerate(tagged_alignment) if tup[-1] < 3] + [-1]
803 | position = np.random.choice(positions)
804 | if position >= 0 and tagged_alignment[position][-1] == 0 and \
805 | tagged_alignment[position][0] == tagged_alignment[position][1]:
806 | operation = np.random.randint(0, 2)
807 | else:
808 | operation = 0
809 |
810 | if operation == 0: # insert
811 | # token = vocab_dict[torch.multinomial(torch.tensor(vocab_weights), 1).item()]
812 | token = vocab_sampler.get()
813 | tagged_alignment = tagged_alignment[:position+1] + [[token, '', 1]] + tagged_alignment[position+1:]
814 | elif operation == 1: # delete a token
815 | tup = tagged_alignment[position]
816 | tagged_alignment = tagged_alignment[:position] + [['', tup[1], -1]] + [[tup[0], '', 3]] + tagged_alignment[position+1:]
817 |
818 | return tagged_alignment
819 |
820 |
821 | def noise_alignment(tagged_alignment, vocab_weights, vocab_dict, max_noise=3, noise_frac=0.2):
822 | """
823 | Noise an alignment for training. This randomly edits the source canvas of the alignment, and updates
824 | the alignment accordingly
825 | """
826 | canv_len = len([tup[0] for tup in tagged_alignment if tup[-1] < 3])
827 | n_false_edits = int(np.ceil(noise_frac * canv_len))
828 | for i in range(n_false_edits):
829 | # while True:
830 | if random.random() > noise_frac: break
831 | tagged_alignment = noise_alignment_(tagged_alignment, vocab_weights, vocab_dict)
832 | return tagged_alignment
833 |
834 | def sample_actions(alignment, tokenizer, token_type_ids, vocab_weights,
835 | vocab_dict, noise_frac=0.0):
836 | """
837 | Sample a set of actions for training. This randomly pushes the alignment forward a few steps
838 | and then returns the set of remaining actions to turn the source into the target
839 | """
840 | if token_type_ids is None:
841 | token_type_ids = [0 for tup in alignment if tup[0] != '']
842 | alignment = tag_alignment(alignment, token_type_ids)
843 |
844 | non_const_ops = get_non_const_ops(alignment)
845 | # print(non_const_ops)
846 | ops_length = len(non_const_ops) + 1 # +1 operation for stopping
847 |
848 |
849 | #sample length
850 | length = np.random.randint(0,ops_length) # allow sampling any state
851 |
852 | #sample operations
853 | rand_ops = np.random.permutation(non_const_ops)
854 | operations = rand_ops[:length]
855 | # return_ops = rand_ops[length:]
856 |
857 | alignment = operate_tagged_alignment(alignment, operations, agent=0)
858 | alignment = noise_alignment(alignment, vocab_weights, vocab_dict, noise_frac=noise_frac)
859 | return_ops = get_non_const_ops(alignment)
860 |
861 | actions = get_actions_from_tagged_alignment(alignment, return_ops)
862 | if len(actions) == 0: actions = [(1, None, None, None)]
863 | else: actions = [(0, a[0], a[1], tokenizer.convert_tokens_to_ids(a[2])) for a in actions]
864 |
865 | canvas = get_source_canvas(alignment)
866 | token_type_ids = get_tags_from_alignment(alignment)
867 |
868 | return canvas, token_type_ids, actions, ops_length
869 |
870 | def sample_trajectory(alignment, tokenizer, token_type_ids, vocab_sampler,
871 | noise_prob=0.2, max_traj_length=64):
872 | """
873 | A more principled approach to noising the canvas. Randomly push the alignment forward
874 | while randomly making errors (noise).
875 | """
876 |
877 | alignment = tag_alignment(alignment, token_type_ids)
878 | ops = get_non_const_ops(alignment)
879 | ops = np.random.permutation(ops)
880 | ops_len = len(ops)
881 | n_errors = 0
882 | n_ops = 0
883 | trajectory = []
884 | trajectory.append((n_errors, n_ops))
885 | for i in range(max_traj_length):
886 | if random.random() > noise_prob:
887 | if n_ops == ops_len and n_errors == 0: break
888 | if random.random() <= n_errors / (n_errors + ops_len - n_ops):
889 | n_errors -= 1
890 | else: n_ops += 1
891 | else:
892 | n_errors += 1
893 | trajectory.append((n_errors, n_ops))
894 | traj_length = len(trajectory)
895 | traj_idx = np.random.randint(0,traj_length) # allow sampling any state
896 | n_errors, n_ops = trajectory[traj_idx]
897 | alignment = operate_tagged_alignment(alignment, ops[:n_ops])
898 | # print(alignment)
899 | for i in range(n_errors):
900 | alignment = noise_alignment_(alignment, vocab_sampler)
901 | return_ops = get_non_const_ops(alignment)
902 | actions = get_actions_from_tagged_alignment(alignment, return_ops)
903 | if len(actions) == 0: actions = [(1, None, None, None)]
904 | else: actions = [(0, a[0], a[1], tokenizer.convert_tokens_to_ids(a[2])) for a in actions]
905 |
906 | canvas = get_source_canvas(alignment)
907 | token_type_ids = get_tags_from_alignment(alignment)
908 |
909 | return canvas, token_type_ids, actions, traj_length
910 |
911 |
912 |
--------------------------------------------------------------------------------
/infosol/decoding.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from copy import deepcopy
3 | from queue import PriorityQueue
4 |
5 | class ParallelDecodingMixin():
6 | """
7 | Mixin for parallel decoding
8 | """
9 |
10 | def forward_canvases(self, canvases, device=torch.device('cpu'), move_to_cpu=True):
11 | raise NotImplemented
12 |
13 | @staticmethod
14 | def canvas_len(canvas):
15 | raise NotImplemented
16 |
17 | def decode_loop(self, canvases, decode_func, state=None, max_batch_tokens=2048, device=torch.device('cpu'), queue_size=2000, return_idx=False):
18 |
19 | iter_idx = 0
20 | finished = False
21 | input_queue = PriorityQueue(maxsize=queue_size)
22 | input_iter = iter(canvases)
23 | start_state = state
24 |
25 | def enqueue(canvas, state, iter_idx):
26 | input_queue.put((canvas, iter_idx, state))
27 |
28 | def add_input(idx):
29 | c = next(input_iter, None)
30 | if c is None:
31 | return True
32 | else:
33 | enqueue(deepcopy(c), start_state, idx)
34 | return False
35 |
36 | def get_batch():
37 | batch = []
38 | batch_len = 0
39 | batch_n_tokens = 0
40 | while True:
41 | if input_queue.empty(): break
42 | # get input
43 | inp = input_queue.get()
44 | canvas, idx, state = inp
45 | n_tokens = self.canvas_len(canvas)
46 | # check if can fit in batch
47 | batch_len_ = batch_len + 1
48 | batch_n_tokens_ = max(batch_n_tokens, n_tokens)
49 | # if doesn't fit, requeue and break
50 | if batch_len_ * batch_n_tokens_ > max_batch_tokens:
51 | enqueue(canvas, state, idx) # todo: swith this order of arguments, it's confusing
52 | break
53 | else:
54 | batch.append(inp)
55 | batch_len = batch_len_
56 | batch_n_tokens = batch_n_tokens_
57 | batch_size = len(batch)
58 | if batch_size == 0:
59 | raise RuntimeError('Unable to fit any inputs into batch. Try increasing max batch tokens')
60 | return batch
61 |
62 | # start by filling up the queue
63 | while not input_queue.full():
64 | finished = add_input(iter_idx)
65 | iter_idx += 1
66 | if finished: break
67 |
68 | while not input_queue.empty():
69 | batch = get_batch()
70 | batch_size = len(batch)
71 |
72 | canvases = [b[0] for b in batch]
73 | model_out = self.forward_canvases(canvases, device=device, move_to_cpu=True)
74 | for i in range(batch_size):
75 | m_out, (canvas, idx, state) = model_out[i], batch[i]
76 | canvas, state, stop = decode_func(m_out, canvas, state)
77 |
78 | if stop:
79 | if return_idx:
80 | yield canvas, idx
81 | else:
82 | yield canvas
83 | if not finished:
84 | finished = add_input(iter_idx)
85 | iter_idx += 1
86 | else:
87 | enqueue(canvas, state, idx)
88 | return
89 |
--------------------------------------------------------------------------------
/infosol/env.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import copy
3 | import torch.nn.functional as F
4 | import random
5 | import numpy as np
6 | import torch.nn.functional as F
7 |
8 | from transformers import BertModel, AutoTokenizer, T5ForConditionalGeneration, BertConfig
9 | from nltk.translate.bleu_score import sentence_bleu
10 |
11 | from infosol.alignment import batch_align, align_canvas, score_token_idf, Canvas, batch_align_canvases
12 |
13 | class WordEditOracle(torch.nn.Module):
14 | """
15 | Oracle Model
16 | """
17 |
18 | def __init__(self, align_model, align_tokenizer, idf_dict, adjacent_ops=False, sort_ops='sort',
19 | baseline_score=0.3, adj_range=1, min_alignment_score=0.7,
20 | avoid_delete=True, complete_words=False, token_no_space=None,
21 | contiguous_edits=False,
22 | # n_contiguous_edits=1
23 | ):
24 | """
25 | :param idf_dict: dict of idf scores for ranking oracle actions
26 | :param adjacent_ops: whether to limit the model to making adjacent edits
27 | :param sort_ops: whether to sort the actions returned by the oracle, sorting is done according to idf scores. This is a string (!), can be 'sort', 'random', or 'l2r' (left2right)
28 | :param n_return_actions: number of actions the oracle returns. 0 means return all actions
29 | :param baseline_score: baseline score for computing alignment, 0.3 empirically gives "natural" alignments
30 | :param adj_range: adjacency range for limiting oracle to adjacent edits
31 | :param avoid_delete: whether to avoid deletion edits
32 | """
33 | super().__init__()
34 | self.idf_dict = idf_dict
35 | # self.n_contiguous_edits = n_contiguous_edits
36 | self.contiguous_edits = contiguous_edits
37 | self.max_idf_score = np.max(list(self.idf_dict.values()))
38 | self.adjacent_ops = adjacent_ops
39 | self.adj_range = adj_range
40 | self.complete_words = complete_words
41 | self.token_no_space = token_no_space
42 | if sort_ops not in ('sort', 'random', 'l2r'):
43 | raise ValueError('unrecognized value for sort_ops: {}'.format(sort_ops))
44 | self.sort_ops = sort_ops
45 | self.baseline_score = baseline_score
46 | self.avoid_delete = avoid_delete
47 |
48 | self.align_model = align_model
49 | self.align_tokenizer = align_tokenizer
50 |
51 | def edit(self, canvas, target, alignment=None, return_alignment=False,
52 | device=torch.device('cpu'), k=1):
53 | """
54 | Make an edit
55 | """
56 | if not alignment is None:
57 | if canvas != alignment.get_source_canvas():
58 | raise ValueError('canvas does not match alignment')
59 | if target != alignment.get_target_tokens():
60 | raise ValueError('target does not match alignment')
61 | else:
62 | alignment = align_canvas(canvas, target, self.align_model,
63 | self.align_tokenizer, baseline_score=self.baseline_score,
64 | device=device)
65 | self.edit_alignment(alignment, k=k)
66 | canvas = alignment.get_source_canvas()
67 |
68 | ret_tup = canvas
69 | if return_alignment: ret_tup = (canvas, alignment)
70 | return ret_tup
71 |
72 | def edit_alignment(self, alignment, k=1):
73 | """
74 | Make an edit to an alignment
75 | """
76 | op_idxs = self.get_action_idxs(alignment)
77 | op_idxs = self.sort_action_idxs(op_idxs, alignment)
78 | op_idxs = self.choose_action_idxs(op_idxs, alignment, k)
79 | op_idxs = list(set(op_idxs))
80 | alignment.push_forward(op_idxs, agent=1)
81 | return len(op_idxs) > 0
82 | #
83 | # if len(op_idxs) > 0:
84 | # idxs_to_push = [op_idxs[0]]
85 | # for op in op_idxs[1:]:
86 | # if len(idxs_to_push) >= self.n_contiguous_edits: break
87 | # else:
88 | # cond1 = op == max(idxs_to_push) + 1
89 | # cond2 = op == min(idxs_to_push) - 1
90 | # if cond1 or cond2:
91 | # idxs_to_push.append(op)
92 | # if not self.complete_words:
93 | # alignment.push_forward(idxs_to_push, agent=1)
94 | # else:
95 | # idxs_to_push = list(set(idxs_to_push))
96 | # idxs_to_push_ = []
97 | # for op in idxs_to_push:
98 | # for case, rel_idx in zip(
99 | # (alignment.is_insertion, alignment.is_deletion),
100 | # (1,0)):
101 | # if case(op):
102 | # for j in range(op, 0, -1):
103 | # if case(j-1)\
104 | # and self.token_no_space(alignment.alignment[j][rel_idx]):
105 | # # and alignment.alignment[j][rel_idx].startswith('#'):
106 | # idxs_to_push_.append(j-1)
107 | # else: break
108 | # for j in range(op+1, len(alignment)):
109 | # if case(j)\
110 | # and self.token_no_space(alignment.alignment[j][rel_idx]):
111 | # # and alignment.alignment[j][rel_idx].startswith('#'):
112 | # idxs_to_push_.append(j)
113 | # else: break
114 | # # print(idxs_to_push)
115 | # idxs_to_push = list(set(idxs_to_push + idxs_to_push_))
116 | # alignment.push_forward(list(idxs_to_push), agent=1)
117 | # return True
118 | # else: return False
119 |
120 | def get_action_idxs(self, alignment):
121 | """
122 | Choose actions from alignment
123 | """
124 |
125 | # retrieve idxs in the alignment that correspond to an edit
126 | op_idxs = alignment.get_non_const_ops()
127 | if len(op_idxs) == 0: return []
128 |
129 | # if only allowing adjacent edits, limit idxs to adjacent idxs
130 | if self.adjacent_ops:
131 | op_idxs = alignment.get_adjacent_ops(op_idxs, adj_range=self.adj_range)
132 | return op_idxs
133 |
134 | def sort_action_idxs(self, op_idxs, alignment):
135 | """
136 | Sort possible actions, based on alignment scores and idf scores
137 | """
138 | if self.sort_ops == 'sort':
139 | # score using idf score and alignment score
140 | def score_op_idx(idx, alignment, max_idf_score):
141 | score = 0
142 | if alignment.is_deletion(idx): #deletions get treated specially
143 | if self.avoid_delete:
144 | score = 0
145 | else: # note that the alignment score for a deletion is the baseline score, usually 0.3
146 | # should have a more principled way of scoring deletions?
147 | score = alignment.scores[idx] * score_token_idf(
148 | alignment.alignment[idx][0], self.idf_dict, max_idf_score)
149 | else: # For insertions, the alignment score is the baseline score, usually 0.3
150 | score = (1-alignment.scores[idx]) * score_token_idf(
151 | alignment.alignment[idx][1], self.idf_dict, max_idf_score)
152 | return score
153 | idx_scores = [score_op_idx(i, alignment, self.max_idf_score) for i in op_idxs]
154 | op_idxs = list(np.asarray(op_idxs)[np.argsort(idx_scores)[::-1]])
155 | elif self.sort_ops == 'random':
156 | op_idxs = [op_idxs[i] for i in np.random.choice(
157 | np.arange(len(op_idxs)), len(op_idxs), replace=False)]
158 | # if going left to right, we still want to prioritize a set of "misaligned" words
159 | elif self.sort_ops == 'l2r':
160 | priority_ops = [i for i in op_idxs if alignment_scores[i] < 0.8]
161 | if len(priority_ops) > 0:
162 | op_idxs = priority_ops
163 | return op_idxs
164 |
165 | def choose_action_idxs(self, op_idxs, alignment, k):
166 |
167 | """
168 | Choose which actions to take based on some heuristics
169 | """
170 |
171 | def complete_word(op):
172 | """
173 | Completes actions so they operate on whole words, not tokens
174 | """
175 | complete_idxs = [op]
176 | for case, rel_idx in zip(
177 | (alignment.is_addition, alignment.is_deletion),
178 | (1,0)
179 | ):
180 | try:
181 | if case(op):
182 | for j in range(op, 0, -1): #search backwards for word boundary
183 | if case(j-1) and self.token_no_space(alignment.alignment[j][rel_idx]):
184 | complete_idxs.append(j-1)
185 | else: break
186 | for j in range(op+1, len(alignment)): #search forwards for word boundary
187 | if case(j) and self.token_no_space(alignment.alignment[j][rel_idx]):
188 | complete_idxs.append(j)
189 | else: break
190 | except TypeError as e:
191 | print(alignment)
192 | raise e
193 | return complete_idxs
194 |
195 | def add_op(op, l):
196 | if self.complete_words:
197 | l.extend(complete_word(op))
198 | else:
199 | l.append(op)
200 | return l
201 |
202 | idxs_to_push = []
203 | for i,op in enumerate(op_idxs):
204 | if len(idxs_to_push) >= k: break
205 | if op in idxs_to_push: continue
206 |
207 | if not self.contiguous_edits:
208 | idxs_to_push = add_op(op, idxs_to_push)
209 | else:
210 | # start building set of contiguous edits
211 | contiguous_idxs = add_op(op, [])
212 | for opp in op_idxs[i+1:]:
213 | if opp in idxs_to_push: continue
214 | if opp in contiguous_idxs: continue
215 | if len(idxs_to_push) + len(contiguous_idxs) >= k: break
216 |
217 | cond1 = opp == max(contiguous_idxs) + 1
218 | cond2 = opp == min(contiguous_idxs) - 1
219 | if cond1 or cond2:
220 | contiguous_idxs = add_op(opp, contiguous_idxs)
221 | idxs_to_push.extend(contiguous_idxs)
222 |
223 | return idxs_to_push
224 |
225 | # if len(op_idxs) > 0:
226 | # idxs_to_push = [op_idxs[0]]
227 | # for op in op_idxs[1:]:
228 | # if len(idxs_to_push) >= self.n_contiguous_edits: break
229 | # else:
230 | # cond1 = op == max(idxs_to_push) + 1
231 | # cond2 = op == min(idxs_to_push) - 1
232 | # if cond1 or cond2:
233 | # idxs_to_push.append(op)
234 | # if not self.complete_words:
235 | # alignment.push_forward(idxs_to_push, agent=1)
236 | # else:
237 | # idxs_to_push = list(set(idxs_to_push))
238 | # idxs_to_push_ = []
239 | # for op in idxs_to_push:
240 | # for case, rel_idx in zip(
241 | # (alignment.is_insertion, alignment.is_deletion),
242 | # (1,0)):
243 | # if case(op):
244 | # for j in range(op, 0, -1):
245 | # if case(j-1)\
246 | # and self.token_no_space(alignment.alignment[j][rel_idx]):
247 | # # and alignment.alignment[j][rel_idx].startswith('#'):
248 | # idxs_to_push_.append(j-1)
249 | # else: break
250 | # for j in range(op+1, len(alignment)):
251 | # if case(j)\
252 | # and self.token_no_space(alignment.alignment[j][rel_idx]):
253 | # # and alignment.alignment[j][rel_idx].startswith('#'):
254 | # idxs_to_push_.append(j)
255 | # else: break
256 | # # print(idxs_to_push)
257 | # idxs_to_push = list(set(idxs_to_push + idxs_to_push_))
258 | # alignment.push_forward(list(idxs_to_push), agent=1)
259 | #
260 |
261 | class EditingEnvironment():
262 |
263 | def __init__(self, oracle, oracle_stop_p = 0.5, n_oracle_edits = -1, bleu_weights=(0.25,0.25,0.25,0.25)):
264 | """
265 | :param oracle_stop_p: the oracle makes edits until it stops with probability oracle_stop_p
266 | :param bleu_weights: used for the reward, not important since we don't use rewards
267 | """
268 | self.oracle = oracle
269 | self.oracle_stop_p = oracle_stop_p
270 | self.n_oracle_edits = n_oracle_edits
271 | self.bleu_weights=bleu_weights
272 |
273 | self.alignment = None
274 |
275 | def check_input_(self, canvas=None, target_tokens=None, alignment=None, device=torch.device('cpu')):
276 | if all((alignment is None, canvas is None, target_tokens is None)):
277 | raise ValueError('alignment, canvas and target cannot all be None')
278 | elif alignment is None and (canvas is None or target_tokens is None):
279 | raise ValueError('canvas and target cannot both be None')
280 |
281 | if not alignment is None:
282 | if canvas is None:
283 | canvas = alignment.get_source_canvas()
284 | if target_tokens is None:
285 | target_tokens = alignment.get_target_tokens()
286 |
287 | if canvas != alignment.get_source_canvas():
288 | raise ValueError('canvas does not match alignment')
289 | if target_tokens != alignment.get_target_tokens():
290 | raise ValueError('target does not match alignment')
291 | alignment = copy.deepcopy(alignment)
292 | else:
293 | alignment = align_canvas(canvas, target_tokens, self.oracle.align_model,
294 | self.oracle.align_tokenizer, baseline_score=self.oracle.baseline_score,
295 | device=device)
296 | return canvas, target_tokens, alignment
297 |
298 | def oracle_edit(self, canvas=None, target_tokens=None, alignment=None, device=torch.device('cpu'), return_alignment=False):
299 | canvas, target_tokens, alignment = self.check_input_(canvas, target_tokens, alignment, device=device)
300 | """
301 | Let the oracle make an edit
302 | """
303 | if self.n_oracle_edits >= 0:
304 | self.oracle.edit_alignment(alignment, k=self.n_oracle_edits)
305 | # for i in range(self.n_oracle_edits):
306 | # self.oracle.edit_alignment(alignment)
307 | else:
308 | k = np.random.geometric(self.oracle_stop_p)
309 | self.oracle.edit_alignment(alignment, k=k)
310 | # while True:
311 | # self.oracle.edit_alignment(alignment)
312 | # if np.random.random() <= self.oracle_stop_p:
313 | # break
314 | if return_alignment:
315 | return alignment
316 | else:
317 | return alignment.get_source_canvas()
318 |
319 | def reset(self, target_tokens=None, source_canvas=None, alignment=None, device=torch.device('cpu'),
320 | return_alignment=False):
321 | """
322 | Reset the environment. Can either specify the target and source or the alignment (saves the cost of
323 | aligning the source to the target
324 | """
325 | self.alignment = self.oracle_edit(source_canvas, target_tokens, alignment, return_alignment=True)
326 | if return_alignment:
327 | return copy.deepcopy(self.alignment)
328 | else:
329 | return copy.deepcopy(self.alignment.get_source_canvas())
330 |
331 | def compute_reward(self, prev_canvas, agent_canvas, target_tokens):
332 | tokenizer = self.oracle.align_tokenizer
333 |
334 | prev_text = prev_canvas.render(tokenizer)
335 | target_text = Canvas(target_tokens).render(tokenizer)
336 | agent_text = agent_canvas.render(tokenizer)
337 |
338 | prev_bleu_score = sentence_bleu([target_text],
339 | prev_text, weights=self.bleu_weights)
340 | agent_bleu_score = sentence_bleu([target_text],
341 | agent_text, weights=self.bleu_weights)
342 | delta = agent_bleu_score - prev_bleu_score
343 | return delta, agent_bleu_score, prev_bleu_score
344 |
345 | def step(self, canvas, return_alignment=False, device=torch.device('cpu')):
346 | """
347 | One step of the environment. Consists of oracle edit, computing rewards etc.
348 | """
349 | prev_canvas, target_tokens = self.alignment.get_source_canvas(), self.alignment.get_target_tokens()
350 | reward, agent_score, _ = self.compute_reward(prev_canvas, canvas, target_tokens)
351 |
352 | self.alignment = self.oracle_edit(canvas=canvas, target_tokens=target_tokens, device=device, return_alignment=True)
353 | if return_alignment:
354 | return copy.deepcopy(self.alignment), reward
355 | else:
356 | return copy.deepcopy(self.alignment.get_source_canvas()), reward
357 |
--------------------------------------------------------------------------------
/infosol/evaluate.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import pickle
3 | import itertools
4 | import json
5 | import torch
6 | import tqdm
7 |
8 | from distutils.util import strtobool
9 | from nltk.translate.bleu_score import sentence_bleu
10 |
11 | from transformers import AutoTokenizer, BartModel, BertModel
12 | from datasets import load_from_disk
13 |
14 | from infosol.alignment import Alignment, Canvas, batch_align_canvases
15 | from infosol.models.word_edit_model import BertEditor, BartS2SEditor
16 | from infosol.env import WordEditOracle, EditingEnvironment
17 |
18 | class EvaluationInstance():
19 |
20 | def __init__(self, datum):
21 | self.done = False
22 | self.rewards = []
23 | self.history = []
24 | self.actions = []
25 | self.consistency = []
26 | self.bleu_scores = []
27 | self.canvas_history = []
28 | self.canvas_actions = []
29 | self.target_tokens = datum['target_tokens']
30 | self.canvas = Canvas(datum['source_tokens'], datum['token_type_ids'])
31 | self.source_text = datum['source_text']
32 | self.target_text = datum['target_text']
33 |
34 | def to_dict(self):
35 | return {
36 | 'source': self.source_text,
37 | 'target': self.target_text,
38 | 'history': self.history,
39 | 'rewards': self.rewards,
40 | 'actions': self.actions,
41 | 'consistency': self.consistency,
42 | 'bleu_scores': self.bleu_scores,
43 | 'canvas history': self.canvas_history,
44 | 'canvas actions': self.canvas_actions
45 | }
46 |
47 |
48 | class Evaluate():
49 |
50 | """
51 | Evaluation job
52 | """
53 |
54 | @classmethod
55 | def add_args(cls, parser):
56 | parser.add_argument('--model_path', type=str,
57 | default='') #/path/to/WEIGHTS.bin)
58 | parser.add_argument('--out_path', type=str,
59 | default='') #/path/to/save_file.pickle)
60 | parser.add_argument('--cuda_device', type=int, default=0)
61 | parser.add_argument('--data_path', type=str, default='') #/path/to/data/)
62 | parser.add_argument('--max_data', type=int, default=1000)
63 | parser.add_argument('--idf_path', type=str, default='') #/path/to/idf.pickle)
64 | parser.add_argument('--n_oracle_edits', type=int, default=3)
65 | parser.add_argument('--oracle_stop_p', type=float, default=0.5)
66 | parser.add_argument('--n_episodes', type=int, default=4)
67 | parser.add_argument('--BLEU_threshold', type=float, default=1.0)
68 | parser.add_argument('--adjacent_ops', type=lambda x:bool(strtobool(x)), default=True)
69 | parser.add_argument('--complete_words', type=lambda x:bool(strtobool(x)), default=True)
70 | parser.add_argument('--contiguous_edits', type=lambda x:bool(strtobool(x)), default=True)
71 | parser.add_argument('--baseline_alignment_score', type=float, default=0.3)
72 | parser.add_argument('--sort_ops', type=str, default='sort')
73 | parser.add_argument('--bleu_ngrams', type=int, default=1)
74 | parser.add_argument('--keyword_gen', type=lambda x:bool(strtobool(x)), default=False)
75 |
76 | def setup(self, args):
77 | self.args = args
78 | self.device = torch.device('cuda:{}'.format(args.cuda_device))
79 | print('Loading Environment')
80 | self.tokenizer, self.align_model, self.oracle, self.env = self.setup_env(args)
81 | print('Loading Model')
82 | self.model = self.setup_model(args)
83 |
84 | def setup_model(self, args):
85 | raise NotImplementedError
86 |
87 | def setup_env(self, args):
88 | raise NotImplementedError
89 |
90 | def create_instance(self, datum):
91 | inst = EvaluationInstance(datum)
92 | alignment = Alignment(alignment=datum['alignment'], scores=datum['alignment_scores'])
93 | inst.canvas = self.env.reset(alignment=alignment)
94 | return inst
95 |
96 | def step_instance(self, inst, gen, alignment):
97 | """
98 | Advance instance by one steps. This includes letting the oracle make changes, and tracking metrics etc.
99 | """
100 | prev_canvas, target_tokens = inst.canvas, inst.target_tokens
101 | agent_canvas = gen
102 | # agent_canvas, target_tokens_ = alignment.get_source_canvas(), alignment.get_target_tokens()
103 | # assert target_tokens == target_tokens_, 'mismatching target tokens {target_tokens} vs. {target_tokens_}'
104 |
105 | reward, agent_score, prev_score = self.env.compute_reward(prev_canvas, agent_canvas, target_tokens)
106 |
107 | if not self.args.keyword_gen:
108 | inst.canvas = self.env.oracle_edit(alignment=alignment)
109 | else:
110 | oracle_canvas = self.env.oracle_edit(alignment=alignment)
111 | tokens_, type_ids_ = oracle_canvas.tokens, oracle_canvas.type_ids
112 | keywords = [t for t,tid in zip(tokens_, type_ids_) if tid == 2]
113 | kw_type_ids = [2] * len(keywords)
114 | inst.canvas = Canvas(keywords, type_ids=kw_type_ids)
115 |
116 | tokenizer = self.env.oracle.align_tokenizer
117 | inst.history.append(prev_canvas.render(tokenizer))
118 | inst.canvas_history.append(prev_canvas.copy())
119 | inst.bleu_scores.append((prev_score, agent_score))
120 | inst.actions.append(agent_canvas.render(tokenizer))
121 | inst.canvas_actions.append(agent_canvas.copy())
122 | inst.rewards.append(reward)
123 |
124 | prev_text = prev_canvas.render(tokenizer)
125 | agent_text = agent_canvas.render(tokenizer)
126 | consistency = sentence_bleu([agent_text], prev_text, weights=(1,0,0,0))
127 | inst.consistency.append(consistency)
128 |
129 | if len(inst.history) >= self.args.n_episodes:
130 | inst.done = True
131 | if max(inst.bleu_scores[-1]) >= self.args.BLEU_threshold:
132 | inst.done = True
133 |
134 | return inst
135 |
136 | def episode(self, instances):
137 | """
138 | One editing episode. Model makes changes, align the model outputs to the targets, let the oracle make changes.
139 | Note the oracle makes changes when initializing the instances, that's why the order here is model then oracle
140 | instead of oracle then model as described in the paper
141 | """
142 | canvases = [inst.canvas for inst in instances]
143 | generations = self.gen_model(canvases)
144 | clean_generations = [g.clean() for g in generations]
145 | align_model, align_tokenizer = self.env.oracle.align_model, self.env.oracle.align_tokenizer
146 | targets = [inst.target_tokens for inst in instances]
147 | alignments = batch_align_canvases(clean_generations, targets, align_model, align_tokenizer, device=self.device)
148 | # return [self.step_instance(inst, gen, a) for inst,gen,a in zip(instances, generations, alignments)]
149 | instances = [self.step_instance(inst, gen, a) for inst,gen,a in zip(instances, generations, alignments)]
150 | finished_instances = [inst for inst in instances if inst.done]
151 | instances = [inst for inst in instances if not inst.done]
152 | return instances, finished_instances
153 |
154 | def evaluate_instance(self, datum):
155 | rewards = []
156 | history = []
157 | actions = []
158 | consistency = []
159 | bleu_scores = []
160 | canvas_history = []
161 | canvas_actions = []
162 | canvas = self.env.reset(alignment=Alignment(
163 | alignment=datum['alignment'],
164 | scores=datum['alignment_scores']))
165 | for e in range(self.args.n_episodes):
166 | input_text = canvas.render(self.tokenizer)
167 | target_text = self.env.alignment.get_target_canvas().render(self.tokenizer)
168 | history.append(input_text)
169 | canvas_history.append(canvas.copy())
170 | inp_score = sentence_bleu([target_text], input_text, weights=self.env.bleu_weights)
171 |
172 | canvas = self.gen_model(canvas)
173 |
174 | output_text = canvas.render(self.tokenizer)
175 | out_score = sentence_bleu([target_text], output_text, weights=self.env.bleu_weights)
176 | bleu_scores.append((inp_score, out_score))
177 | actions.append(output_text)
178 | canvas_actions.append(canvas.copy())
179 |
180 | canvas = canvas.clean()
181 | canvas, r = self.env.step(canvas, device=self.device)
182 | rewards.append(r)
183 | consistency.append(sentence_bleu([output_text], input_text, weights=(1,0,0,0)))
184 | return {
185 | 'source': datum['source_text'],
186 | 'target': datum['target_text'],
187 | # 'target': self.env.target_text,
188 | 'history': history,
189 | 'rewards': rewards,
190 | 'actions': actions,
191 | 'consistency': consistency,
192 | 'bleu_scores': bleu_scores,
193 | 'canvas history': canvas_history,
194 | 'canvas actions': canvas_actions
195 | }
196 |
197 | def gen_model(self, canvases):
198 | raise NotImplementedError
199 |
200 | # def gen_user(self, prev_canvases, canvases, targets, device=torch.device('cpu')):
201 | # batch_size = 64
202 | # user_canvases, rewards = [], []
203 | # def batch(lst, n=1):
204 | # l = len(lst)
205 | # for ndx in range(0, l, n):
206 | # yield lst[ndx:min(ndx + n, l)]
207 | # for pc, c, t in tqdm.tqdm(zip(
208 | # batch(prev_canvases, batch_size), batch(canvases, batch_size),
209 | # batch(targets, batch_size))):
210 | # _, user_c, r = self.env.step_(pc, c, t, device)
211 | # user_canvases.extend(user_c)
212 | # rewards.extend(r)
213 | # return zip(user_canvases, rewards)
214 |
215 | def run(self):
216 | self.model = self.model.eval().to(self.device)
217 | self.oracle = self.oracle.to(self.device)
218 |
219 | data = load_from_disk(self.args.data_path)['test']
220 | instances = [self.create_instance(datum) for datum in itertools.islice(data, 0, self.args.max_data)]
221 | finished_instances = []
222 | while len(instances) > 0:
223 | instances, finst_ = self.episode(instances)
224 | finished_instances.extend(finst_)
225 | # for e in range(self.args.n_episodes):
226 | # instances = self.episode(instances)
227 | # results = [inst.to_dict() for inst in instances]
228 | results = [inst.to_dict() for inst in finished_instances]
229 | # for datum in tqdm.tqdm(itertools.islice(data,0,self.args.max_data),
230 | # total=self.args.max_data):
231 | # res = self.evaluate_instance(datum)
232 | # results.append(res)
233 | with open(self.args.out_path, 'wb') as f:
234 | pickle.dump(results, f)
235 |
236 | class EvaluateEditor(Evaluate):
237 |
238 | @classmethod
239 | def add_args(cls, parser):
240 | super().add_args(parser)
241 | parser.add_argument('--top_k', type=int, default=10)
242 | parser.add_argument('--stop_threshold', type=float, default=0.95)
243 |
244 | def gen_model(self, canvases):
245 | canvases = list(tqdm.tqdm(self.model.batch_depth_decode(
246 | canvases,
247 | device=self.device,
248 | top_k=self.args.top_k,
249 | stop_threshold=self.args.stop_threshold,
250 | return_idx=True
251 | ), total=len(canvases)))
252 | canvases = [(i,c) for c,i in canvases]
253 | canvases = [c for i,c in sorted(canvases)]
254 | return canvases
255 |
256 | class EvaluateS2S(Evaluate):
257 |
258 | @classmethod
259 | def add_args(cls, parser):
260 | super().add_args(parser)
261 | parser.add_argument('--do_sample', type=lambda x:bool(strtobool(x)), default=False)
262 | parser.add_argument('--top_p', type=float, default=0.95)
263 | parser.add_argument('--max_length', type=int, default=64)
264 | parser.add_argument('--num_beams', type=int, default=10)
265 | parser.add_argument('--length_penalty', type=float, default=1.0)
266 |
267 | def gen_model(self, canvases):
268 | return self.model.batch_generate(canvases, device=self.device,
269 | do_sample=self.args.do_sample, top_p=self.args.top_p, max_length=self.args.max_length,
270 | num_beams=self.args.num_beams, length_penalty=self.args.length_penalty)
271 |
272 | class EvaluateBart(Evaluate):
273 |
274 | def setup_env(self, args):
275 | tokenizer = AutoTokenizer.from_pretrained('facebook/bart-base')
276 | align_model = BartModel.from_pretrained('facebook/bart-base').encoder
277 | with open(self.args.idf_path, 'rb') as f:
278 | idf_dict = pickle.load(f)
279 | oracle = WordEditOracle(align_model, tokenizer, idf_dict=idf_dict,
280 | sort_ops='sort', adjacent_ops=args.adjacent_ops,
281 | baseline_score=args.baseline_alignment_score,
282 | avoid_delete=False, complete_words=args.complete_words,
283 | token_no_space=lambda x: not x.startswith('Ġ'),
284 | contiguous_edits=args.contiguous_edits)
285 | bleu_weights = [1/args.bleu_ngrams if i < args.bleu_ngrams else 0 for i in range(4)]
286 | env = EditingEnvironment(oracle, n_oracle_edits=args.n_oracle_edits,
287 | oracle_stop_p=args.oracle_stop_p,
288 | bleu_weights = bleu_weights)
289 |
290 | return tokenizer, align_model, oracle, env
291 |
292 | class EvaluateBert(Evaluate):
293 |
294 | def setup_env(self, args):
295 | tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')
296 | align_model = BertModel.from_pretrained('bert-base-uncased')
297 | with open(self.args.idf_path, 'rb') as f:
298 | idf_dict = pickle.load(f)
299 | oracle = WordEditOracle(align_model, tokenizer, idf_dict=idf_dict,
300 | sort_ops='sort', adjacent_ops=args.adjacent_ops,
301 | baseline_score=args.baseline_alignment_score,
302 | avoid_delete=False, complete_words=args.complete_words,
303 | token_no_space=lambda x: x.startswith('#'),
304 | contiguous_edits=args.contiguous_edits)
305 | bleu_weights = [1/args.bleu_ngrams if i < args.bleu_ngrams else 0 for i in range(4)]
306 | env = EditingEnvironment(oracle, n_oracle_edits=args.n_oracle_edits,
307 | oracle_stop_p=args.oracle_stop_p,
308 | bleu_weights = bleu_weights)
309 |
310 | return tokenizer, align_model, oracle, env
311 |
312 | class EvaluateBartEditor(EvaluateEditor, EvaluateBart):
313 |
314 | def setup_model(self, args):
315 | model = BertEditor(model_type='bart',
316 | tokenizer=self.tokenizer,
317 | model_file='facebook/bart-base')
318 | model.load_state_dict(torch.load(self.args.model_path))
319 | return model
320 |
321 | class EvaluateBertEditor(EvaluateEditor, EvaluateBert):
322 |
323 | def setup_model(self, args):
324 | model = BertEditor(model_type='bert',
325 | tokenizer=self.tokenizer)
326 | model.load_state_dict(torch.load(self.args.model_path))
327 | return model
328 |
329 |
330 | class EvaluateBartS2S(EvaluateS2S, EvaluateBart):
331 |
332 | def setup_model(self, args):
333 | model = BartS2SEditor(self.align_model, self.tokenizer)
334 | model.load_state_dict(torch.load(self.args.model_path))
335 | return model
336 |
337 | class EvaluateBartLarge(EvaluateEditor, EvaluateBart):
338 |
339 | def setup_model(self, args):
340 | model = BertEditor(model_type='bart-large',
341 | tokenizer=self.tokenizer,
342 | model_file='facebook/bart-large')
343 | model.load_state_dict(torch.load(self.args.model_path))
344 | return model
345 |
346 | class EvaluateNoModel(EvaluateBartEditor):
347 |
348 | def gen_model(self, canvases):
349 | return canvases
350 |
351 |
352 | if __name__ == '__main__':
353 |
354 | parser = argparse.ArgumentParser()
355 | subparsers = parser.add_subparsers()
356 |
357 | parser_barteditor = subparsers.add_parser('BartEditor')
358 | parser_barteditor.set_defaults(func=EvaluateBartEditor)
359 | EvaluateBartEditor.add_args(parser_barteditor)
360 |
361 | parser_barts2s = subparsers.add_parser('BartS2S')
362 | parser_barts2s.set_defaults(func=EvaluateBartS2S)
363 | EvaluateBartS2S.add_args(parser_barts2s)
364 |
365 | parser_berteditor = subparsers.add_parser('BertEditor')
366 | parser_berteditor.set_defaults(func=EvaluateBertEditor)
367 | EvaluateBertEditor.add_args(parser_berteditor)
368 |
369 | parser_bartlargeeditor = subparsers.add_parser('BartLargeEditor')
370 | parser_bartlargeeditor.set_defaults(func=EvaluateBartLarge)
371 | EvaluateBartLarge.add_args(parser_bartlargeeditor)
372 |
373 | parser_baseline = subparsers.add_parser('Baseline')
374 | parser_baseline.set_defaults(func=EvaluateNoModel)
375 | EvaluateNoModel.add_args(parser_baseline)
376 |
377 | args = parser.parse_args()
378 | eval_instance = args.func()
379 | eval_instance.setup(args)
380 | eval_instance.run()
381 |
--------------------------------------------------------------------------------
/infosol/modeling_utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | def compute_entropy(probs):
4 | log_probs = torch.log(probs)
5 | entropies = -torch.matmul(probs, log_probs.transpose(0,1)).diagonal()
6 | return entropies
7 |
8 | def top_p_warp(scores, top_p=0.95, filter_value=-float("Inf")):
9 | sorted_logits, sorted_indices = torch.sort(scores, descending=True)
10 | cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1)
11 |
12 | # Remove tokens with cumulative top_p above the threshold (token with 0 are kept)
13 | sorted_indices_to_remove = cumulative_probs > top_p
14 | # Shift the indices to the right to keep also the first token above the threshold
15 | sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
16 | sorted_indices_to_remove[..., 0] = 0
17 |
18 | # scatter sorted tensors to original indexing
19 | indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
20 | scores = scores.masked_fill(indices_to_remove, filter_value)
21 | return scores
22 |
--------------------------------------------------------------------------------
/infosol/models/__pycache__/word_edit_model.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ffaltings/InteractiveTextGeneration/febd658a91227dd88fbc5355111382b732a91647/infosol/models/__pycache__/word_edit_model.cpython-310.pyc
--------------------------------------------------------------------------------
/infosol/models/word_edit_model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import tqdm
3 | import os
4 | import time
5 | import pdb
6 | import copy
7 | import torch.nn.functional as F
8 | import random
9 | import operator
10 | import numpy as np
11 | import torch.nn.functional as F
12 | import torch.multiprocessing as mp
13 |
14 | from queue import PriorityQueue
15 | from infosol.alignment import *
16 | from infosol.decoding import ParallelDecodingMixin
17 | from transformers import BertModel, AutoTokenizer, T5ForConditionalGeneration, BertConfig, AutoConfig, AutoModel
18 | from transformers.models.bart.modeling_bart import BartEncoder, BartModel, BartForConditionalGeneration, BartDecoder
19 | from transformers.modeling_outputs import Seq2SeqLMOutput, BaseModelOutput
20 |
21 | MP_TIMEOUT = 2 #2 seconds timeout
22 |
23 | def log_factorial(n):
24 | if n <= 1: return 0.
25 | return np.log(np.arange(2,n+1)).sum()
26 |
27 | class BeamSearchNode():
28 |
29 | """
30 | Class for beam search
31 | """
32 |
33 | def __init__(self, prevNode, canvas, stop, logp, length):
34 | self.prevNonde = prevNode
35 | self.canvas = canvas
36 | self.stop = stop
37 | self.logp = logp
38 | self.len = length
39 |
40 | def eval(self, len_normalize=False):
41 | if len_normalize:
42 | return self.logp + log_factorial(self.len-1)
43 | else: return self.logp
44 |
45 | def __eq__(self, other):
46 | return self.eval() == other.eval()
47 |
48 | def __lt__(self, other):
49 | return self.eval() < other.eval()
50 |
51 | def __gt__(self, other):
52 | return self.eval() > other.eval()
53 |
54 | def __le__(self, other):
55 | return self.eval() <= other.eval()
56 |
57 | def __ge__(self, other):
58 | return self.eval() >= other.eval()
59 |
60 | def noise_alignment(alignment, vocab_sampler):
61 | """
62 | Make a random edit to the source canvas and update the alignment
63 | """
64 | canvas = alignment.get_source_canvas()
65 | positions = [-1] + [i for i in range(len(canvas)) if canvas.type_ids[i] < 3]
66 | position = np.random.choice(positions)
67 | if position >= 0:
68 | operation = np.random.randint(0, 3)
69 | else:
70 | operation = 0
71 |
72 | if operation < 2: # insert and substitute
73 | token = vocab_sampler.get()
74 | alignment.operate(position, operation, token)
75 | elif operation == 2: # delete a token
76 | alignment.operate(position, operation, None)
77 |
78 | def sample_trajectory(alignment, tokenizer, vocab_sampler, noise_prob=0.2,
79 | max_traj_length=64):
80 |
81 | """
82 | Sample a trajectory for training. Take the alignment and randomly push it forward
83 | a few steps (make the source more like the target). Inject noise by simulating a noisy
84 | process where you make errors when pushing the alignment forward.
85 | """
86 |
87 | ops = alignment.get_non_const_ops()
88 | ops = np.random.permutation(ops)
89 | ops_len = len(ops)
90 | n_errors = 0
91 | n_ops = 0
92 | trajectory = []
93 | trajectory.append((n_errors, n_ops))
94 | for i in range(max_traj_length):
95 | if random.random() > noise_prob:
96 | if n_ops == ops_len and n_errors == 0: break
97 | if random.random() <= n_errors / (n_errors + ops_len - n_ops): # you might fix an error
98 | n_errors -= 1
99 | else: n_ops += 1
100 | else:
101 | n_errors += 1
102 | trajectory.append((n_errors, n_ops))
103 | traj_length = len(trajectory)
104 | traj_idx = np.random.randint(0,traj_length) # sample state from trajectory
105 | n_errors, n_ops = trajectory[traj_idx]
106 | alignment.push_forward(ops[:n_ops])
107 | for i in range(n_errors):
108 | noise_alignment(alignment, vocab_sampler)
109 | return_ops = alignment.get_non_const_ops()
110 | actions = alignment.get_actions(return_ops)
111 | if len(actions) == 0: actions = [(1, None, None, None)]
112 | else: actions = [(0, a[0], a[1], tokenizer.convert_tokens_to_ids(a[2])) for a in actions]
113 |
114 | canvas = alignment.get_source_canvas()
115 |
116 | return canvas, actions, traj_length
117 |
118 | class mlp_head(torch.nn.Module):
119 | """
120 | MLP head to put on top of BERT
121 | """
122 |
123 | def __init__(self, hidden_size, out_size, layer_norm_eps):
124 | super().__init__()
125 | self.head = torch.nn.Sequential(
126 | torch.nn.Linear(hidden_size, hidden_size),
127 | torch.nn.GELU(),
128 | torch.nn.LayerNorm(hidden_size, eps=layer_norm_eps),
129 | torch.nn.Linear(hidden_size, out_size, bias=False)
130 | )
131 |
132 | def forward(self, inp):
133 | out = self.head(inp)
134 | return out
135 |
136 | def draw_multi(p_values):
137 | return np.where(np.random.multinomial(1, p_values) == 1)[0][0]
138 |
139 | class BertEditor(torch.nn.Module, ParallelDecodingMixin): #TODO: clean up, split out Bart model
140 | """
141 | Main editor model
142 | """
143 |
144 | def __init__(self, model_type = 'bert', model_file='bert-base-uncased', is_decoder=False,
145 | type_vocab_size=2, use_type_ids=True, training_noise=0.2,
146 | tokenizer=None, vocab_sampler=None):
147 | """
148 | :param model_file: model file to initialize bert model, default set to pretrained model from huggingface
149 | :param use_memory: whether to use memory of oracle actions
150 | :param is_decoder: whether to use decoder architecture for BERT
151 | :param cat_embeds: whether to concatenate memory to inputs instead of feeding to cross attention layers (in decoder case)
152 | """
153 | super().__init__()
154 | self.model_type = model_type
155 | self.use_type_ids = use_type_ids
156 | self.training_noise = training_noise
157 |
158 | self.tokenizer = tokenizer
159 | vocab_size = len(self.tokenizer)
160 | self.vocab_sampler = vocab_sampler
161 | self.delete_type_id = 3
162 |
163 | if self.model_type == 'bert':
164 | self.base_model = BertModel.from_pretrained(model_file)
165 | self.base_model.resize_token_embeddings(vocab_size)
166 | self.base_model.embeddings.token_type_embeddings = self.base_model._get_resized_embeddings(
167 | self.base_model.embeddings.token_type_embeddings, 5) # resize token type embeddings
168 | hidden_size = self.base_model.config.hidden_size
169 | layer_norm_eps = self.base_model.config.layer_norm_eps
170 | elif self.model_type in ['bart', 'bart-large']:
171 | config = AutoConfig.from_pretrained(model_file)
172 | config.n_type_ids = 5
173 | self.base_model = BartForConditionalGenerationTType.from_pretrained(
174 | model_file, config=config).model.encoder
175 | hidden_size = config.hidden_size
176 | layer_norm_eps = 1e-12
177 |
178 | self.emission_head = mlp_head(hidden_size, 1, layer_norm_eps) # decides when to "emit" a document/stop
179 | self.location_head = mlp_head(hidden_size, 1, layer_norm_eps) # decides where to edit
180 | self.operation_head = mlp_head(hidden_size, 3, layer_norm_eps) # decides what operation to do
181 | self.vocabulary_head = mlp_head(hidden_size, vocab_size * 2, layer_norm_eps) # decides what token to insert/substitute for
182 |
183 | def forward(self, input_ids, attention_mask, encoder_states=None, encoder_mask=None, token_type_ids=None,
184 | return_hidden_states=False):
185 |
186 | N, L = input_ids.shape
187 |
188 | if encoder_states is None:
189 | encoder_states = torch.zeros((N, 1, self.base_model.config.hidden_size)).to(input_ids.device)
190 | encoder_mask = torch.zeros((N, 1), dtype=torch.int32).to(input_ids.device)
191 |
192 | if not self.use_type_ids:
193 | token_type_ids = None
194 |
195 | if token_type_ids is None:
196 | token_type_ids = torch.zeros_like(attention_mask)
197 | token_type_ids = torch.cat([
198 | torch.zeros_like(encoder_mask),
199 | token_type_ids],
200 | dim=1)
201 | attention_mask = torch.cat(
202 | [encoder_mask, attention_mask], dim=1)
203 |
204 | # TODO: split into separate model classes
205 | if self.model_type == 'bert':
206 | inputs_embeds = self.base_model.embeddings.word_embeddings(input_ids) #embed token ids
207 | inputs_embeds = torch.cat([encoder_states, inputs_embeds], dim=1) #concatenate with encoder states
208 | inputs_embeds = self.base_model.embeddings(inputs_embeds=inputs_embeds,
209 | token_type_ids=token_type_ids)
210 | out = self.base_model(inputs_embeds=inputs_embeds, attention_mask=attention_mask)
211 | elif self.model_type == 'bart' or self.model_type == 'bart-large':
212 | inputs_embeds = self.base_model.embed_tokens(input_ids)
213 | inputs_embeds = torch.cat([encoder_states, inputs_embeds], dim=1) #concatenate with encoder states
214 | out = self.base_model(inputs_embeds=inputs_embeds, token_type_ids=token_type_ids,
215 | attention_mask=attention_mask)
216 | prefix_len = encoder_states.shape[1]
217 |
218 | hidden_states = out[0][:, prefix_len:, :]
219 | emission_out = self.emission_head(hidden_states[:,0,:]).squeeze(-1)
220 | location_out = self.location_head(hidden_states).squeeze(-1)
221 | operation_out = self.operation_head(hidden_states)
222 | vocabulary_out = self.vocabulary_head(hidden_states)
223 | vocabulary_out = vocabulary_out.reshape(N, L, 2, -1)
224 | V = vocabulary_out.shape[-1]
225 |
226 | # if np.isnan(emission_out.detach().cpu()).any():
227 | # pdb.set_trace()
228 |
229 | # Location mask
230 | # Mask out padding
231 | padding_mask = 1 - attention_mask[:, prefix_len:]
232 | # Mask out stricken tokens
233 | stricken_mask = (token_type_ids[:, prefix_len:] >= self.delete_type_id) # cannot insert, delete or sub stricken-out tokens <-- may want to allow substitutions?
234 | # Mask out EOS tokens
235 | eos_mask = torch.zeros_like(attention_mask[:, prefix_len:])
236 | attention_lengths = attention_mask[:, prefix_len:].sum(dim=1, dtype=torch.int32)
237 | for i in range(attention_lengths.shape[0]):
238 | eos_mask[i, attention_lengths[i].item()-1] = 1
239 | location_mask = torch.clamp(padding_mask + stricken_mask + eos_mask, max=1)
240 | location_out += location_mask * -1e20
241 |
242 | # Operation mask
243 | # BOS mask
244 | bos_mask = torch.zeros_like(operation_out)
245 | bos_mask[:, 0, 1:] = 1
246 | operation_mask = torch.clamp(bos_mask, max=1)
247 | operation_out += operation_mask * -1e20
248 |
249 | # Vocabulary mask
250 | # Can't sub token for same one
251 | sub_mask = torch.zeros_like(vocabulary_out)
252 | sub_mask[:, :, 1, :] = F.one_hot(input_ids, num_classes=V)
253 | vocabulary_mask = torch.clamp(sub_mask, max=1)
254 | vocabulary_out += vocabulary_mask * -1e20
255 |
256 | return_out = (emission_out, location_out, operation_out, vocabulary_out)
257 | if return_hidden_states:
258 | return_out = return_out + (hidden_states,)
259 | return return_out
260 |
261 | def compute_instance_loss(self, emission_out, location_out, operation_out, vocabulary_out,
262 | emission_idx, location_idxs, operation_idxs, vocabulary_idxs, ops_len):
263 | """
264 | Calculate cross entropy loss based on model outputs and gold targets
265 | """
266 | n_actions = max(emission_out.shape[0] - 1, 1)
267 |
268 | emission_loss = F.binary_cross_entropy(torch.sigmoid(emission_out), emission_idx.float())
269 | emission_mask = (1-emission_idx).type(torch.int32)
270 |
271 | location_loss = F.cross_entropy(location_out, location_idxs, reduction='none')
272 | location_loss = (location_loss * emission_mask).sum()
273 |
274 | # TODO: should technically mask here for bos, only viable operaiton is insertion
275 | operation_loss = F.cross_entropy(
276 | operation_out.gather(1,
277 | location_idxs.reshape(-1,1,1).expand(-1,-1,operation_out.size(-1))).squeeze(1),
278 | operation_idxs, reduction='none')
279 | # multiply by a mask so that don't count the loss in cases where the model should stop
280 | operation_loss = (operation_loss * emission_mask).sum()
281 |
282 | # TODO: should also technically mask here for sub identical token, since not allowed
283 | vocabulary_loss = F.cross_entropy(
284 | vocabulary_out.gather(1,
285 | location_idxs.reshape(-1,1,1,1).expand(-1,-1,2,vocabulary_out.size(-1))).squeeze(1).gather(1,
286 | torch.clamp(operation_idxs, max=1).reshape(-1,1,1).expand(-1,1,vocabulary_out.size(-1))).squeeze(1),
287 | vocabulary_idxs, reduction='none') # <- NOT averaged
288 | # only use vocabulary loss for substitution and insertion opertions (deletion mask)
289 | deletion_mask = 1 - (operation_idxs == 2).type(torch.long) # <- 2 = deletion operation
290 | vocabulary_loss = (vocabulary_loss * deletion_mask * emission_mask).sum()
291 |
292 | # scaling
293 | total_loss = emission_loss + location_loss + operation_loss + vocabulary_loss # this is already averaged by n-t
294 | total_loss = -log_factorial(ops_len) + ops_len * total_loss / n_actions
295 |
296 | return total_loss, emission_loss, location_loss, operation_loss, vocabulary_loss
297 |
298 | def compute_loss(self, input_ids, attention_mask, actions, prefix_lengths, ops_lengths,
299 | encoder_states=None, encoder_mask=None, token_type_ids=None):
300 | #TODO: get rid of prefix_lenghts (outdated)
301 | """
302 | Computes the loss for a batch. Expands the batch when there are multiple actions to compute loss over all actions.
303 | Thus, for each entry in the batch, creates N entry for each action from the oracle that the model should predict
304 | """
305 | emission_out, location_out, operation_out, vocabulary_out = self.forward(input_ids, attention_mask,
306 | encoder_states=encoder_states, encoder_mask=encoder_mask, token_type_ids=token_type_ids)[:4]
307 |
308 | # if np.isnan(emission_out.detach().cpu()).any():
309 | # pdb.set_trace()
310 |
311 | batch_size = location_out.size(0)
312 | batch_loss, emission_loss, location_loss, operation_loss, vocabulary_loss = [
313 | torch.tensor(0, dtype=torch.double, device=location_out.device)]*5 # create tensors to track loss
314 | # the number of actions that the oracle returns for each state could differ, so we have to iterate over the batch
315 | for e_out, l_out, o_out, v_out, (e_idxs, l_idxs, o_idxs, v_idxs), p_len, o_len in zip(
316 | emission_out, location_out, operation_out, vocabulary_out, actions, prefix_lengths, ops_lengths):
317 | n_ops = l_idxs.size(0)
318 | instance_loss, e_loss, l_loss, o_loss, v_loss = self.compute_instance_loss(
319 | # expand by number of operations that the oracle returned for these inputs
320 | e_out.unsqueeze(0).expand(n_ops),
321 | # prefix here is because we are still including a source canvas in the model inputs, which differs for each instance
322 | l_out[p_len:-1,...].unsqueeze(0).expand(n_ops, -1),
323 | o_out[p_len:-1,...].unsqueeze(0).expand(n_ops, -1, -1),
324 | v_out[p_len:-1,...].unsqueeze(0).expand(n_ops, -1, -1, -1),
325 | e_idxs, l_idxs, o_idxs, v_idxs, o_len
326 | )
327 |
328 | batch_loss, emission_loss, location_loss, operation_loss, vocabulary_loss = [
329 | loss + l / batch_size for loss,l in zip(
330 | [batch_loss, emission_loss, location_loss, operation_loss, vocabulary_loss],
331 | [instance_loss, e_loss, l_loss, o_loss, v_loss])
332 | ] # aggregate loss
333 |
334 | metrics = {
335 | 'loss': batch_loss.item(),
336 | 'em_loss': emission_loss.item(),
337 | 'loc_loss': location_loss.item(),
338 | 'op_loss': operation_loss.item(),
339 | 'voc_loss': vocabulary_loss.item()
340 | }
341 |
342 | return batch_loss, metrics
343 |
344 | def prep_canvas(self, canvas, type_ids=None):
345 | """
346 | Prepares canvass to feed into the model
347 | """
348 | ids = self.tokenizer.convert_tokens_to_ids(canvas.tokens) # convert tokens to ids
349 | ids = [self.tokenizer.cls_token_id] + ids + [self.tokenizer.sep_token_id]
350 | type_ids = [0] + canvas.type_ids + [0]
351 | return ids, type_ids, 0
352 |
353 | def prep_canvases(self, canvases, device=torch.device('cpu')):
354 | """
355 | Preps canvases into a batch to feed into the model
356 | """
357 | canvases, type_ids, prefix_lengths = zip(*[self.prep_canvas(c) for c in canvases])
358 | max_length = np.max([len(c) for c in canvases])
359 | input_ids = torch.zeros((len(canvases), max_length), dtype=torch.long, device=device)
360 | attention_mask = torch.zeros(input_ids.shape, dtype=torch.int32, device=device)
361 | token_type_ids = torch.zeros_like(attention_mask, device=device)
362 |
363 | for i,(c, tids) in enumerate(zip(canvases, type_ids)):
364 | input_ids[i, :len(c)] = torch.tensor(c)
365 | token_type_ids[i, :len(c)] = torch.tensor(tids)
366 | attention_mask[i, :len(c)] = 1
367 |
368 | return input_ids, token_type_ids, attention_mask, prefix_lengths
369 |
370 | def prep_action(self, action):
371 | """
372 | Prep action for model
373 | """
374 | emission_idx = action[0]
375 | if emission_idx == 0:
376 | location_idx = action[1] + 1 # + 1 for sentinel token
377 | operation_idx = action[2]
378 | token_idx = action[3] if action[3] is not None else 0
379 | else:
380 | location_idx, operation_idx, token_idx = 0,0,0 # dummy values
381 | return emission_idx, location_idx, operation_idx, token_idx
382 |
383 | def prep_actions(self, actions, device=torch.device('cpu')): # preps sampled actions for a SINGLE INSTANCE
384 | """
385 | Prep set of actions return by the oracle for the model. This is a single state, not for a batch
386 | """
387 | return [torch.tensor(a, dtype=torch.long, device=device) for a in zip(
388 | *[self.prep_action(a) for a in actions])]
389 |
390 | def prep_batch(self, batch, device=torch.device('cpu'), **kwargs):
391 | """
392 | Prepares whole batch to feed into the model
393 | """
394 | if self.training_noise > 0 and self.vocab_sampler is None:
395 | raise ValueError('vocab sampler cannot be None when using training noise')
396 | canvases, actions, ops_lengths = zip(*[sample_trajectory(
397 | alignment, self.tokenizer, self.vocab_sampler, noise_prob = self.training_noise, **kwargs
398 | ) for alignment in batch])
399 |
400 | input_ids, token_type_ids, attention_mask, prefix_lengths = self.prep_canvases(
401 | canvases, device=device)
402 |
403 | actions = [self.prep_actions(a, device) for a in actions]
404 | # no encoder input in this version
405 | encoder_states, encoder_mask = None, None
406 |
407 | return input_ids, attention_mask, actions, prefix_lengths, ops_lengths, encoder_states, encoder_mask, token_type_ids
408 |
409 | def move_batch(self, batch, device):
410 | input_ids, attention_mask, actions, prefix_lengths, ops_lengths, encoder_states, encoder_mask, token_type_ids = batch
411 | input_ids, attention_mask, token_type_ids = input_ids.to(device), attention_mask.to(device), token_type_ids.to(device)
412 | #encoder_states, encoder_mask = encoder_states.to(device), encoder_mask.to(device)
413 | actions = [(a[0].to(device), a[1].to(device), a[2].to(device), a[3].to(device)) for a in actions]
414 | return input_ids, attention_mask, actions, prefix_lengths, ops_lengths, encoder_states, encoder_mask, token_type_ids
415 |
416 | ### Decoding
417 |
418 | def convert_action(self, action):
419 | """
420 | Get action from model representation of action
421 | """
422 | stop, ref_idx, op_idx, tok_idx = action
423 | if not ref_idx is None: ref_idx -= 1
424 | if not tok_idx is None:
425 | tok_idx = self.tokenizer.convert_ids_to_tokens(int(tok_idx))
426 | return stop, ref_idx, op_idx, tok_idx
427 |
428 | @staticmethod
429 | def sample_from_logits(actions, logps):
430 | action_ps = np.exp(np.asarray(logps))
431 | action_ps = action_ps / np.sum(action_ps)
432 | action_idx = draw_multi(action_ps)
433 | return actions[action_idx], logps[action_idx]
434 |
435 | def forward_canvases(self, canvases, device=torch.device('cpu'), move_to_cpu=False):
436 | """
437 | Forward pass taking canvases as inputs
438 | """
439 | input_ids, token_type_ids, attention_mask, prefix_lengths = self.prep_canvases(
440 | canvases, device=device)
441 | with torch.no_grad():
442 | emission_out, location_out, operation_out, vocabulary_out = self.forward(
443 | input_ids, attention_mask, token_type_ids = token_type_ids)[:4]
444 | if move_to_cpu:
445 | emission_out, location_out, operation_out, vocabulary_out =\
446 | emission_out.cpu(), location_out.cpu(), operation_out.cpu(), vocabulary_out.cpu()
447 | B = emission_out.shape[0]
448 | out = []
449 | for i in range(B):
450 | plen = prefix_lengths[i]
451 | e_out, l_out, op_out, v_out = emission_out[i], location_out[i], operation_out[i], vocabulary_out[i]
452 | # yield e_out, l_out[plen:], op_out[plen:], v_out[plen:]
453 | out.append((e_out, l_out[plen:], op_out[plen:], v_out[plen:]))
454 | return out
455 |
456 | def forward_canvas(self, canvas, device=torch.device('cpu')):
457 | # return list(self.forward_canvases(canvas, device=device))
458 | return self.forward_canvases(canvas, device=device)
459 |
460 | def get_topk_actions(self, canvas, device=torch.device('cpu'), top_k=1, **kwargs):
461 | """
462 | Get top actions from model predictions
463 | """
464 | input_ids, token_type_ids, attention_mask, prefix_lengths = self.prep_canvases(
465 | [canvas], device=device)
466 | plen = prefix_lengths[0]
467 | with torch.no_grad():
468 | emission_out, location_out, operation_out, vocabulary_out = self.forward(
469 | input_ids, attention_mask, token_type_ids = token_type_ids)[:4]
470 | actions, logps, stop_lp = get_topk_action_logps(
471 | emission_out, location_out[:, plen:], operation_out[:, plen:],
472 | vocabulary_out[:, plen:], top_k=top_k, **kwargs)
473 | return actions, logps, stop_lp
474 |
475 | @staticmethod
476 | def ancestral_sample(emission_out, location_out, operation_out, vocabulary_out):
477 | """
478 | Ancestral sampling: sample next action based on previous predictions
479 | """
480 | stop_p = torch.sigmoid(emission_out)
481 | if np.random.random() <= stop_p:
482 | return (1, None, None, None)
483 |
484 | def sample_from_out(out):
485 | prob = torch.softmax(out, dim=-1)
486 | return torch.multinomial(prob, 1).item()
487 |
488 | location = sample_from_out(location_out)
489 | op = sample_from_out(operation_out[location])
490 | if op == 2: return (0, location, op, None)
491 | token = sample_from_out(vocabulary_out[location,op])
492 | return (0, location, op, token)
493 |
494 | @staticmethod
495 | def get_topk_action_logps(stop_out, location_out, action_out, vocab_out, top_k=10, exclude_stop=False): # warning: this will modify the outputs from the model!
496 | """
497 | Searches the model's output distribution for the top k actions. The actions the model can take
498 | are organized in a tree. The first level corresponds to deciding whether or not to edit, then where to
499 | edit, then what operation to take, and finally what token to use for insertions/substitutions.
500 | We search the tree mainting a list of the most likely actions, and since the probability of a node is always
501 | lower than that of its parent (since the probabilities multiply and are all leq 1), we can easily determine
502 | which nodes to explore by exploring them in sorted order.
503 | """
504 | class ActionBuffer:
505 |
506 | def __init__(self, top_k):
507 | self.top_k = top_k
508 | self.actions = [None] * top_k
509 | self.logps = [-1e20] * top_k
510 |
511 | def insert(self, action, logp):
512 | # find insertion index
513 | for i in range(self.top_k-1, -2, -1):
514 | if i == -1:
515 | break
516 | if logp < self.logps[i]:
517 | break
518 |
519 | self.logps = self.logps[:i+1] + [logp] + self.logps[i+1:]
520 | self.actions = self.actions[:i+1] + [action] + self.actions[i+1:]
521 |
522 | self.logps, self.actions = self.logps[:self.top_k], self.actions[:self.top_k]
523 |
524 | def min_lp(self):
525 | return self.logps[-1]
526 |
527 | # def iter_torch_sort(sort_obj):
528 | # """
529 | # Helper function for iterating over tensors. Much more efficient than directly iterating over tensor
530 | # since we often don't iterate over all the values
531 | # """
532 | # values, indices = sort_obj.values, sort_obj.indices
533 | # # values = values.cpu()
534 | # for i in range(values.shape[0]):
535 | # yield values[i], indices[i]
536 |
537 | def iter_np_sort(sort_obj, descending=True):
538 | sorted_indices = np.argsort(sort_obj)
539 | if descending: sorted_indices = reversed(sorted_indices)
540 | for i in sorted_indices: #desce
541 | yield sort_obj[i], i
542 |
543 | # always include first index in frozen_idxs since cannot sub or del the CLS token
544 | # buffer for actions and their logps
545 | action_buffer = ActionBuffer(top_k)
546 |
547 | levels = (stop_out, location_out, action_out, vocab_out)
548 | level_logps = np.array([0,0,0,0], dtype=np.float32)
549 | # stopping probabilities
550 | stop_p = torch.sigmoid(stop_out)
551 | stop_lp = torch.log(stop_p)
552 | if not exclude_stop:
553 | action_buffer.insert((1, None, None, None), stop_lp.item())
554 | level_logps[0] = np.log(1-stop_p)
555 |
556 | def search_level(level, idxs):
557 | if level > 3: return
558 | out = levels[level].unsqueeze(0)[idxs]
559 | logprobs = F.log_softmax(out, dim=-1)
560 | for lprob, i in iter_np_sort(logprobs.numpy()):
561 | # for lprob, i in iter_torch_sort(torch.sort(logprobs, descending=True)):
562 | level_logps[level] = lprob
563 | cumlp = level_logps[:level+1].sum()
564 | if cumlp <= action_buffer.min_lp():
565 | return
566 | if level == 3:
567 | action_buffer.insert(idxs + (i,), cumlp)
568 | elif level == 2 and i==2: #deletion action is special because it doesn't specify a token
569 | action_buffer.insert(idxs + (i, None,), cumlp)
570 | else:
571 | search_level(level+1, idxs + (i,))
572 |
573 | search_level(1, (0,))
574 |
575 | return action_buffer.actions, action_buffer.logps, stop_lp.item()
576 |
577 | @staticmethod
578 | def sample(model_out, top_k=None):
579 | if top_k is None:
580 | action = BertEditor.ancestral_sample(*model_out)
581 | else:
582 | actions, logps, stop_lp = BertEditor.get_topk_action_logps(*model_out, top_k=top_k)
583 | action, _ = BertEditor.sample_from_logits(actions, logps)
584 | return action
585 |
586 | def edit(self, canvas, top_k=None, device=torch.device('cpu')):
587 | """
588 | Make an edit to a canvas
589 | """
590 | model_out = next(self.forward_canvas_([canvas], device=device, move_to_cpu=True))
591 | action = BertEditor.sample(model_out, top_k=top_k)
592 | action = self.convert_action(action)
593 | if action[0]: return True
594 | canvas.operate(*action[1:], agent=0)
595 | return False
596 |
597 | # def decode_loop(self, canvases, decode_func, state=None, max_batch_tokens=2048, parallel=False,
598 | # n_processes=None, device=torch.device('cpu'), queue_size=2000, return_idx=False):
599 | # ''' parallelized decoding loop '''
600 | #
601 | # def enqueue(canvas, state, iter_idx):
602 | # input_queue.put((canvas, iter_idx, state))
603 | #
604 | # input_iter = iter(canvases)
605 | # start_state = state
606 | # def add_input(idx):
607 | # # add input to queue
608 | # c = next(input_iter, None)
609 | # if c is None:
610 | # return True
611 | # else:
612 | # enqueue(c.copy(), start_state, idx)
613 | # return False
614 | #
615 | # def inputs_generator(batch):
616 | # canvases = [b[0] for b in batch]
617 | # for i, model_out in enumerate(self.forward_canvas_(canvases, device=device, move_to_cpu=True)):
618 | # canvas, idx, state = batch[i]
619 | # yield model_out, canvas, state, idx
620 | #
621 | # iter_idx = 0
622 | # finished = False
623 | # input_queue = PriorityQueue(maxsize=queue_size) # priority queue groups similar sized inputs together
624 | #
625 | # # fill up queue
626 | # while not input_queue.full():
627 | # finished = add_input(iter_idx)
628 | # iter_idx += 1
629 | # if finished: break
630 | #
631 | # if parallel:
632 | # if n_processes is None: n_processes = os.cpu_count()
633 | # else: n_processes = min(os.cpu_count(), n_processes)
634 | ## print(f'Using {n_processes} processes')
635 | # worker_pool = mp.Pool(processes=n_processes)
636 | #
637 | # start_time = time.time()
638 | #
639 | # try:
640 | # while not input_queue.empty():
641 | # # prepare batch
642 | # batch = []
643 | # batch_len = 0
644 | # batch_n_tokens = 0
645 | # while True:
646 | # if input_queue.empty(): break
647 | # # get input
648 | # inp = input_queue.get()
649 | # canvas, idx, state = inp
650 | # n_tokens = len(canvas)
651 | # # check if it can fit in batch
652 | # new_batch_len = batch_len + 1
653 | # new_batch_n_tokens = max(batch_n_tokens, n_tokens) # max because inputs will get padded
654 | # # if it doesn't fit, requeue and break
655 | # if new_batch_len * new_batch_n_tokens > max_batch_tokens:
656 | # enqueue(canvas, state, idx) # requeue
657 | # break
658 | # # otherwise add to batch
659 | # else:
660 | # batch.append(inp)
661 | # batch_len = new_batch_len
662 | # batch_n_tokens = new_batch_n_tokens
663 | #
664 | # batch_size = len(batch)
665 | # if batch_size == 0:
666 | # raise RuntimeError('Unable to fit any inputs into batch. Try increasing max batch tokens')
667 | # inputs_feed = list(inputs_generator(batch))
668 | # if parallel:
669 | # chunk_size = batch_size // n_processes
670 | ## outputs_generator = worker_pool.imap_unordered(decode_func, inputs_feed, chunk_size)
671 | # outputs_generator = [worker_pool.apply_async(decode_func, args=(inp,)) for inp in inputs_feed]
672 | # else:
673 | # outputs_generator = map(decode_func, inputs_feed)
674 | # did_timeout = False
675 | # for out in outputs_generator:
676 | # if parallel:
677 | # try:
678 | # out = out.get(MP_TIMEOUT)
679 | # except mp.TimeoutError:
680 | # did_timeout = True
681 | # if return_idx:
682 | # yield None, None
683 | # else:
684 | # yield None
685 | # continue
686 | # canvas, action, stop, state, idx = out
687 | # if stop:
688 | # if return_idx:
689 | # yield canvas, idx
690 | # else:
691 | # yield canvas
692 | # if not finished:
693 | # finished = add_input(iter_idx)
694 | # iter_idx += 1
695 | # else:
696 | # action = self.convert_action(action)
697 | # canvas.operate(*action[1:], agent=0)
698 | # enqueue(canvas.copy(), state, idx)
699 | # if did_timeout and parallel: #reset the worker pool to avoid accumulating dead processes
700 | # print('restarting pool')
701 | # worker_pool.terminate()
702 | # worker_pool = mp.Pool(processes=n_processes)
703 | ## assert finished
704 | # finally:
705 | # if parallel:
706 | # worker_pool.terminate()
707 | #
708 | # delta = time.time() - start_time
709 | # delta = int(1000*delta)
710 | ## print(f'Took {delta} ms')
711 | ## return out_canvases
712 | # return
713 |
714 | @staticmethod
715 | def canvas_len(canvas):
716 | return len(canvas)
717 |
718 | def decode_(self, model_out, canvas, state):
719 | """
720 | Regular decoding with sampling
721 | """
722 | # if np.random.random() <= 0.01:
723 | # time.sleep(9999)
724 | # model_out, canvas, state, idx = inputs
725 | top_k, max_iter, cur_iter = state
726 | state = top_k, max_iter, cur_iter + 1
727 | action = BertEditor.sample(model_out, top_k=top_k)
728 | if action[0] or cur_iter > max_iter:
729 | # return canvas, action, True, state, idx
730 | return canvas, state, True
731 | else:
732 | action = self.convert_action(action)
733 | canvas.operate(*action[1:], agent=0)
734 | # return canvas, action, False, state, idx
735 | return canvas, state, False
736 |
737 | def batch_decode(self, canvases, top_k=10, max_iter=32, **kwargs):
738 | state = top_k, max_iter, 0
739 | for out in self.decode_loop(canvases, self.decode_, state=state, **kwargs):
740 | yield out
741 |
742 | def decode(self, canvas, **kwargs):
743 | return next(self.batch_decode([canvas], **kwargs))
744 |
745 | def depth_decode_(self, model_out, canvas, state):
746 | """
747 | Depth decoding: ignores stop actions, continus until reaching a stopping condition. Then
748 | returns whichever canvas had the highest stopping probability
749 | """
750 | # model_out, canvas, state, idx = inputs
751 | top_k, stop_threshold, max_iter, cur_iter, top_canvas, top_lp = state
752 |
753 | actions, logps, stop_lp = BertEditor.get_topk_action_logps(*model_out, top_k=top_k, exclude_stop=True)
754 | if stop_lp > top_lp:
755 | top_lp = stop_lp
756 | top_canvas = canvas.copy()
757 | state = (top_k, stop_threshold, max_iter, cur_iter + 1, top_canvas, top_lp)
758 | if np.exp(stop_lp) >= stop_threshold or cur_iter > max_iter:
759 | return top_canvas, state, True
760 | # return top_canvas, None, True, state, idx
761 | else:
762 | action, _ = BertEditor.sample_from_logits(actions, logps)
763 | action = self.convert_action(action)
764 | canvas.operate(*action[1:], agent=0)
765 | return canvas, state, False
766 | # return canvas, action, False, state, idx
767 |
768 | def batch_depth_decode(self, canvases, top_k=10, stop_threshold=0.95, max_iter=64, **kwargs):
769 | state = top_k, stop_threshold, max_iter, 0, None, -1e20
770 | for out in self.decode_loop(canvases, self.depth_decode_, state=state, **kwargs):
771 | yield out
772 |
773 | def depth_decode(self, canvas, **kwargs):
774 | return next(self.batch_depth_decode([canvas], **kwargs))
775 | #
776 | # def depth_decode(self, canvas, top_k=20, device=torch.device('cpu'), max_iter=64, stop_threshold=0.95, len_normalize=False, only_stop_lp=False):
777 | # canvases = PriorityQueue()
778 | # running_logp = 0
779 | # canvas = canvas.copy()
780 | # for i in range(max_iter):
781 | # actions, logps, stop_lp = self.get_topk_actions(canvas, device, top_k=top_k, exclude_stop=True)
782 | #
783 | # # add completed canvas
784 | # if only_stop_lp:
785 | # canvases.put((-stop_lp, canvas))
786 | # else:
787 | # canvases.put((-(running_logp + stop_lp), canvas))
788 | #
789 | # #stopping condition
790 | # if np.exp(stop_lp) >= stop_threshold: break
791 | #
792 | # # make new canvas
793 | # canvas = canvas.copy()
794 | # action, lp = self.sample_action_(actions, logps)
795 | # action = self.convert_action(action)
796 | # canvas.operate(*action[1:], agent=0)
797 | # running_logp += lp
798 | # if len_normalize: running_logp += np.log(i+1)
799 | # return canvases.get()[1]
800 |
801 | def beam_decode(self, canvas, top_k=1, beam_width=10, max_length=32, len_normalize=False, device=torch.device('cpu')):
802 | """
803 | Beam decoding, not actually used
804 | """
805 |
806 | if top_k > beam_width:
807 | raise ValueError('top_k cannot be greater than beam_width')
808 |
809 | decoded_batch = []
810 | endnodes = PriorityQueue()
811 |
812 | node = BeamSearchNode(None, canvas, False, 0, 0)
813 | nodes = PriorityQueue()
814 |
815 | nodes.put((-node.eval(len_normalize=len_normalize), node))
816 | qsize = 1
817 |
818 | for i in range(max_length):
819 | if nodes.empty(): break
820 | nextnodes = PriorityQueue()
821 | # go through nodes at current depth
822 | assert nodes.qsize() <= beam_width, nodes.qsize()
823 | while not nodes.empty():
824 | score, n = nodes.get()
825 | canvas = n.canvas
826 |
827 | if n.stop:
828 | endnodes.put((score, n))
829 | continue
830 |
831 | actions, logps, _ = self.get_topk_actions(canvas, device, top_k=beam_width)
832 | for a,lp in zip(actions, logps):
833 | edited_canvas = canvas.copy()
834 | a = self.convert_action(a)
835 | if not a[0]: edited_canvas.operate(*a[1:], agent=0)
836 |
837 |
838 | node = BeamSearchNode(n, edited_canvas, a[0], n.logp + lp, n.len + 1)
839 | score = -node.eval(len_normalize=len_normalize)
840 | nextnodes.put((score, node))
841 |
842 | for _ in range(beam_width):
843 | if nextnodes.empty(): break
844 | score, nn = nextnodes.get()
845 | nodes.put((score, nn))
846 |
847 | if endnodes.empty():
848 | for _ in range(top_k):
849 | score,n = nodes.get()
850 | endnodes.put((score, n))
851 |
852 | for _ in range(top_k):
853 | if endnodes.empty(): break
854 | score, n = endnodes.get()
855 | decoded_batch.append(n.canvas)
856 |
857 | return decoded_batch
858 |
859 | class BartEncoderTType(BartEncoder):
860 |
861 | """
862 | Token type ids for BART (not supported natively)
863 | """
864 |
865 | def __init__(self, config, embed_tokens):
866 | super().__init__(config, embed_tokens)
867 |
868 | self.token_type_embeddings = torch.nn.Embedding(config.n_type_ids, config.d_model)
869 |
870 | def forward(self, input_ids=None, inputs_embeds=None, token_type_ids=None, **kwargs):
871 |
872 | if input_ids is not None and inputs_embeds is not None:
873 | raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
874 | elif input_ids is not None:
875 | input_shape = input_ids.size()
876 | input_ids = input_ids.view(-1, input_shape[-1])
877 | elif inputs_embeds is not None:
878 | input_shape = inputs_embeds.size()[:-1]
879 | else:
880 | raise ValueError("You have to specify either input_ids or inputs_embeds")
881 |
882 | if token_type_ids is None:
883 | token_type_ids = torch.zeros_like(input_ids)
884 |
885 | if inputs_embeds is None:
886 | inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
887 |
888 | token_type_embeds = self.token_type_embeddings(token_type_ids)
889 | inputs_embeds += token_type_embeds
890 |
891 | return super().forward(inputs_embeds=inputs_embeds, **kwargs)
892 |
893 | class BartModelTType(BartModel):
894 |
895 | def __init__(self, config):
896 | super(BartModel, self).__init__(config)
897 |
898 | padding_idx, vocab_size = config.pad_token_id, config.vocab_size
899 | self.shared = torch.nn.Embedding(vocab_size, config.d_model, padding_idx)
900 |
901 | self.encoder = BartEncoderTType(config, self.shared)
902 | self.decoder = BartDecoder(config, self.shared)
903 |
904 | def forward(self, input_ids=None, attention_mask=None, head_mask=None, inputs_embeds=None,
905 | output_attentions=None, output_hidden_states=None, return_dict=None,
906 | token_type_ids=None, encoder_outputs = None, **kwargs):
907 |
908 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
909 | output_hidden_states = (
910 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
911 | )
912 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict
913 |
914 | if encoder_outputs is None:
915 | encoder_outputs = self.encoder(
916 | input_ids=input_ids,
917 | attention_mask=attention_mask,
918 | head_mask=head_mask,
919 | inputs_embeds=inputs_embeds,
920 | output_attentions=output_attentions,
921 | output_hidden_states=output_hidden_states,
922 | return_dict=return_dict,
923 | token_type_ids=token_type_ids
924 | )
925 | elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
926 | encoder_outputs = BaseModelOutput(
927 | last_hidden_state=encoder_outputs[0],
928 | hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
929 | attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
930 | )
931 |
932 | return super().forward(input_ids=input_ids, attention_mask=attention_mask, head_mask=head_mask,
933 | inputs_embeds=inputs_embeds, output_attentions=output_attentions,
934 | output_hidden_states=output_hidden_states, return_dict=return_dict,
935 | encoder_outputs=encoder_outputs, **kwargs)
936 |
937 | def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int):
938 | """
939 | Shift input ids one token to the right.
940 | """
941 | shifted_input_ids = input_ids.new_zeros(input_ids.shape)
942 | shifted_input_ids[:, 1:] = input_ids[:, :-1].clone()
943 | shifted_input_ids[:, 0] = decoder_start_token_id
944 |
945 | if pad_token_id is None:
946 | raise ValueError("self.model.config.pad_token_id has to be defined.")
947 | # replace possible -100 values in labels by `pad_token_id`
948 | shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)
949 |
950 | return shifted_input_ids
951 |
952 | class BartForConditionalGenerationTType(BartForConditionalGeneration):
953 |
954 | """
955 | Derived class to add in the token type ids
956 | """
957 |
958 | def __init__(self, config):
959 | super().__init__(config)
960 | self.model = BartModelTType(config)
961 | self.register_buffer("final_logits_bias", torch.zeros((1, self.model.shared.num_embeddings)))
962 | self.lm_head = torch.nn.Linear(config.d_model, self.model.shared.num_embeddings, bias=False)
963 |
964 | self.init_weights()
965 |
966 | def forward(
967 | self,
968 | input_ids=None,
969 | attention_mask=None,
970 | decoder_input_ids=None,
971 | decoder_attention_mask=None,
972 | head_mask=None,
973 | decoder_head_mask=None,
974 | encoder_outputs=None,
975 | past_key_values=None,
976 | inputs_embeds=None,
977 | decoder_inputs_embeds=None,
978 | labels=None,
979 | use_cache=None,
980 | output_attentions=None,
981 | output_hidden_states=None,
982 | return_dict=None,
983 | token_type_ids=None,
984 | **kwargs # why not absorb all these arguments in kwargs??
985 | ):
986 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict
987 |
988 | if labels is not None:
989 | if decoder_input_ids is None and decoder_inputs_embeds is None:
990 | decoder_input_ids = shift_tokens_right(
991 | labels, self.config.pad_token_id, self.config.decoder_start_token_id
992 | )
993 |
994 | outputs = self.model(
995 | input_ids,
996 | attention_mask=attention_mask,
997 | decoder_input_ids=decoder_input_ids,
998 | encoder_outputs=encoder_outputs,
999 | decoder_attention_mask=decoder_attention_mask,
1000 | head_mask=head_mask,
1001 | decoder_head_mask=decoder_head_mask,
1002 | past_key_values=past_key_values,
1003 | inputs_embeds=inputs_embeds,
1004 | decoder_inputs_embeds=decoder_inputs_embeds,
1005 | use_cache=use_cache,
1006 | output_attentions=output_attentions,
1007 | output_hidden_states=output_hidden_states,
1008 | return_dict=return_dict,
1009 | token_type_ids=token_type_ids,
1010 | **kwargs
1011 | )
1012 | lm_logits = self.lm_head(outputs[0]) + self.final_logits_bias
1013 |
1014 | masked_lm_loss = None
1015 | if labels is not None:
1016 | loss_fct = torch.nn.CrossEntropyLoss()
1017 | masked_lm_loss = loss_fct(lm_logits.view(-1, self.config.vocab_size), labels.view(-1))
1018 |
1019 | if not return_dict:
1020 | output = (lm_logits,) + outputs[1:]
1021 | return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
1022 |
1023 | return Seq2SeqLMOutput(
1024 | loss=masked_lm_loss,
1025 | logits=lm_logits,
1026 | past_key_values=outputs.past_key_values,
1027 | decoder_hidden_states=outputs.decoder_hidden_states,
1028 | decoder_attentions=outputs.decoder_attentions,
1029 | cross_attentions=outputs.cross_attentions,
1030 | encoder_last_hidden_state=outputs.encoder_last_hidden_state,
1031 | encoder_hidden_states=outputs.encoder_hidden_states,
1032 | encoder_attentions=outputs.encoder_attentions,
1033 | )
1034 |
1035 | class BartS2SEditor(torch.nn.Module, ParallelDecodingMixin):
1036 | """
1037 | S2S model
1038 | """
1039 |
1040 | def __init__(self, align_model, tokenizer, model_file='facebook/bart-base', from_pretrained=True,
1041 | alignment_baseline_score=0.2):
1042 | super().__init__()
1043 |
1044 | self.align_model = align_model
1045 | for p in self.align_model.parameters():
1046 | p.requires_grad = False
1047 | self.alignment_baseline_score=alignment_baseline_score
1048 | self.tokenizer = tokenizer
1049 | vocab_size = len(self.tokenizer)
1050 |
1051 | config = AutoConfig.from_pretrained(model_file)
1052 | config.n_type_ids = 5
1053 |
1054 | if from_pretrained:
1055 | self.bart_model = BartForConditionalGenerationTType.from_pretrained(
1056 | model_file, config=config)
1057 | else:
1058 | self.bart_model = BartForConditionalGenerationTType.from_config(config)
1059 | self.bart_model.resize_token_embeddings(vocab_size)
1060 |
1061 | def forward(self, **kwargs):
1062 | return self.bart_model(**kwargs)
1063 |
1064 | def compute_loss(self, input_ids, token_type_ids, attention_mask, labels):
1065 | model_loss = self.bart_model(input_ids=input_ids, attention_mask=attention_mask, labels=labels,
1066 | token_type_ids=token_type_ids).loss
1067 | metrics = {'loss': model_loss.item()}
1068 | return model_loss, metrics
1069 |
1070 | def prep_canvas(self, canvas):
1071 | tokens, type_ids = canvas.tokens, canvas.type_ids
1072 | token_ids = [self.tokenizer.bos_token_id] + \
1073 | self.tokenizer.convert_tokens_to_ids(tokens) +\
1074 | [self.tokenizer.eos_token_id]
1075 | type_ids = [0] + type_ids + [0]
1076 | assert len(token_ids) == len(type_ids)
1077 | return token_ids, type_ids
1078 |
1079 | def prep_canvases(self, canvases, device=torch.device('cpu')):
1080 | token_ids, type_ids = zip(*[self.prep_canvas(c) for c in canvases])
1081 | max_length = np.max([len(ids) for ids in token_ids])
1082 | input_ids = torch.ones((len(canvases), max_length), dtype=torch.long,
1083 | device=device) * self.tokenizer.pad_token_id
1084 | attention_mask = torch.zeros(input_ids.shape, dtype=torch.int32, device=device)
1085 | token_type_ids = torch.zeros_like(attention_mask)
1086 | for i in range(len(canvases)):
1087 | input_ids[i, :len(token_ids[i])] = torch.tensor(token_ids[i])
1088 | token_type_ids[i, :len(token_ids[i])] = torch.tensor(type_ids[i])
1089 | attention_mask[i, :len(token_ids[i])] = 1
1090 | return input_ids, attention_mask, token_type_ids
1091 |
1092 | def prep_target(self, target):
1093 | target = [self.tokenizer.bos_token_id] +\
1094 | self.tokenizer.convert_tokens_to_ids(target)+\
1095 | [self.tokenizer.eos_token_id]
1096 | return target
1097 |
1098 | def prep_targets(self, targets, device=torch.device('cpu')):
1099 | token_ids = [self.prep_target(t) for t in targets]
1100 | max_length = np.max([len(ids) for ids in token_ids])
1101 | input_ids = torch.ones((len(targets), max_length), dtype=torch.long,
1102 | device=device) * -100
1103 | for i in range(len(targets)):
1104 | input_ids[i, :len(token_ids[i])] = torch.tensor(token_ids[i])
1105 | return input_ids
1106 |
1107 | def prep_gen(self, gen):
1108 | return self.tokenizer.convert_tokens_to_ids(gen)
1109 |
1110 | def prep_generations(self, generations, device=torch.device('cpu')):
1111 | token_ids = [self.prep_gen(g) for g in generations]
1112 | max_length = np.max([len(ids) for ids in token_idxs])
1113 | input_ids = torch.ones((len(targets), max_length), dtype=torch.long,
1114 | device=device) * self.tokenizer.pad_token_id
1115 | input_lengths = []
1116 | for i in range(len(generations)):
1117 | input_ids[i, :len(token_ids[i])] = torch.tensor(token_ids[i])
1118 | input_lengths.append(len(token_ids[i]))
1119 | return input_ids, input_lengths
1120 |
1121 | def prep_batch(self, batch, device=torch.device('cpu'), **kwargs):
1122 | canvases = [a.get_source_canvas() for a in batch]
1123 | input_ids, attention_mask, type_ids = self.prep_canvases(canvases, device=device)
1124 | targets = [a.get_target_tokens() for a in batch]
1125 | labels = self.prep_targets(targets, device=device)
1126 |
1127 | return input_ids, type_ids, attention_mask, labels
1128 |
1129 | def move_batch(self, batch, device):
1130 | input_ids, type_ids, attention_mask, labels = batch
1131 | return input_ids.to(device), type_ids.to(device), attention_mask.to(device), labels.to(device)
1132 |
1133 | # ====== DECODING ============
1134 |
1135 | # @staticmethod
1136 | # def canvas_len(canvas):
1137 | # canvas, gen = canvas
1138 | # return len(canvas)
1139 | #
1140 | # def forward_canvases(self, canvases, device=torch.device('cpu'), move_to_cpu=True):
1141 | # batch_size = len(canvases)
1142 | # canvases, gen = zip(*canvases)
1143 | # input_ids, attention_mask, type_ids = self.prep_canvases(canvases, device=device)
1144 | # gen_ids, gen_lengths = self.prep_generations(gen, device=device)
1145 | # with torch.no_grad():
1146 | # model_out = self.bart_model(input_ids=input_ids, attention_mask=attention_mask,
1147 | # token_type_ids=type_ids, decoder_input_ids=gen_ids)
1148 | # out = []
1149 | # for i in range(batch_size):
1150 |
1151 |
1152 | def batch_generate(self, canvases, device=torch.device('cpu'), **kwargs):
1153 | generations = []
1154 | for c in tqdm.tqdm(canvases):
1155 | input_ids, attention_mask, type_ids = self.prep_canvases([c], device=device)
1156 | # print(input_ids.shape, attention_mask.shape)
1157 | gen = self.bart_model.generate(input_ids=input_ids,
1158 | attention_mask=attention_mask, token_type_ids=type_ids, early_stopping=False,**kwargs)
1159 | # print(gen.shape)
1160 | # for g in gen.cpu().numpy():
1161 | # print(self.tokenizer.decode(g))
1162 | tokens = self.tokenizer.convert_ids_to_tokens(gen[0].cpu().numpy())
1163 | tokens = [t for t in tokens if not t in ('', '', '')] # get rid of extraneous tokens
1164 | # generations.append(self.tokenizer.convert_ids_to_tokens(gen[0][2:-1].cpu().numpy()))
1165 | generations.append(tokens)
1166 | generated_canvases = []
1167 | # print(generations)
1168 | for alignment in batch_align_canvases(canvases, generations, self.align_model, self.tokenizer,
1169 | baseline_score=self.alignment_baseline_score, device=device):
1170 | alignment.push_forward(alignment.get_non_const_ops())
1171 | generated_canvases.append(alignment.get_source_canvas())
1172 | return generated_canvases
1173 |
1174 | def generate(self, canvas, **kwargs):
1175 | return self.batch_generate([canvas], **kwargs)[0]
1176 |
1177 | #def sample_action(stop_out, location_out, action_out, vocab_out, input_ids, sample=False):
1178 | # if sample:
1179 | # stop = int(np.random.random() <= stop_out.item())
1180 | # else:
1181 | # stop = int(stop_out.item() > 0.5)
1182 | # if stop == 1:
1183 | # return stop, None, None, None
1184 | #
1185 | # if sample:
1186 | # refinement_idx = torch.multinomial(F.softmax(location_out, dim=-1), 1).item()
1187 | # else:
1188 | # refinement_idx = torch.argmax(F.softmax(location_out, dim=-1)).item()
1189 | # if refinement_idx > 0:
1190 | # if sample:
1191 | # action_idx = torch.multinomial(F.softmax(action_out[:, refinement_idx, :], dim=-1), 1).item()
1192 | # else:
1193 | # action_idx = torch.argmax(F.softmax(action_out[:, refinement_idx, :], dim=-1)).item()
1194 | # else: # can only insert when selecting the sentinel
1195 | # action_idx = 0
1196 | #
1197 | # if action_idx != 2:
1198 | # if sample:
1199 | # token = torch.multinomial(F.softmax(vocab_out[:, refinement_idx, action_idx, :], dim=-1), 1).squeeze().item()
1200 | # else:
1201 | # top_tokens = torch.argsort(F.softmax(vocab_out[:, refinement_idx, action_idx, :], dim=-1)).squeeze()
1202 | # token = top_tokens[-1].item()
1203 | # if token == input_ids[:, refinement_idx] and action_idx == 1: token = top_tokens[-2].item()
1204 | # else: token = None
1205 | #
1206 | # return stop, refinement_idx, action_idx, token
1207 | #
1208 |
1209 |
1210 | def top_k_top_p_filtering(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')):
1211 | """ Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
1212 | Args:
1213 | logits: logits distribution shape (vocabulary size)
1214 | top_k >0: keep only top k tokens with highest probability (top-k filtering).
1215 | top_p >0.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
1216 | Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
1217 | """
1218 | assert logits.dim() == 1 # batch size 1 for now - could be updated for more but the code would be less clear
1219 | top_k = min(top_k, logits.size(-1)) # Safety check
1220 | if top_k > 0:
1221 | # Remove all tokens with a probability less than the last token of the top-k
1222 | indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
1223 | logits[indices_to_remove] = filter_value
1224 |
1225 | if top_p > 0.0:
1226 | sorted_logits, sorted_indices = torch.sort(logits, descending=True)
1227 | cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
1228 |
1229 | # Remove tokens with cumulative probability above the threshold
1230 | sorted_indices_to_remove = cumulative_probs > top_p
1231 | # Shift the indices to the right to keep also the first token above the threshold
1232 | sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
1233 | sorted_indices_to_remove[..., 0] = 0
1234 |
1235 | indices_to_remove = sorted_indices[sorted_indices_to_remove]
1236 | logits[indices_to_remove] = filter_value
1237 | return logits
1238 |
1239 |
1240 |
--------------------------------------------------------------------------------
/infosol/train.py:
--------------------------------------------------------------------------------
1 | import os
2 | import shutil
3 | import wandb
4 | import pickle
5 | import json
6 | import random
7 | import tqdm
8 | import torch
9 | import itertools
10 | import numpy as np
11 | import argparse
12 | import datetime
13 |
14 | from distutils.util import strtobool
15 | from torch.utils.data import DataLoader
16 | from torch.utils.tensorboard import SummaryWriter
17 | from datasets import DatasetDict, Dataset, load_from_disk, concatenate_datasets, load_dataset
18 | from transformers import BertModel, AutoTokenizer, BartModel
19 | from infosol.models.word_edit_model import BertEditor, sample_trajectory, BartS2SEditor
20 | from infosol.env import WordEditOracle, EditingEnvironment
21 | from infosol.alignment import *
22 |
23 | DATA_DIR = #/path/to/data_dir/
24 | SAVE_DIR = #/path/to/default/save_dir/
25 | IDF_PATH = os.path.join(DATA_DIR, 'misc', 'cnn_bart_idfs.pickle')
26 | TF_PATH = os.path.join(DATA_DIR, 'misc', 'cnn_bart_tfs.pickle')
27 | DATA_PATH = os.path.join(DATA_DIR, 'cnn', 'filtered_bart_64')
28 |
29 | def custom_cat_datasets(data1, data2):
30 | """
31 | Utility to concatenate two huggingface datasets
32 | """
33 | l1, l2 = len(data1), len(data2)
34 | data1 = data1.to_dict()
35 | data2 = data2.to_dict()
36 | keys = set(data1.keys()).union(set(data2.keys()))
37 | cat_data= {}
38 | for k in list(keys):
39 | d1 = data1.get(k)
40 | if d1 is None:
41 | d1 = [None] * l1
42 | d2 = data2.get(k)
43 | if d2 is None:
44 | d2 = [None] * l2
45 | cat_data[k] = d1 + d2
46 | return Dataset.from_dict(cat_data)
47 |
48 | class WBLogger():
49 |
50 | def log(self, metrics):
51 | wandb.log(metrics)
52 |
53 | class Train():
54 |
55 | """
56 | Base training job. Doesn't use any generations from model
57 | """
58 |
59 | def __init__(
60 | self,
61 | config=None,
62 | log_dir=None,
63 | save_dir=SAVE_DIR,
64 | rng=None,
65 | project_name=None,
66 | run_name=None,
67 | accumulation_steps=8,
68 | learning_rate=1e-4,
69 | n_epochs=1,
70 | report_every=50,
71 | val_every=1000,
72 | device=torch.device('cpu'),
73 | resume_from_epoch=0,
74 | keep_token_type_ids=True,
75 | track_gradient_norm=False,
76 | clip_grad_norm=False,
77 | max_grad_norm=2000,
78 | **kwargs):
79 |
80 | self.config = config
81 | self.project_name = project_name
82 | self.run_name = run_name
83 | self.rng = rng
84 | self.save_dir=save_dir
85 | self.model_save_path = os.path.join(self.save_dir, 'WEIGHTS.bin')
86 | self.log_dir=log_dir
87 | self.accumulation_steps = accumulation_steps
88 | self.n_epochs = n_epochs
89 | self.report_every = report_every
90 | self.val_every = val_every
91 | self.device = device
92 | self.resume_from_epoch = resume_from_epoch
93 | self.keep_token_type_ids = keep_token_type_ids
94 | self.track_gradient_norm = track_gradient_norm
95 | self.clip_grad_norm = clip_grad_norm
96 | self.max_grad_norm = max_grad_norm
97 |
98 | # self.logger = SummaryWriter(log_dir = self.log_dir)
99 | self.logger = WBLogger()
100 |
101 | self.kwargs = kwargs
102 |
103 | print("Loading environment")
104 | self.load_env(**kwargs)
105 | print("Loading data")
106 | self.load_data(**kwargs)
107 | print("Loading model")
108 | self.load_model(**kwargs)
109 | print("Loading optimizer")
110 | self.load_optimizer(**kwargs)
111 |
112 | def log(self, logdict, prefix, i, iter_name='iter'):
113 | for n in logdict:
114 | metrics = {
115 | '/'.join((prefix,n)): np.mean(logdict[n]) for n in logdict}
116 | metrics[iter_name] = i
117 | self.logger.log(metrics)
118 | # self.logger.add_scalar('/'.join((prefix, n)),
119 | # np.mean(logdict[n]),
120 | # i)
121 |
122 | def load_env(self,
123 | env_type='bart',
124 | idf_path=IDF_PATH,
125 | sort_ops='sort',
126 | adjacent_ops=False,
127 | avoid_delete=False,
128 | contiguous_edits=False,
129 | complete_words=True,
130 | baseline_score=0.3,
131 | oracle_stop_p=0.25,
132 | n_oracle_hints=-1,
133 | **kwargs):
134 | print(f'n_oracle_hints={n_oracle_hints}')
135 |
136 | with open(idf_path, 'rb') as f:
137 | idf_dict = pickle.load(f)
138 |
139 | if env_type == 'bert':
140 | self.align_model = BertModel.from_pretrained('bert-base-uncased')
141 | self.align_tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')
142 | token_no_space = lambda x: x.startswith('#')
143 | elif env_type == 'bart':
144 | self.align_model = BartModel.from_pretrained('facebook/bart-base').encoder
145 | self.align_tokenizer = AutoTokenizer.from_pretrained('facebook/bart-base')
146 | token_no_space=lambda x: not x.startswith('Ġ')
147 | else:
148 | raise ValueError('Unknown env type: {}'.format(env_type))
149 |
150 | self.oracle = WordEditOracle(self.align_model, self.align_tokenizer, idf_dict,
151 | sort_ops=sort_ops, adjacent_ops=adjacent_ops,
152 | avoid_delete=avoid_delete, baseline_score=baseline_score,
153 | contiguous_edits=contiguous_edits, complete_words=complete_words,
154 | token_no_space=token_no_space)
155 | self.env = EditingEnvironment(self.oracle, oracle_stop_p, n_oracle_edits=n_oracle_hints)
156 |
157 | def load_data(self,
158 | data_path=DATA_PATH,
159 | batch_size=16,
160 | max_train_edits=0,
161 | max_val_edits=1000,
162 | **kwargs):
163 | self.data = load_from_disk(data_path)
164 | self.batch_size = batch_size
165 |
166 | max_train_edits = len(self.data['train']) if max_train_edits == 0 else max_train_edits
167 | max_val_edits = len(self.data['val']) if max_val_edits == 0 else max_val_edits
168 |
169 | self.data['train'] = self.data['train'].shuffle(generator=self.rng).select(list(range(max_train_edits)))
170 | self.data['val'] = self.data['val'].shuffle(generator=self.rng).select(list(range(max_val_edits)))
171 |
172 | def load_model(self,
173 | tf_path=TF_PATH,
174 | model_name='bart',
175 | noise_frac=0.0,
176 | resume_from_ckpt=None,
177 | **kwargs):
178 | with open(tf_path, 'rb') as f:
179 | tf_dict = pickle.load(f)
180 | tf_map, tf_weights = {}, []
181 | for i,k in enumerate(tf_dict):
182 | tf_map[i] = k
183 | tf_weights.append(tf_dict[k])
184 | vocab_sampler = VocabSampler(tf_weights, tf_map)
185 |
186 | self.model_name = model_name
187 | if model_name == 'bert':
188 | self.model = BertEditor(
189 | tokenizer=self.align_tokenizer,
190 | vocab_sampler=vocab_sampler,
191 | training_noise=noise_frac)
192 | elif model_name == 'bart':
193 | self.model = BertEditor(
194 | tokenizer=self.align_tokenizer,
195 | vocab_sampler=vocab_sampler,
196 | training_noise=noise_frac,
197 | model_type='bart',
198 | model_file='facebook/bart-base')
199 | elif model_name == 'bart-large':
200 | self.model = BertEditor(
201 | tokenizer=self.align_tokenizer,
202 | vocab_sampler=vocab_sampler,
203 | training_noise=noise_frac,
204 | model_type='bart',
205 | model_file='facebook/bart-large')
206 | elif model_name == 'barts2s':
207 | self.model = BartS2SEditor(self.align_model, self.align_tokenizer)
208 | elif model_name == 'barts2s-large':
209 | self.model = BartS2SEditor(self.align_model, self.align_tokenizer, model_file='facebook/bart-large')
210 | else:
211 | raise NotImplementedError(f'Unknown model name: {model_name}')
212 |
213 | if not resume_from_ckpt is None:
214 | print(f'loading model weights from ckpt {resume_from_ckpt}')
215 | self.model.load_state_dict(torch.load(resume_from_ckpt))
216 |
217 | self.model = self.model.to(self.device)
218 | self.model = self.model.train()
219 |
220 | def load_optimizer(self, learning_rate=1e-4, **kwargs):
221 | if self.model_name in ('bert', 'bart', 'bart-large'):
222 | self.optimizer = torch.optim.Adam(params=self.model.parameters(), lr=learning_rate)
223 | elif self.model_name == 'barts2s':
224 | self.optimizer = torch.optim.Adam(params=self.model.bart_model.parameters(), lr=learning_rate)
225 | else:
226 | raise NotImplementedError(f'Unknown model name: {model_name}')
227 |
228 | def pre_train(self):
229 | self.train_loader = DataLoader(self.data['train'], collate_fn = self.prep_batch, batch_size=self.batch_size,
230 | shuffle=True, drop_last=True)
231 | self.val_loader = DataLoader(self.data['val'], collate_fn = self.prep_batch, batch_size=self.batch_size,
232 | shuffle=False, drop_last=False)
233 |
234 | def train(self):
235 |
236 | def val_(min_val_loss):
237 | val_loss, val_metrics = self.validate(cur_iter)
238 | self.log(val_metrics, 'val', cur_iter+1)
239 | if val_loss <= min_val_loss:
240 | torch.save(self.model.state_dict(), self.model_save_path)
241 | min_val_loss = val_loss
242 | return min_val_loss
243 |
244 | self.pre_train()
245 | min_val_loss = 1e10
246 | print("Training")
247 | cur_iter = 0
248 | batch_num = 0
249 | with wandb.init(config=self.config, project=self.project_name, name=self.run_name, dir=self.log_dir) as wandb_run:
250 | wandb.watch(self.model, log='all')
251 | for e in range(self.resume_from_epoch, self.n_epochs):
252 | print("Starting epoch: {}".format(e))
253 | self.optimizer.zero_grad()
254 | metrics = {}
255 | for i,batch in tqdm.tqdm(enumerate(self.train_loader), total=len(self.train_loader)):
256 | batch_num += 1
257 | metrics = self.train_step(batch, metrics)
258 | if (i+1) % self.accumulation_steps == 0:
259 | cur_iter += 1
260 | if self.clip_grad_norm:
261 | torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm = self.max_grad_norm)
262 | if self.track_gradient_norm:
263 | gradient_norm = 0
264 | for p in self.model.parameters():
265 | if not p.requires_grad: continue
266 | gradient_norm += torch.square(p.grad).sum().item()
267 | gradient_norm = np.sqrt(gradient_norm)
268 | self.log({'grad_norm': gradient_norm}, 'train', cur_iter)
269 | self.optimizer.step()
270 | self.optimizer.zero_grad()
271 | if (cur_iter+1) % args.report_every == 0:
272 | self.log(metrics, 'train', cur_iter+1)
273 | if (cur_iter+1) % args.val_every == 0:
274 | min_val_loss = val_(min_val_loss)
275 | # val_loss, val_metrics = self.validate(cur_iter)
276 | # self.log(val_metrics, 'val', cur_iter+1)
277 | # if val_loss <= min_val_loss:
278 | # min_val_loss = val_loss
279 | # torch.save(self.model.state_dict(), self.model_save_path)
280 | metrics = {}
281 | self.post_epoch(e)
282 |
283 | _ = val_(min_val_loss)
284 |
285 | def post_epoch(self, e):
286 | return
287 |
288 | def train_step(self, batch, metrics):
289 | loss, batch_metrics = self.model.compute_loss(*batch) #TODO: logging
290 | for n in batch_metrics:
291 | if n not in metrics:
292 | metrics[n] = []
293 | metrics[n].append(batch_metrics[n])
294 | loss /= self.accumulation_steps
295 | loss.backward()
296 | return metrics
297 |
298 | def validate(self, cur_iter):
299 | self.model.eval()
300 | running_loss = 0
301 | n_batches = 0
302 | metrics = {}
303 | for batch in self.val_loader: # TODO: <= make dataloader
304 | batch = self.model.move_batch(batch, device)
305 | with torch.no_grad():
306 | loss, batch_metrics = self.model.compute_loss(*batch)
307 | for n in batch_metrics:
308 | if n not in metrics:
309 | metrics[n] = []
310 | metrics[n].append(batch_metrics[n])
311 | running_loss += loss.item()
312 | n_batches += 1
313 | # for n,l in zip(val_metrics, (loss, e_loss, l_loss, o_loss, v_loss)):
314 | # val_metrics[n].append(l.item())
315 | # batch = self.model.move_batch(batch, torch.device('cpu'))
316 |
317 | print("============> Val loss, epoch: iteration: {}".format(cur_iter))
318 | print(running_loss/n_batches)
319 | # for n in val_metrics:
320 | # print("{}: {}".format(n, np.mean(val_metrics[n])))
321 | # tb_writer.add_scalar('/'.join(('Val', n)), np.mean(val_metrics[n]), cur_iter)
322 | self.model.train()
323 | return running_loss/n_batches, metrics
324 |
325 |
326 | def prep_batch(self, batch):
327 | alignments = []
328 | for b in batch:
329 | if self.keep_token_type_ids:
330 | token_type_ids = b.get('token_type_ids')
331 | else:
332 | token_type_ids = None
333 | alignment = Alignment(b['alignment'], b['alignment_scores'], token_type_ids)
334 | alignment = self.env.oracle_edit(alignment=alignment, device=self.device, return_alignment=True) # TODO: move to sampling
335 | alignments.append(alignment)
336 | batch = self.model.prep_batch(alignments, device=self.device)
337 | return batch
338 |
339 | @classmethod
340 | def add_args(self, parser):
341 | parser.add_argument('--accumulation_steps', type=int, default=8)
342 | parser.add_argument('--n_epochs', type=int, default=20)
343 | parser.add_argument('--report_every', type=int, default=50)
344 | parser.add_argument('--val_every', type=int, default=1000)
345 | parser.add_argument('--resume_from_ckpt', type=str)
346 | parser.add_argument('--resume_from_epoch', type=int, default=0)
347 | parser.add_argument('--track_gradient_norm', type=lambda x:bool(strtobool(x)), default=False)
348 | parser.add_argument('--clip_grad_norm', type=lambda x:bool(strtobool(x)), default=False)
349 | parser.add_argument('--max_grad_norm', type=float, default=1000)
350 |
351 | data_group = parser.add_argument_group('data')
352 | data_group.add_argument('--batch_size', type=int, default=16)
353 | data_group.add_argument('--gen_batch_size', type=int, default=16)
354 | data_group.add_argument('--max_train_edits', type=int, default=500)
355 | data_group.add_argument('--max_val_edits', type=int, default=1000)
356 | data_group.add_argument('--data_path', type=str, default=
357 | '/data/scratch/faltings/data/infosol/cnn/filtered_bart_64')
358 | data_group.add_argument('--keep_token_type_ids', type=lambda x:bool(strtobool(x)), default=True)
359 |
360 | env_group = parser.add_argument_group('env')
361 | env_group.add_argument('--env_type', type=str, default='bart')
362 | env_group.add_argument('--oracle_stop_p', type=float, default=0.25)
363 | env_group.add_argument('--n_oracle_hints', type=int, default=-1)
364 | env_group.add_argument('--idf_path', type=str, default=
365 | '/data/scratch/faltings/data/infosol/misc/cnn_bart_idfs.pickle')
366 | env_group.add_argument('--n_return_actions', type=int, default=1)
367 | env_group.add_argument('--sort_ops', type=str, default='sort')
368 | env_group.add_argument('--avoid_delete', type=lambda x:bool(strtobool(x)), default=False)
369 | env_group.add_argument('--adjacent_ops', type=lambda x:bool(strtobool(x)), default=False)
370 | env_group.add_argument('--contiguous_edits', type=lambda x:bool(strtobool(x)), default=False)
371 | env_group.add_argument('--complete_words', type=lambda x:bool(strtobool(x)), default=True)
372 |
373 | model_group = parser.add_argument_group('model')
374 | model_group.add_argument('--model_name', type=str, default='bart')
375 | model_group.add_argument('--noise_frac', type=float, default=0.3)
376 | model_group.add_argument('--max_traj_length', type=int, default=64)
377 | model_group.add_argument('--tf_path', type=str, default=
378 | '/data/scratch/faltings/data/infosol/misc/cnn_bart_tfs.pickle')
379 |
380 | opt_group = parser.add_argument_group('optimizer')
381 | opt_group.add_argument('--learning_rate', type=float, default=1e-4)
382 |
383 | # parser.add_argument('--n_forward_iterations', type=int, default=1) TODO: ForwarTrain
384 |
385 | class GenerationInstance():
386 |
387 | """
388 | Used for training jobs that need generation
389 | """
390 |
391 | def __init__(self, datum, env):
392 | self.done = False
393 | self.target = datum['target_tokens']
394 | self.target_text = datum['target_text']
395 | canvas = Canvas(datum['source_tokens'], datum['token_type_ids'])
396 | assert (not datum['alignment'] is None) and (not datum['alignment_scores'] is None)
397 | alignment = Alignment(datum['alignment'], datum['alignment_scores'])
398 | self.history = [(canvas.copy(), alignment.copy())]
399 | self.oracle_canvas = env.reset(alignment=alignment)
400 |
401 | def make_data(self, tokenizer):
402 | instances = []
403 | for canvas, alignment in self.history:
404 | inst = {
405 | 'source_text': canvas.render(tokenizer),
406 | 'source_tokens': canvas.tokens,
407 | 'token_type_ids': canvas.type_ids,
408 | 'target_text': self.target_text,
409 | 'target_tokens': self.target,
410 | 'alignment': alignment.alignment,
411 | 'alignment_scores': alignment.scores
412 | }
413 | instances.append(inst)
414 | return instances
415 |
416 | class DaggerTrain(Train):
417 |
418 | """
419 | Dagger training job. Split into epochs that run through small batches of data generated
420 | from the model or sampled from the dataset.
421 | """
422 |
423 | def __init__(
424 | self,
425 | top_k=10,
426 | top_p=1.0,
427 | stop_threshold=0.9,
428 | max_iter=16,
429 | do_sample=True,
430 | max_length=64,
431 | parallel_decode=False,
432 | n_processes=10,
433 | n_warmup_epochs=0,
434 | sampling_annealing_rate=1.0,
435 | dagger_sampling_rate=1.0,
436 | max_trajectory_length=2,
437 | sample_batch_size=4096,
438 | val_sample_batch_size=1024,
439 | sample_val_every_n_epoch=25,
440 | sample_train_every_n_epoch=10,
441 | **kwargs):
442 |
443 | super().__init__(**kwargs)
444 |
445 | self.top_k = top_k
446 | self.top_p = top_p
447 | self.stop_threshold = stop_threshold
448 | self.max_iter = max_iter
449 | self.do_sample = do_sample
450 | self.max_length = max_length
451 | self.parallel_decode = parallel_decode
452 | self.n_processes = n_processes
453 |
454 | self.n_warmup_epochs = n_warmup_epochs
455 | self.sampling_annealing_rate = sampling_annealing_rate
456 | self.sample_expert_p = dagger_sampling_rate
457 | self.dagger_sampling_rate = dagger_sampling_rate # base rate
458 | self.max_trajectory_length = max_trajectory_length
459 | self.sample_batch_size = sample_batch_size
460 | self.sample_val_every_n_epoch = sample_val_every_n_epoch
461 | self.sample_train_every_n_epoch = sample_train_every_n_epoch
462 | self.val_sample_batch_size = val_sample_batch_size
463 |
464 | self.model_save_path = os.path.join(self.save_dir, 'WEIGHTS.bin')
465 |
466 | def gen_episode(self, instances):
467 | """
468 | Generate one episode of a trajectory
469 | """
470 | canvases = [inst.oracle_canvas for inst in instances]
471 | generations = self.gen_model(canvases)
472 | clean_generations = [g.clean() for g in generations]
473 | align_model, align_tokenizer = self.env.oracle.align_model, self.env.oracle.align_tokenizer
474 | targets = [inst.target for inst in instances]
475 | alignments = list(batch_align_canvases(clean_generations, targets, align_model, align_tokenizer, device=self.device))
476 | for i,inst in enumerate(instances):
477 | inst.history.append((clean_generations[i].copy(), alignments[i].copy()))
478 | inst.oracle_canvas = self.env.oracle_edit(alignment=alignments[i])
479 | return instances
480 |
481 | def sample_batch(self, data):
482 |
483 | """
484 | Sample a new batch of data, i.e. generated trajectories from the current model
485 | """
486 |
487 | instances = [GenerationInstance(d, self.env) for d in data]
488 | # for inst in instances:
489 | # if np.random.random() <= self.sample_expert_p:
490 | # inst.done=True
491 |
492 | self.env.oracle = self.env.oracle.to(self.device)
493 | self.model = self.model.eval()
494 | finished_instances = []
495 | n_sampled_states = 0
496 | for i in range(self.max_trajectory_length):
497 | for inst in instances:
498 | if np.random.random() <= self.sample_expert_p or len(inst.history) >= self.max_trajectory_length:
499 | finished_instances.append(inst)
500 | inst.done = True
501 | n_sampled_states += len(inst.history)
502 | instances = [inst for inst in instances if not inst.done]
503 | # TODO: instead of fixing a maximum trajectory length, fix a max number of instances, then get a distribution over traj lengths
504 | # if n_sampled_states > self.sample_batch_size or len(instances) == 0:
505 | if len(instances) == 0:
506 | break
507 |
508 | instances = self.gen_episode(instances)
509 | finished_instances.extend(instances)
510 | self.env.oracle = self.env.oracle.cpu()
511 | self.model = self.model.train()
512 |
513 | sampled_states = []
514 | for inst in finished_instances:
515 | sampled_states.extend(inst.make_data(self.align_tokenizer))
516 | return sampled_states
517 |
518 | def pre_train(self):
519 | # self.sampling_rate_updates = 0
520 | self.post_epoch(-1)
521 |
522 | def post_epoch(self, e):
523 | # Sample trajectories here!
524 | def get_batch(data, size):
525 | idxs = np.random.choice(np.arange(len(data)), size, replace=False)
526 | for i in idxs:
527 | yield data[int(i)]
528 |
529 | if e < self.n_warmup_epochs or (e+1) % self.sample_train_every_n_epoch == 0:
530 | print("Sampling train batch")
531 | if e >= self.n_warmup_epochs:
532 | # self.sampling_rate_updates += 1
533 | self.sample_expert_p *= self.sampling_annealing_rate
534 | # np.exp(self.sampling_rate_updates * np.log(self.sampling_annealing_rate)) * self.dagger_sampling_rate
535 | print(self.sample_expert_p)
536 | train_batch = self.sample_batch(get_batch(self.data['train'], self.sample_batch_size))
537 | self.train_loader = DataLoader(train_batch, collate_fn = self.prep_batch, batch_size=self.batch_size,
538 | shuffle=True, drop_last=True)
539 |
540 | if (e+1) % self.sample_val_every_n_epoch == 0:
541 | print("Sampling val batch")
542 | val_batch = self.sample_batch(get_batch(self.data['val'], self.val_sample_batch_size))
543 | self.val_loader = DataLoader(val_batch, collate_fn = self.prep_batch, batch_size=self.batch_size,
544 | shuffle=False, drop_last=False)
545 |
546 |
547 | # def sample_trajectory(self, datum):
548 | # alignments = []
549 | # alignment = Alignment(alignment=datum['alignment'], scores=datum['alignment_scores'])
550 | # alignment = self.env.reset(alignment=alignment, return_alignment=True)
551 | # i = 0
552 | # while True:
553 | # alignments.append(alignment)
554 | # if np.random.random() > self.sample_expert_p\
555 | # and i < self.max_trajectory_length:
556 | # canvas = alignment.get_source_canvas()
557 | # canvas = self.gen_model(canvas)
558 | # alignment,_ = self.env.step(canvas=canvas, return_alignment=True, device=self.device)
559 | # i += 1
560 | # else:
561 | # break
562 | # return alignments
563 | #
564 | # def prep_batch(self, batch):
565 | # alignments = []
566 | # for b in batch:
567 | # alignments.extend(self.sample_trajectory(b))
568 | # if len(alignments) >= self.batch_size: break
569 | # batch = self.model.prep_batch(alignments, device=self.device)
570 | # return batch
571 |
572 | def gen_model(self, canvases):
573 | if self.model_name in ('bart', 'bert', 'bart-large'):
574 | canvases = list(tqdm.tqdm(self.model.batch_depth_decode(
575 | canvases,
576 | top_k=self.top_k,
577 | max_batch_tokens=2048,
578 | device=self.device,
579 | # parallel=self.parallel_decode,
580 | return_idx=True,
581 | queue_size=2000,
582 | max_iter=self.max_iter,
583 | # n_processes=self.n_processes
584 | ), total=len(canvases)))
585 | canvases = [(i,c) for c,i in canvases]
586 | canvases = [c for i,c in sorted(canvases)]
587 | return canvases
588 | elif self.model_name == 'barts2s': #TODO
589 | canvases = self.model.batch_generate(canvases, device=self.device,
590 | do_sample=self.do_sample, top_p=self.top_p, max_length=self.max_length)
591 | # print(canvases)
592 | return canvases
593 |
594 | @classmethod
595 | def add_args(cls, parser):
596 | super().add_args(parser)
597 |
598 | model_gen_group = parser.add_argument_group('model generation')
599 | model_gen_group.add_argument('--top_k', type=int, default=10)
600 | model_gen_group.add_argument('--top_p', type=float, default=0.95)
601 | model_gen_group.add_argument('--stop_threshold', type=float, default=0.9)
602 | model_gen_group.add_argument('--parallel_decode', type=lambda x:bool(strtobool(x)), default=False)
603 | model_gen_group.add_argument('--n_processes', type=int, default=10)
604 | model_gen_group.add_argument('--do_sample', type=lambda x:bool(strtobool(x)), default=True)
605 | model_gen_group.add_argument('--max_length', type=int, default=64)
606 | model_gen_group.add_argument('--max_iter', type=int, default=32)
607 |
608 | dagger_group = parser.add_argument_group('dagger')
609 | dagger_group.add_argument('--n_warmup_epochs', type=int, default=0)
610 | dagger_group.add_argument('--sampling_annealing_rate', type=float, default=1.0)
611 | dagger_group.add_argument('--dagger_sampling_rate', type=float, default=0.5)
612 | dagger_group.add_argument('--max_trajectory_length', type=int, default=2)
613 | dagger_group.add_argument('--sample_batch_size', type=int, default=100)
614 | dagger_group.add_argument('--sample_val_every_n_epoch', type=int, default=50)
615 | dagger_group.add_argument('--sample_train_every_n_epoch', type=int, default=15)
616 | dagger_group.add_argument('--val_sample_batch_size', type=int, default=100)
617 |
618 | class ForwardTrain(Train):
619 |
620 | """
621 | Forward training job. Similar to dagger but generates less frequently.
622 | Not used anymore
623 | """
624 |
625 | def __init__(
626 | self,
627 | top_k=10,
628 | top_p=1.0,
629 | stop_threshold=1.0,
630 | max_iter=16,
631 | do_sample=True,
632 | max_length=64,
633 | n_forward_iter=2,
634 | resume_forward_iter=0,
635 | sample_alg='depth',
636 | sample_gen_reverse_steps=64,
637 | force_regen=False,
638 | n_processes=None,
639 | parallel_decode=False,
640 | **kwargs):
641 |
642 | super().__init__(**kwargs)
643 |
644 | self.top_k = top_k
645 | self.sample_alg = sample_alg
646 | self.top_p = top_p
647 | self.stop_threshold = stop_threshold
648 | self.max_iter = max_iter
649 | self.do_sample = do_sample
650 | self.max_length = max_length
651 | self.sample_gen_reverse_steps = sample_gen_reverse_steps
652 | self.n_forward_iter = n_forward_iter
653 | self.resume_forward_iter = resume_forward_iter
654 | self.force_regen = force_regen
655 | self.n_processes = n_processes
656 | self.parallel_decode = parallel_decode
657 |
658 | self.meta_log_dir = self.log_dir
659 | self.meta_run_name = self.run_name
660 |
661 | def gen_model(self, input_generator):
662 | if self.model_name in ('bart', 'bart-large', 'bert'):
663 | if self.sample_alg == 'sample':
664 | for canvas, idx in self.model.batch_decode(
665 | input_generator,
666 | top_k = self.top_k,
667 | max_batch_tokens=2048,
668 | device=self.device,
669 | # parallel=self.parallel_decode,
670 | return_idx=True,
671 | queue_size=2000,
672 | max_iter=self.max_iter,
673 | # n_processes=self.n_processes
674 | ):
675 | yield canvas, idx
676 | elif self.sample_alg == 'depth':
677 | for canvas, idx in self.model.batch_depth_decode(
678 | input_generator,
679 | top_k = self.top_k,
680 | max_batch_tokens=2048,
681 | device=self.device,
682 | # parallel=self.parallel_decode,
683 | return_idx=True,
684 | queue_size=2000,
685 | max_iter=self.max_iter,
686 | stop_threshold=self.stop_threshold,
687 | # n_processes=self.n_processes
688 | ):
689 | yield canvas, idx
690 | elif self.model_name == 'barts2s':
691 | for i,canvas in enumerate(input_generator):
692 | canvas = self.model.generate(
693 | canvas, device=self.device, do_sample=self.do_sample, top_p=self.top_p, max_length=self.max_length
694 | )
695 | yield canvas, i
696 | return
697 |
698 | def batch_generate(self, data, save_path, chunk_size=10000, **kwargs):# batch_size=1 for now because it is still not scaling up well (because of differences in canvas lengths, etc.)
699 | def generator_(data, data_buffer):
700 | for i,datum in enumerate(data):
701 | if self.keep_token_type_ids:
702 | token_type_ids = datum.get('token_type_ids')
703 | else:
704 | token_type_ids = None
705 | alignment = Alignment(datum['alignment'], datum['alignment_scores'], token_type_ids)
706 | alignment = self.env.oracle_edit(alignment=alignment, return_alignment=True)
707 |
708 | # push forward
709 | non_const_ops = alignment.get_non_const_ops()
710 | n_forward_steps = max(0, len(non_const_ops) - self.sample_gen_reverse_steps)
711 | forward_ops = np.random.choice(non_const_ops, n_forward_steps)
712 | alignment.push_forward(forward_ops)
713 | data_buffer[i] = alignment
714 | yield alignment.get_source_canvas()
715 |
716 | # chunk generator
717 | def chunks(iterator, n):
718 | for first in iterator: # take one item out (exits loop if `iterator` is empty)
719 | rest_of_chunk = itertools.islice(iterator, 0, n - 1)
720 | yield itertools.chain([first], rest_of_chunk) # concatenate the first item back
721 |
722 | self.model = self.model.eval()
723 | with open(save_path, 'wt') as f:
724 | for i,chunk in enumerate(chunks(iter(data), chunk_size)):
725 | print(f'On chunk {i}')
726 | timeouts = 0
727 | data_buffer = {}
728 | popped_idxs = set()
729 | input_generator = generator_(chunk, data_buffer)
730 | output_generator = self.gen_model(input_generator)
731 | for canvas, idx in tqdm.tqdm(output_generator, total=chunk_size):
732 | if canvas is None:
733 | timeouts += 1
734 | print('timeout')
735 | continue
736 | try:
737 | alignment = data_buffer.pop(idx)
738 | except KeyError as e:
739 | print(np.max(list(data_buffer.keys())))
740 | print(idx in popped_idxs)
741 | raise e
742 | popped_idxs.add(idx)
743 | canvas = canvas.clean()
744 | json_str = json.dumps({
745 | 'source_tokens': canvas.tokens,
746 | 'token_type_ids': canvas.type_ids,
747 | # 'idx': idx,
748 | 'source_text': canvas.render(self.align_tokenizer),
749 | 'target_tokens': alignment.get_target_tokens(),
750 | 'target_text': alignment.get_target_canvas().render(self.align_tokenizer),
751 | })
752 | f.write(json_str + '\n')
753 | print(f'{timeouts} timeouts')
754 |
755 | def align_batch(self, batch):
756 | tokens_a = batch['source_tokens']
757 | tokens_b = batch['target_tokens']
758 | alignments = list(
759 | batch_align(
760 | tokens_a, tokens_b, self.align_model, self.align_tokenizer,
761 | baseline_score=0.3, device=self.device
762 | )
763 | )
764 | alignment_scores = [a.scores for a in alignments]
765 | alignments = [a.alignment for a in alignments]
766 | batch['alignment'] = alignments
767 | batch['alignment_scores'] = alignment_scores
768 | return batch
769 |
770 | def load_data(self, **kwargs):
771 | super().load_data(**kwargs)
772 | self.gen_data = self.data
773 |
774 | def pre_train(self):
775 | self.train_loader = DataLoader(self.data['train'], collate_fn = self.prep_batch, batch_size=self.batch_size,
776 | shuffle=True, drop_last=True)
777 | self.val_loader = DataLoader(self.data['val'], collate_fn = self.prep_batch, batch_size=self.batch_size,
778 | shuffle=False, drop_last=False)
779 |
780 | def train(self):
781 | if self.resume_forward_iter > 0:
782 | forward_iter = self.resume_forward_iter - 1
783 | iter_save_dir = os.path.join(self.save_dir, f'forward_iter_{forward_iter}')
784 | for i in range(forward_iter):
785 | gen_save_path = os.path.join(self.save_dir, f'forward_iter_{i}', 'gen_data')
786 | self.gen_data = load_from_disk(gen_save_path)
787 | self.add_data(self.gen_data)
788 |
789 | gen_save_path = os.path.join(iter_save_dir, 'gen_data')
790 | if not os.path.exists(gen_save_path) or self.force_regen:
791 | weights_path = os.path.join(iter_save_dir, 'WEIGHTS.bin')
792 | self.kwargs.update({'resume_from_ckpt': weights_path})
793 | self.load_model(**self.kwargs)
794 | self.kwargs.update({'resume_from_ckpt': None})
795 | self.generate(forward_iter, iter_save_dir)
796 | else:
797 | self.gen_data = load_from_disk(gen_save_path)
798 | self.add_data(self.gen_data)
799 |
800 | self.load_model(**self.kwargs)
801 | self.load_optimizer(**self.kwargs)
802 |
803 | for forward_iter in range(self.resume_forward_iter, self.n_forward_iter):
804 |
805 | iter_save_dir = os.path.join(self.save_dir, f'forward_iter_{forward_iter}')
806 | if not os.path.exists(iter_save_dir):
807 | os.makedirs(iter_save_dir)
808 | self.model_save_path = os.path.join(iter_save_dir, 'WEIGHTS.bin')
809 |
810 | iter_logdir = os.path.join(self.meta_log_dir, f'forward_iter_{forward_iter}')
811 | if not os.path.exists(iter_logdir):
812 | os.makedirs(iter_logdir)
813 | self.log_dir = iter_logdir
814 |
815 | self.run_name = '-'.join((self.meta_run_name, f'forward_iter_{forward_iter}'))
816 | # self.logger = SummaryWriter(log_dir = iter_logdir)
817 |
818 | super().train()
819 |
820 | if forward_iter == (self.n_forward_iter - 1):
821 | break
822 |
823 | # load best checkpoint
824 | self.kwargs.update({'resume_from_ckpt': self.model_save_path})
825 | self.load_model(**self.kwargs)
826 | self.kwargs.update({'resume_from_ckpt': None})
827 |
828 | self.generate(forward_iter, iter_save_dir)
829 |
830 | print('reloading model')
831 |
832 | self.load_model(**self.kwargs)
833 | self.load_optimizer(**self.kwargs)
834 |
835 | def generate(self, forward_iter, iter_save_dir):
836 |
837 | print('generating')
838 | train_save_path = os.path.join(iter_save_dir, 'train_data')
839 | val_save_path = os.path.join(iter_save_dir, 'val_data')
840 | self.batch_generate(self.gen_data['train'], train_save_path)
841 | self.batch_generate(self.gen_data['val'], val_save_path)
842 |
843 | def listdict2dictlist(ld):
844 | keys = ld[0].keys()
845 | return {
846 | k: [d[k] for d in ld] for k in keys
847 | }
848 |
849 | train_generations, val_generations = [], []
850 | with open(train_save_path, 'rt') as f:
851 | for line in f:
852 | train_generations.append(json.loads(line))
853 | with open(val_save_path, 'rt') as f:
854 | for line in f:
855 | val_generations.append(json.loads(line))
856 |
857 | generated_data = DatasetDict({
858 | 'train': Dataset.from_dict(listdict2dictlist(train_generations)),
859 | 'val': Dataset.from_dict(listdict2dictlist(val_generations))
860 | })
861 |
862 | # generated_data = DatasetDict({
863 | # 'train': load_dataset('json', train_save_path),
864 | # 'val': load_dataset('json', val_save_path)
865 | # })
866 |
867 | self.align_model = self.align_model.to(self.device)
868 | print('aligning')
869 | generated_data = generated_data.map(self.align_batch, batched=True, batch_size=64)
870 | self.align_model = self.align_model.cpu()
871 | self.gen_data = generated_data
872 |
873 | os.remove(train_save_path)
874 | os.remove(val_save_path)
875 |
876 | gen_save_path = os.path.join(iter_save_dir, 'gen_data')
877 | generated_data.save_to_disk(gen_save_path)
878 |
879 | self.add_data(generated_data)
880 |
881 | def add_data(self, dataset):
882 |
883 | self.data = DatasetDict({
884 | 'train': custom_cat_datasets(dataset['train'], self.data['train']),
885 | 'val': custom_cat_datasets(dataset['val'], self.data['val'])
886 | })
887 |
888 | @classmethod
889 | def add_args(cls, parser):
890 | super().add_args(parser)
891 |
892 | parser.add_argument('--sample_alg', choices=['depth', 'sample'], default='depth')
893 | parser.add_argument('--top_k', type=int, default=10)
894 | parser.add_argument('--top_p', type=float, default=0.95)
895 | parser.add_argument('--stop_threshold', type=float, default=0.9)
896 | parser.add_argument('--max_iter', type=int, default=32)
897 | parser.add_argument('--do_sample', type=lambda x:bool(strtobool(x)), default=False)
898 | parser.add_argument('--max_length', type=int, default=64)
899 | parser.add_argument('--n_forward_iter', type=int, default=2)
900 | parser.add_argument('--resume_forward_iter', type=int, default=0)
901 | parser.add_argument('--sample_gen_reverse_steps', type=int, default=64)
902 | parser.add_argument('--force_regen', type=lambda x:bool(strtobool(x)), default=False)
903 | parser.add_argument('--parallel_decode', type=lambda x:bool(strtobool(x)), default=False)
904 | parser.add_argument('--n_processes', type=int)
905 |
906 | #class NBTrain(Train):
907 | #
908 | # def __init__(
909 | # self,
910 | # top_k=10):
911 | #
912 | # self.top_k = top_k
913 |
914 |
915 |
916 | if __name__ == '__main__':
917 |
918 | parser = argparse.ArgumentParser()
919 | parser.add_argument('--cuda_device', type=int, default=0)
920 | parser.add_argument('--use_timestamp', type=lambda x:bool(strtobool(x)), default=False)
921 | parser.add_argument('--seed', type=int, default=42)
922 | parser.add_argument('--run_dir', type=str,
923 | default='/Mounts/rbg-storage1/users/faltings/cache/infosol/experiment_results/train/')
924 | parser.add_argument('--run_name', type=str,
925 | default='debug')
926 | parser.add_argument('--project_name', type=str,
927 | default='infosol')
928 | subparsers = parser.add_subparsers()
929 |
930 | base_parser = subparsers.add_parser('base')
931 | base_parser.set_defaults(func=Train)
932 | Train.add_args(base_parser)
933 |
934 | dagger_parser = subparsers.add_parser('dagger')
935 | dagger_parser.set_defaults(func=DaggerTrain)
936 | DaggerTrain.add_args(dagger_parser)
937 |
938 | forward_parser = subparsers.add_parser('forward')
939 | forward_parser.set_defaults(func=ForwardTrain)
940 | ForwardTrain.add_args(forward_parser)
941 |
942 | args = parser.parse_args()
943 |
944 | device=torch.device('cuda:{}'.format(args.cuda_device)) if torch.cuda.is_available() else torch.device('cpu')
945 |
946 | if args.use_timestamp:
947 | run_id = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
948 | run_dir = os.path.join(args.run_dir, args.run_name, run_id)
949 | else: run_dir = os.path.join(args.run_dir, args.run_name)
950 | if not os.path.exists(run_dir):
951 | os.makedirs(run_dir)
952 |
953 | config_path = os.path.join(run_dir, 'config.conf')
954 | with open(config_path, 'w') as f:
955 | for k,v in vars(args).items():
956 | f.write('--' + k + '\n' + str(v) + '\n')
957 |
958 | log_dir = os.path.join(run_dir, 'logs')
959 | # if os.path.exists(log_dir):
960 | # shutil.rmtree(log_dir)
961 | # for logfile in os.listdir(log_dir):
962 | # os.remove(os.path.join(log_dir, logfile))
963 | if not os.path.exists(log_dir):
964 | os.makedirs(log_dir)
965 |
966 | rng = np.random.default_rng(seed = args.seed)
967 | train = args.func(
968 | config = vars(args),
969 | save_dir = run_dir,
970 | log_dir = log_dir,
971 | device = device,
972 | rng = rng,
973 | **vars(args))
974 | train.train()
975 |
--------------------------------------------------------------------------------
/infosol/utils/data/cnn_dailymail.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from nltk.tokenize import sent_tokenize
4 |
5 | def prep_data_inst(tokenizer,
6 | data_inst,
7 | n_article_sentences = 3,
8 | prompt_length=10,
9 | use_highlights=True):
10 | highlight_prefix = 'Summary: '
11 | article_prefix = 'Article: '
12 | highlights = highlight_prefix + ' '.join(data_inst['highlights'].split('\n')) + '\n'
13 | highlight_ids = tokenizer(highlights, return_tensors='pt')['input_ids'].squeeze()
14 | article_text = ' '.join(sent_tokenize(data_inst['article'])[:n_article_sentences]) + '\n'
15 | article_ids = tokenizer(article_text, return_tensors='pt')['input_ids'].squeeze()
16 | article_prefix_length = 0
17 | if use_highlights:
18 | article_text = article_prefix + article_text
19 | article_length_no_prefix = article_ids.size(0)
20 | article_ids = tokenizer(article_text, return_tensors='pt')['input_ids'].squeeze()
21 | article_prefix_length = article_ids.size(0) - article_length_no_prefix
22 | if use_highlights:
23 | input_ids = torch.cat((highlight_ids, article_ids[:prompt_length]))
24 | highlight_length = highlight_ids.size(0)
25 | else:
26 | input_ids = article_ids[:prompt_length]
27 | highlight_length = 0
28 | target_ids = article_ids
29 | return input_ids, target_ids, highlight_length, article_prefix_length
30 |
--------------------------------------------------------------------------------
/infosol/utils/keywords.py:
--------------------------------------------------------------------------------
1 | import re
2 | import numpy as np
3 |
4 | def order_keywords(target_text, keywords, highlighter):
5 | highlighted_string = highlighter.highlight(target_text, keywords)
6 | kw_re = re.compile('[^<]*')
7 | ordered_keywords = kw_re.findall(highlighted_string)
8 | ordered_keywords = [re.sub('', '', re.sub('', '', kw)) for kw in ordered_keywords]
9 | for i in range(len(ordered_keywords)):
10 | kw_matched = False
11 | for kw_pair in keywords:
12 | if kw_pair[0] == ordered_keywords[i]:
13 | ordered_keywords[i] = kw_pair
14 | kw_matched = True
15 | break
16 | if not kw_matched: ordered_keywords[i] = (ordered_keywords[i], 0)
17 | return ordered_keywords
18 |
19 | def choose_top_kws(kws, n_kws):
20 | scores = [k[1] for k in kws]
21 | top_kws = [kws[i] for i in np.sort(np.argsort(scores)[-n_kws:])]
22 | return top_kws
23 |
--------------------------------------------------------------------------------
/infosol/utils/pointer_utils.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import torch
3 |
4 | sys.path.append('/home/v-fefal/POINTER')
5 | from pytorch_transformers import BertForMaskedLM
6 | from pytorch_transformers.tokenization_bert import BertTokenizer
7 | from inference import convert_example_to_features, greedy_search, PregeneratedDataset
8 | from util import MAX_TURN, PREVENT_FACTOR, PROMOTE_FACTOR, PREVENT_LIST, REDUCE_LIST, STOP_LIST, boolean_string
9 |
10 | class POINTERArgs:
11 |
12 | def __init__(self, bert_model, do_lower_case=False, noi_decay=1, reduce_decay=1, prevent=True,
13 | reduce_stop=True, lessrepeat=True, max_seq_length=256, no_ins_at_first=False, verbose=0):
14 | self.bert_model = bert_model
15 | self.do_lower_case = do_lower_case
16 | self.noi_decay = noi_decay
17 | self.reduce_decay = reduce_decay
18 | self.prevent = prevent
19 | self.reduce_stop = reduce_stop
20 | self.lessrepeat = lessrepeat
21 | self.max_seq_length = max_seq_length
22 | self.no_ins_at_first = no_ins_at_first
23 | self.verbose = verbose
24 |
25 | class POINTERWrapper:
26 |
27 | def __init__(self, args):
28 | self.args = args
29 | self.tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case)
30 | self.model = BertForMaskedLM.from_pretrained(args.bert_model)
31 | self.model.eval()
32 |
33 | prevent = [ self.tokenizer.vocab.get(x) for x in PREVENT_LIST] if args.prevent else None
34 | if args.reduce_stop:
35 | # import pdb; pdb.set_trace()
36 | reduce_l = REDUCE_LIST | STOP_LIST
37 | reduce = None
38 | if args.prevent:
39 | reduce = [ self.tokenizer.vocab.get(x) for x in reduce_l]
40 | reduce = [s for s in reduce if s]
41 | self.prevent = prevent
42 | self.reduce = reduce
43 |
44 | def prep_input(self, inp_tokens, device):
45 | # canvas = [c.strip().lstrip() for c in canvas]
46 | features = convert_example_to_features(inp_tokens, self.tokenizer, self.args.max_seq_length,
47 | no_ins_at_first = self.args.no_ins_at_first, id=0, tokenizing=True)
48 | out = (features.input_ids, features.input_mask, features.segment_ids, features.lm_label_ids, features.no_ins)
49 | out = tuple(torch.tensor(o.reshape(1,-1)).long().to(device) for o in out)
50 | return out
51 |
52 | def generate_(self, input_ids, segment_ids, input_mask, no_ins):
53 | sep_tok = self.tokenizer.vocab['[SEP]']
54 | cls_tok = self.tokenizer.vocab['[CLS]']
55 | pad_tok = self.tokenizer.vocab['[PAD]']
56 |
57 | predict_ids = greedy_search(self.model, input_ids, segment_ids, input_mask, no_ins = no_ins, args=self.args,
58 | tokenizer=self.tokenizer, prevent=self.prevent, reduce=self.reduce)
59 | output = " ".join([str(self.tokenizer.ids_to_tokens.get(x, "noa").encode('ascii', 'ignore').decode('ascii')) for x in predict_ids[0].detach().cpu().numpy() if x!=sep_tok and x != pad_tok and x != cls_tok])
60 | output = output.replace(" ##", "")
61 | return output
62 |
63 | def generate(self, inp_tokens, device=torch.device('cpu')):
64 | input_ids, input_mask, segment_ids, lm_label_ids, no_ins = self.prep_input(inp_tokens, device)
65 | return self.generate_(input_ids, segment_ids, input_mask, no_ins)
66 |
--------------------------------------------------------------------------------
/jobs/interactive/cnn-bart-s2s-len:
--------------------------------------------------------------------------------
1 | BartS2S --model_path models/dagger/cnn-bart-s2s/WEIGHTS.bin --out_path out/main/interactive/cnn-bart-s2s/1x6_lp3.3 --data_path data/cnn_bart --max_data 2000 --idf_path data/misc/cnn_bart_idfs.pickle --n_episodes 1 --n_oracle_edits 6 --adjacent_ops True --complete_words True --contiguous_edits True --bleu_ngrams 1 --length_penalty 3.3
2 | BartS2S --model_path models/dagger/cnn-bart-s2s/WEIGHTS.bin --out_path out/main/interactive/cnn-bart-s2s/2x3_lp1.5 --data_path data/cnn_bart --max_data 2000 --idf_path data/misc/cnn_bart_idfs.pickle --n_episodes 2 --n_oracle_edits 3 --adjacent_ops True --complete_words True --contiguous_edits True --bleu_ngrams 1 --length_penalty 1.5
3 | BartS2S --model_path models/dagger/cnn-bart-s2s/WEIGHTS.bin --out_path out/main/interactive/cnn-bart-s2s/3x2_lp0.8 --data_path data/cnn_bart --max_data 2000 --idf_path data/misc/cnn_bart_idfs.pickle --n_episodes 3 --n_oracle_edits 2 --adjacent_ops True --complete_words True --contiguous_edits True --bleu_ngrams 1 --length_penalty 0.8
4 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | torch==1.13.1
2 | datasets==2.4.0
3 | matplotlib==3.5.1
4 | nltk==3.7
5 | setuptools==65.5.1
6 | tqdm==4.64.0
7 | transformers==4.18.0
8 | numpy==1.22.4
9 | wandb==0.13.2
10 | scikit-learn==1.1.1
11 |
--------------------------------------------------------------------------------
/scripts/dowload_models.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | wget "https://onedrive.live.com/download?cid=7AA4E17800AB44C8&resid=7AA4E17800AB44C8%21510310&authkey=ABxsQgnGdJk7wWI&download=1" -O infosol_models.tar.zst
4 | zstd -dc infosol_models.tar.zst | tar -xvf - ./models
5 |
--------------------------------------------------------------------------------
/scripts/make_data.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import os
3 | import json
4 | import torch
5 | import itertools
6 | import tqdm
7 | import pickle
8 | import random
9 | import argparse
10 | import numpy as np
11 | import matplotlib.pyplot as pltq
12 |
13 | from sklearn.feature_extraction.text import CountVectorizer
14 | from datasets import Dataset, DatasetDict, load_from_disk, load_dataset, concatenate_datasets
15 | from transformers import AutoTokenizer, BertModel, BartModel, AutoModel
16 | from infosol.alignment import batch_align, get_non_const_ops, sample_actions
17 | from nltk.tokenize import sent_tokenize
18 |
19 | from transformers import AutoModel
20 | model = AutoModel.from_pretrained('facebook/bart-large')
21 | model_parameters = filter(lambda p: p.requires_grad, model.parameters())
22 | params = sum([np.prod(p.size()) for p in model_parameters])
23 |
24 | def tokenize_text(data_inst):
25 | data_inst['target_tokens'] = tokenizer.tokenize(data_inst['target_text'])
26 | data_inst['source_tokens'] = tokenizer.tokenize(data_inst['source_text'])
27 | return data_inst
28 |
29 | def make_len_filter(max_len=32, min_len=0):
30 | def filter_(data_inst):
31 | return len(data_inst['target_tokens']) <= max_len and len(data_inst['target_tokens']) >= min_len
32 | return filter_
33 |
34 | def align_batch(data_batch, model, tokenizer, baseline_score=0.3, device=torch.device('cpu')):
35 | tokens_a = data_batch['source_tokens']
36 | tokens_b = data_batch['target_tokens']
37 | alignments = list(batch_align(tokens_a, tokens_b,
38 | model, tokenizer, baseline_score=baseline_score, device=device))
39 | alignment_scores = [a.scores for a in alignments]
40 | alignments = [a.alignment for a in alignments]
41 | data_batch['alignment'] = alignments
42 | data_batch['alignment_scores'] = alignment_scores
43 | return data_batch
44 |
45 | def is_edit(data_inst):
46 | return len(get_non_const_ops(data_inst['alignment'])) > 0
47 |
48 | def compute_tf_idf_scores(data, tokenizer):
49 | sentences = [datum['target_text'] for datum in data]
50 | cvectorizer = CountVectorizer(analyzer=tokenizer.tokenize)
51 | counts = cvectorizer.fit_transform(sentences)
52 |
53 | document_counts = counts > 0
54 |
55 | doc_counts_sum = document_counts.sum(axis=0)
56 | counts_sum = counts.sum(axis=0)
57 |
58 | idf_scores = {k: -np.log(doc_counts_sum[0,cvectorizer.vocabulary_[k]]) for k in tqdm.tqdm(cvectorizer.vocabulary_)}
59 | idf_scores = {k: np.log(document_counts.shape[0]) + idf_scores[k] for k in idf_scores}
60 |
61 | total = counts_sum.sum()
62 | tf_dict = {k: counts_sum[0, cvectorizer.vocabulary_[k]]/total for k in cvectorizer.vocabulary_}
63 |
64 | return idf_scores, tf_dict
65 |
66 | def add_type_ids(datum):
67 | datum['token_type_ids'] = [0] * len(datum['source_tokens'])
68 | return datum
69 |
70 | def remove_type_ids(datum):
71 | datum.pop('token_type_ids')
72 | return datum
73 |
74 | def tokenize_text(data_inst, tokenizer):
75 | data_inst['target_tokens'] = tokenizer.tokenize(data_inst['target_text'])
76 | data_inst['source_tokens'] = tokenizer.tokenize(data_inst['source_text'])
77 | return data_inst
78 |
79 | def random_subset(data, size, seed=42):
80 | return data.shuffle(
81 | generator=np.random.default_rng(seed)
82 | ).select(
83 | range(size)
84 | )
85 |
86 | def random_subset_dict(dataset_dict, split_sizes, seed=42):
87 | return DatasetDict({
88 | n: random_subset(dataset_dict[n], s, seed=seed) for n,s in zip(dataset_dict, split_sizes)
89 | })
90 |
91 | def process_cnn_instances(data_inst, tokenizer):
92 | return_instances = []
93 | for sent in data_inst['highlights'][0].split('\n'):
94 | return_inst = {}
95 | target_text = sent
96 | target_tokens = tokenizer.tokenize(target_text)
97 | alignment = [('', t) for t in target_tokens]
98 | alignment_scores = [0.3] * len(alignment)
99 | source_tokens = []
100 | source_text = ''
101 | token_type_ids = []
102 | return_inst = {
103 | 'id': data_inst['id'][0],
104 | 'target_text': target_text,
105 | 'target_tokens': target_tokens,
106 | 'source_text': source_text,
107 | 'source_tokens': source_tokens,
108 | 'alignment': alignment,
109 | 'alignment_scores': alignment_scores,
110 | 'token_type_ids': token_type_ids
111 | }
112 | return_instances.append(return_inst)
113 | return_instances = {k: [inst[k] for inst in return_instances] for k in return_instances[0]}
114 | return return_instances
115 |
116 | def make_cnn_data(tokenizer):
117 | data = load_dataset('cnn_dailymail', '3.0.0')
118 | data = data.map(lambda x: process_cnn_instances(x, tokenizer), batched=True, batch_size=1, remove_columns=['article', 'highlights', 'id'])
119 | data = data.filter(make_len_filter(max_len=64, min_len=10))
120 | return data
121 |
122 | if __name__ == '__main__':
123 |
124 | parser = argparse.ArgumentParser()
125 | parser.add_argument('--data_dir', type=str, default='data') # where to save data
126 | args = parser.parse_args()
127 |
128 | bart_tokenizer = AutoTokenizer.from_pretrained('facebook/bart-base')
129 | bart_model = BartModel.from_pretrained('facebook/bart-base').encoder
130 | device = torch.device('cuda')
131 | bart_model = bart_model.to(device)
132 |
133 | cnn_data = load_dataset('ccdv/cnn_dailymail', '3.0.0')
134 |
135 | bart_cnn_data = make_cnn_data(bart_tokenizer)
136 |
137 | save_path = os.path.join(args.data_dir, 'cnn_bart')
138 | bart_cnn_data.save_to_disk(save_path)
139 |
140 | cnn_bart_idfs, cnn_bart_tfs = compute_tf_idf_scores(bart_cnn_data['train'].select(range(500000)), bart_tokenizer)
141 |
142 | misc_dir = os.path.join(args.data_dir, 'misc')
143 | if not os.path.exists(misc_dir):
144 | os.makedirs(misc_dir)
145 |
146 | bart_idf_save_path = os.path.join(misc_dir, 'cnn_bart_idfs.pickle')
147 | with open(bart_idf_save_path, 'wb') as f:
148 | pickle.dump(cnn_bart_idfs, f)
149 |
150 | bart_tf_save_path = os.path.join(misc_dir, 'cnn_bart_tfs.pickle')
151 | with open(bart_tf_save_path, 'wb') as f:
152 | pickle.dump(cnn_bart_tfs, f)
153 |
--------------------------------------------------------------------------------
/scripts/make_eval_jobs.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import pickle
3 | import itertools
4 | import json
5 | import os
6 | import re
7 |
8 | from distutils.util import strtobool
9 |
10 | """
11 | Expects the following directory structure:
12 | model_dir /
13 | exp_name /
14 | model_name /
15 | WEIGHTS.bin
16 | model_name
17 | exp_name /
18 | data_dir /
19 | cnn /
20 | filtered_bart_64.pickle # cnn dataset
21 | misc /
22 | cnn_bart_idfs.pickle # cnn idf scores
23 |
24 | Will write the argument files under job_dir, and the scripts will write results to out_dir/job_name. Pass the argument files to the run_eval.py script
25 | to run the experiments.
26 | """
27 |
28 | command_template = '{} \
29 | --model_path {} \
30 | --out_path {} \
31 | --data_path {} \
32 | --max_data {} \
33 | --idf_path {} \
34 | --n_episodes {} \
35 | --n_oracle_edits {} \
36 | --adjacent_ops {} \
37 | --complete_words {} \
38 | --contiguous_edits {} \
39 | --bleu_ngrams {}'
40 |
41 | def job_args_from_job_name(job_name):
42 | if 'bert' in job_name:
43 | if 'cnn' in job_name:
44 | data_path_ = os.path.join(data_dir, 'cnn', 'filtered_bert_64')
45 | idf_path_ = os.path.join(data_dir, 'misc', 'cnn_bert_idfs.pickle')
46 | elif 'yelp' in job_name:
47 | data_path_ = os.path.join(data_dir, 'yelp_pe', 'bert_gen_100')
48 | idf_path_ = os.path.join(data_dir, 'misc', 'yelp_bert_idfs.pickle')
49 | elif 'bart' in job_name:
50 | if 'cnn' in job_name:
51 | data_path_ = os.path.join(data_dir, 'cnn', 'filtered_bart_64')
52 | idf_path_ = os.path.join(data_dir, 'misc', 'cnn_bart_idfs.pickle')
53 | elif 'yelp' in job_name:
54 | data_path_ = os.path.join(data_dir, 'yelp_pe', 'bart_gen_100')
55 | idf_path_ = os.path.join(data_dir, 'misc', 'yelp_bart_idfs.pickle')
56 |
57 | if 'bert' in job_name:
58 | func_ = 'BertEditor'
59 | elif 'bart-s2s' in job_name:
60 | func_ = 'BartS2S'
61 | elif 'bart_editor_large' in job_name:
62 | func_ = 'BartLargeEditor'
63 | elif 'bart' in job_name:
64 | func_ = 'BartEditor'
65 |
66 | return data_path_, idf_path_, func_
67 |
68 | def make_interactive_eval_job(name, func, model_path, idf_path):
69 | out_dir_ = os.path.join(int_out_dir, name)
70 | if not os.path.exists(out_dir_):
71 | os.makedirs(out_dir_)
72 | n_episodes_ = [1,2,3]
73 | n_oracle_edits_ = [6,3,2]
74 | arg_tuples = []
75 |
76 | for n_ep, n_oe in zip(n_episodes_, n_oracle_edits_):
77 | out_path_ = os.path.join(out_dir_, f"{n_ep}x{n_oe}")
78 | arg_tuple = (
79 | func, model_path, out_path_,
80 | data_path, max_data, idf_path,
81 | n_ep, n_oe,
82 | adjacent_ops, complete_words, contiguous_edits,
83 | bleu_ngrams)
84 | arg_tuples.append(arg_tuple)
85 |
86 | job_path = os.path.join(int_job_dir, name)
87 | with open(job_path, 'wt') as f:
88 | for a in arg_tuples:
89 | f.write(command_template.format(*a) + '\n')
90 |
91 | if __name__ == '__main__':
92 |
93 | parser = argparse.ArgumentParser()
94 | parser.add_argument('--job_dir', type=str,
95 | default='jobs') # where to write job args
96 | parser.add_argument('--model_dir', type=str,
97 | default='models') # where to find model checkpoints
98 | parser.add_argument('--out_dir', type=str,
99 | default='out') # where evaluation jobs should write results
100 | parser.add_argument('--data_dir', type=str,
101 | default='data') # data dir
102 | args = parser.parse_args()
103 |
104 | # MAIN ============================
105 |
106 | ## DEFAULT SETTINGS
107 |
108 | job_dir = os.path.join(args.job_dir, 'main')
109 | if not os.path.exists(job_dir):
110 | os.makedirs(job_dir)
111 |
112 | model_dir = args.model_dir
113 | exp_name = 'dagger'
114 | model_name = 'cnn-bart-editor'
115 |
116 | func = 'BartEditor'
117 | model_path = os.path.join(model_dir, exp_name, model_name, 'WEIGHTS.bin')
118 | out_dir = os.path.join(args.out_dir, 'main')
119 | out_path = os.path.join(out_dir, 'main.pickle')
120 | data_dir = args.data_dir
121 | data_path = os.path.join(data_dir, 'cnn_bart')
122 | max_data = 2000
123 | idf_path = os.path.join(data_dir, 'misc', 'cnn_bart_idfs.pickle')
124 | n_episodes = 4
125 | n_oracle_edits = 3
126 | adjacent_ops = True
127 | complete_words = True
128 | contiguous_edits = True
129 | bleu_ngrams = 1
130 |
131 | ## Different Oracles
132 |
133 | out_dir_ = os.path.join(out_dir, 'oracles')
134 | if not os.path.exists(out_dir_):
135 | os.makedirs(out_dir_)
136 | oracle_names = ['unrestricted', 'contiguous', 'adjacent']
137 | oracle_settings = [(False, False), (False, True), (True, False)]
138 | arg_tuples = []
139 | for (adjacent_ops_, contiguous_edits_), job_name in zip(oracle_settings, oracle_names):
140 | out_path_ = os.path.join(out_dir_, job_name)
141 | arg_tuple = (
142 | func, model_path, out_path_,
143 | data_path, max_data, idf_path,
144 | n_episodes, n_oracle_edits,
145 | adjacent_ops_, complete_words, contiguous_edits_,
146 | bleu_ngrams)
147 | arg_tuples.append(arg_tuple)
148 |
149 | job_path = os.path.join(job_dir, 'oracles')
150 | with open(job_path, 'wt') as f:
151 | for a in arg_tuples:
152 | f.write(command_template.format(*a) + '\n')
153 |
154 |
155 | ## DAGGER
156 |
157 | model_names = os.listdir(os.path.join(model_dir, exp_name))
158 | block_list = []
159 | model_names = [n for n in model_names if n not in block_list]
160 |
161 | out_dir_ = os.path.join(out_dir, exp_name)
162 | if not os.path.exists(out_dir_):
163 | os.makedirs(out_dir_)
164 | arg_tuples = []
165 | for job_name in model_names:
166 | model_path_ = os.path.join(model_dir, exp_name, job_name, 'WEIGHTS.bin')
167 | out_path_ = os.path.join(out_dir_, job_name)
168 |
169 | data_path_, idf_path_, func_ = job_args_from_job_name(job_name) #<= placeholder because name format incorrect for debug
170 |
171 | arg_tuple = (
172 | func_, model_path_, out_path_,
173 | data_path_, max_data, idf_path_,
174 | n_episodes, n_oracle_edits,
175 | adjacent_ops, complete_words, contiguous_edits,
176 | bleu_ngrams)
177 | arg_tuples.append(arg_tuple)
178 | #baseline
179 | arg_tuples.append(('Baseline', model_path, os.path.join(out_dir_, 'baseline'),
180 | data_path, max_data, idf_path, n_episodes, n_oracle_edits,
181 | adjacent_ops, complete_words, contiguous_edits, bleu_ngrams))
182 |
183 | job_strings = [command_template.format(*a) for a in arg_tuples]
184 |
185 | job_path = os.path.join(job_dir, exp_name)
186 | with open(job_path, 'wt') as f:
187 | for line in job_strings:
188 | f.write(line + '\n')
189 |
190 | # Interactive vs. One Shot ===========================
191 |
192 | int_job_dir = os.path.join(args.job_dir, 'interactive')
193 | if not os.path.exists(int_job_dir):
194 | os.makedirs(int_job_dir)
195 | int_out_dir = os.path.join(out_dir, 'interactive')
196 | int_model_dir = os.path.join(model_dir, 'dagger') # evaluation done with dagger models
197 |
198 | names = [
199 | 'cnn-bart-editor', 'cnn-bart_editor_large',
200 | 'cnn-bart-s2s'
201 | ]
202 | funcs = [
203 | 'BartEditor', 'BartLargeEditor',
204 | 'BartS2S'
205 | ]
206 | idf_names = [
207 | 'cnn_bart_idfs', 'cnn_bart_idfs',
208 | 'cnn_bart_idfs'
209 | ]
210 |
211 | int_job_args = []
212 | for n,f,idf_n in zip(names, funcs, idf_names):
213 | mp = os.path.join(int_model_dir, n, 'WEIGHTS.bin')
214 | ip = os.path.join(data_dir, 'misc', idf_n + '.pickle')
215 | int_job_args.append((n,f,mp,ip))
216 |
217 | for a in int_job_args:
218 | make_interactive_eval_job(*a)
219 |
--------------------------------------------------------------------------------
/scripts/run_eval.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | from infosol.evaluate import EvaluateBartEditor, EvaluateBertEditor, EvaluateBartS2S, EvaluateBartLarge, EvaluateNoModel
4 |
5 | if __name__ == '__main__':
6 |
7 | meta_parser = argparse.ArgumentParser()
8 | meta_parser.add_argument('--args_path', type=str)
9 | meta_parser.add_argument('--cuda_device', type=int)
10 | meta_args = meta_parser.parse_args()
11 |
12 | parser = argparse.ArgumentParser()
13 | subparsers = parser.add_subparsers()
14 |
15 | parser_barteditor = subparsers.add_parser('BartEditor')
16 | parser_barteditor.set_defaults(func=EvaluateBartEditor)
17 | EvaluateBartEditor.add_args(parser_barteditor)
18 |
19 | parser_berteditor = subparsers.add_parser('BertEditor')
20 | parser_berteditor.set_defaults(func=EvaluateBertEditor)
21 | EvaluateBertEditor.add_args(parser_berteditor)
22 |
23 | parser_barts2s = subparsers.add_parser('BartS2S')
24 | parser_barts2s.set_defaults(func=EvaluateBartS2S)
25 | EvaluateBartS2S.add_args(parser_barts2s)
26 |
27 | parser_bartlarge = subparsers.add_parser('BartLargeEditor')
28 | parser_bartlarge.set_defaults(func=EvaluateBartLarge)
29 | EvaluateBartLarge.add_args(parser_bartlarge)
30 |
31 | parser_baseline = subparsers.add_parser('Baseline')
32 | parser_baseline.set_defaults(func=EvaluateNoModel)
33 | EvaluateNoModel.add_args(parser_baseline)
34 |
35 | with open(meta_args.args_path, 'rt') as f:
36 | for i,line in enumerate(f):
37 | print(f'#### On job {i} ####')
38 | args_list = line.strip().split(' ')
39 | args_list.extend(['--cuda_device', str(meta_args.cuda_device)])
40 | args = parser.parse_args(args_list)
41 | eval_instance = args.func()
42 | eval_instance.setup(args)
43 | eval_instance.run()
44 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup, find_packages
2 |
3 | setup(name='infosol', version='1.0', packages=find_packages())
4 |
--------------------------------------------------------------------------------