├── README.md ├── assets └── method.png ├── collision_ext_sum.py ├── collision_paraphrase.py ├── collision_polyencoder.py ├── collision_retrieval.py ├── constant.py ├── models ├── __init__.py ├── bert_layers.py ├── bert_models.py ├── polyencoder │ ├── __init__.py │ ├── config_polyencoder.py │ ├── layers.py │ ├── modeling_polyencoder.py │ └── tokenization_polyencoder.py ├── presumm │ ├── __init__.py │ ├── config.py │ ├── encoder.py │ ├── model_builder.py │ └── neural.py └── scorer.py ├── requirements.txt ├── scripts ├── ft_bert_lm.py └── ft_polyencoder_lm.py └── utils ├── __init__.py ├── constraints_utils.py ├── logging_utils.py ├── optimization_utils.py └── tokenizer_utils.py /README.md: -------------------------------------------------------------------------------- 1 | # Adversarial Semantic Collisions 2 | This repo contains implementation for EMNLP 2020 paper: 3 | [Adversarial Semantic Collisions](https://arxiv.org/pdf/2011.04743.pdf). 4 | ![method](assets/method.png) 5 | 6 | ## Dependencies 7 | The code is tested on Python 3 with torch==1.4.0 and transformers==2.8.0. 8 | Other requirements can be found in `requirements.txt`. 9 | 10 | ## Datasets and Models 11 | We considered four tasks in this paper. The data and models can be downloaded from [here](https://zenodo.org/record/4263446#.X6iYUnVKjCJ) (the decompressed file can take upto 18GB of disk space). 12 | Please extract the data and models into `COLLISION_DIR` defined in `constant.py`. 13 | 14 | ### Target Models 15 | * For paraphrase identification task, the models are trained with HuggingFace example [scripts](https://github.com/huggingface/transformers/blob/master/examples/text-classification/run_glue.py). 16 | * For response suggestions task, the models are collected from [ParlAI](https://parl.ai/projects/polyencoder/). 17 | * For document retrieval task, the models are collected from [Birch](https://github.com/castorini/birch). 18 | * For extractive summarization task, the models are collected from [PreSumm](https://github.com/nlpyang/PreSumm). 19 | 20 | 21 | ### Language Models for Natural Collisions 22 | For generating natural collisions (see Section 4.2.2 in our paper), we need to train language models (LMs) with the same 23 | vocabulary as the target models we are attacking. 24 | We provide pre-trained LMs in the download link above and their training scripts in `scipts/` folder. 25 | LMs are fine-tuned from BERT or [Poly-encoder](https://arxiv.org/pdf/1905.01969.pdf) on [WikiText-103](https://www.salesforce.com/products/einstein/ai-research/the-wikitext-dependency-language-modeling-dataset/). 26 | 27 | 28 | ## Generating Semantic Collisions 29 | Now we can run collision attacks on the test set for the four tasks. 30 | We provide example scripts for as following, where (A), (R), (N) denotes aggressive, 31 | regularized and natural collisions respectively. 32 | The provided attack arguments are the same for all test examples and 33 | a dedicated hyper-parameter search for each single example could yield better attack performance. 34 | 35 | Add `--verbose` flag if you want to check the intermediate procedure of generating collisions. 36 | Remove `--fp16` flag if you did not install [apex](https://github.com/NVIDIA/apex) for mixed precision training. 37 | Set `--gpu=i` to use the i'th GPU. 38 | 39 | ### **Paraphrase Identification** 40 | ``` 41 | (A) python3 collision_paraphrase.py --topk=30 --perturb_iter=30 --max_iter=10 --stemp=1.0 --lr=1e-3 --seq_len=20 --fp16 42 | 43 | (R) python3 collision_paraphrase.py --topk=15 --perturb_iter=30 --max_iter=10 --stemp=1.0 --lr=1e-3 --seq_len=30 --regularize --beta=0.8 --fp16 44 | 45 | (N) python3 collision_paraphrase.py --topk=128 --perturb_iter=5 --stemp=0.1 --lr=1e-3 --seq_len=25 --nature --beta=0.05 --fp16 46 | ``` 47 | 48 | ### **Response Suggestions** 49 | ``` 50 | (A) python3 collision_polyencoder.py --topk=30 --perturb_iter=30 --stemp=1.0 --lr=1e-3 --seq_len=20 --poly --num_filters=1000 --fp16 51 | 52 | (R) python3 collision_polyencoder.py --topk=20 --perturb_iter=30 --stemp=1.0 --lr=1e-3 --seq_len=25 --regularize --beta=0.8 --poly --num_filters=1000 --fp16 53 | 54 | (N) python3 collision_polyencoder.py --topk=128 --perturb_iter=5 --stemp=0.1 --lr=1e-3 --seq_len=20 --nature --beta=0.15 --poly --num_filters=1000 --fp16 55 | ``` 56 | Remove `--poly` flag if you want to attack the Bi-encoder model. 57 | 58 | ### **Document Retrieval** 59 | ``` 60 | (A) python3 collision_retrieval.py --num_beams=5 --topk=50 --perturb_iter=30 --stemp=1.0 --lr=1e-3 --seq_len=30 --num_filters=1000 --fp16 61 | 62 | (R) python3 collision_retrieval.py --num_beams=5 --topk=40 --perturb_iter=30 --stemp=1.0 --lr=1e-3 --seq_len=60 --num_filters=1000 --regularize --beta=0.85 --fp16 63 | 64 | (N) python3 collision_retrieval.py --num_beams=10 --topk=150 --perturb_iter=5 --stemp=0.1 --lr=1e-3 --seq_len=35 --num_filters=1000 --nature --beta=0.02 --fp16 65 | 66 | ``` 67 | The default dataset is Core17, change `--data_name=core18` flag if you want to generate collisions for Core18. 68 | Add `--verbose` flag to see how the document ranks change after inserting collisions. 69 | 70 | ### **Extractive Summarization** 71 | ``` 72 | (A) python3 collision_ext_sum.py --num_beams=5 --topk=10 --perturb_iter=30 --stemp=1.0 --lr=1e-3 --seq_len=15 --fp16 73 | 74 | (R) python3 collision_ext_sum.py --num_beams=5 --topk=10 --perturb_iter=30 --stemp=1.0 --lr=1e-3 --seq_len=30 --beta=0.8 --regularize --fp16 75 | 76 | (N) python3 collision_ext_sum.py --num_beams=5 --topk=64 --perturb_iter=5 --stemp=1.0 --lr=1e-3 --seq_len=20 --beta=0.02 --nature --fp16 77 | ``` 78 | 79 | ## Reference 80 | ``` 81 | @inproceedings{song2020adversarial, 82 | title={Adversarial Semantic Collisions}, 83 | author={Song, Congzheng and Rush, Alexander M and Shmatikov, Vitaly}, 84 | booktitle={Proceedings of the 2020 Conference on Empirical Methods in Natural Language Processing (EMNLP)}, 85 | pages={4198--4210}, 86 | year={2020} 87 | } 88 | ``` 89 | -------------------------------------------------------------------------------- /assets/method.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csong27/collision-bert/43eda087bf6d632bdb150d98e934206327f8d082/assets/method.png -------------------------------------------------------------------------------- /collision_paraphrase.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import torch 4 | import tqdm 5 | from transformers import BertTokenizer, glue_processors 6 | 7 | from constant import BOS_TOKEN, BERT_LM_MODEL_DIR, PARA_DIR 8 | from models.bert_models import BertForConcatSequenceClassification, BertForLM 9 | from models.scorer import SentenceScorer 10 | from utils.constraints_utils import create_constraints, get_sub_masks, get_inputs_filter_ids, STOPWORDS 11 | from utils.logging_utils import log 12 | from utils.optimization_utils import perturb_logits 13 | from utils.tokenizer_utils import valid_tokenization 14 | 15 | parser = argparse.ArgumentParser() 16 | parser.add_argument('--gpu', type=str, default="0", help='gpu id') 17 | parser.add_argument('--stemp', type=float, default=0.1, help='temperature of softmax') 18 | parser.add_argument('--lr', type=float, default=0.01, help='optimization step size') 19 | parser.add_argument('--max_iter', type=int, default=10, help='maximum iteraiton') 20 | parser.add_argument('--seq_len', type=int, default=32, help='Sequence length') 21 | parser.add_argument('--min_len', type=int, default=5, help='Min sequence length') 22 | parser.add_argument("--beta", default=0.0, type=float, help="Coefficient for language model loss.") 23 | parser.add_argument("--task_name", default='mrpc', type=str) 24 | parser.add_argument("--model_dir", default=PARA_DIR, type=str, help="Path to pre-trained model") 25 | parser.add_argument("--lm_model_dir", default=BERT_LM_MODEL_DIR, type=str, help="Path to pre-trained language model") 26 | parser.add_argument('--perturb_iter', type=int, default=5, help='PPLM iteration') 27 | parser.add_argument("--kl_scale", default=0.0, type=float, help="KL divergence coefficient") 28 | parser.add_argument("--topk", default=50, type=int, help="Top k sampling for beam search") 29 | parser.add_argument("--num_beams", default=10, type=int, help="Number of beams") 30 | parser.add_argument('--verbose', action='store_true', help='Print every iteration') 31 | parser.add_argument('--nature', action='store_true', help='Nature collision') 32 | parser.add_argument('--regularize', action='store_true', help='Use regularization to decrease perplexity') 33 | parser.add_argument('--fp16', action='store_true', help='fp16') 34 | parser.add_argument("--num_filters", default=500, type=int, help="Number of num_filters words to be filtered") 35 | 36 | args = parser.parse_args() 37 | 38 | 39 | def gen_aggressive_collision(inputs_a, model, tokenizer, device, lm_model=None): 40 | seq_len = args.seq_len 41 | 42 | word_embedding = model.get_input_embeddings().weight.detach() 43 | if lm_model is not None: 44 | lm_word_embedding = lm_model.get_input_embeddings().weight.detach() 45 | 46 | vocab_size = word_embedding.size(0) 47 | input_ids = tokenizer.encode(inputs_a) 48 | sub_mask = get_sub_masks(tokenizer, device) 49 | stopwords_mask = create_constraints(seq_len, tokenizer, device) 50 | 51 | input_mask = torch.zeros(vocab_size, device=device) 52 | input_ids = torch.tensor(input_ids, device=device).unsqueeze(0) 53 | # prevent generating the words in the input 54 | input_mask[input_ids] = -1e9 55 | batch_input_ids = torch.cat([input_ids] * args.topk, 0) 56 | 57 | def relaxed_to_word_embs(x): 58 | # convert relaxed inputs to word embedding by softmax attention 59 | masked_x = x + input_mask + sub_mask 60 | if args.regularize: 61 | masked_x += stopwords_mask 62 | 63 | p = torch.softmax(masked_x / args.stemp, -1) 64 | x = torch.mm(p, word_embedding) 65 | # add embeddings for period and SEP 66 | x = torch.cat([x, word_embedding[tokenizer.sep_token_id].unsqueeze(0)]) 67 | return p, x.unsqueeze(0) 68 | 69 | def get_lm_loss(p): 70 | x = torch.mm(p.detach(), lm_word_embedding).unsqueeze(0) 71 | return lm_model(inputs_embeds=x, one_hot_labels=p.unsqueeze(0))[0] 72 | 73 | # some constants 74 | sep_tensor = torch.tensor([tokenizer.sep_token_id] * args.topk, device=device) 75 | batch_sep_embeds = word_embedding[sep_tensor].unsqueeze(1) 76 | labels = torch.ones((1,), dtype=torch.long, device=device) 77 | 78 | best_collision = None 79 | best_score = -1e9 80 | prev_score = -1e9 81 | 82 | var_size = (seq_len, vocab_size) 83 | z_i = torch.zeros(*var_size, requires_grad=True, device=device) 84 | for it in range(args.max_iter): 85 | optimizer = torch.optim.Adam([z_i], lr=args.lr) 86 | for j in range(args.perturb_iter): 87 | optimizer.zero_grad() 88 | # relaxation 89 | p_inputs, inputs_embeds = relaxed_to_word_embs(z_i) 90 | # forward to BERT with relaxed inputs 91 | loss = model(input_ids, inputs_embeds=inputs_embeds, labels=labels)[0] 92 | if args.beta > 0.: 93 | lm_loss = get_lm_loss(p_inputs) 94 | loss = args.beta * lm_loss + (1 - args.beta) * loss 95 | loss.backward() 96 | optimizer.step() 97 | if args.verbose and (j + 1) % 10 == 0: 98 | log(f'It{it}-{j + 1}, loss={loss.item()}') 99 | 100 | # detach to free GPU memory 101 | z_i = z_i.detach() 102 | 103 | _, topk_tokens = torch.topk(z_i, args.topk) 104 | probs_i = torch.softmax(z_i / args.stemp, -1).unsqueeze(0).expand(args.topk, seq_len, vocab_size) 105 | 106 | output_so_far = None 107 | # beam search left to right 108 | for t in range(seq_len): 109 | t_topk_tokens = topk_tokens[t] 110 | t_topk_onehot = torch.nn.functional.one_hot(t_topk_tokens, vocab_size).float() 111 | next_clf_scores = [] 112 | for j in range(args.num_beams): 113 | next_beam_scores = torch.zeros(tokenizer.vocab_size, device=device) - 1e9 114 | if output_so_far is None: 115 | context = probs_i.clone() 116 | else: 117 | output_len = output_so_far.shape[1] 118 | beam_topk_output = output_so_far[j].unsqueeze(0).expand(args.topk, output_len) 119 | beam_topk_output = torch.nn.functional.one_hot(beam_topk_output, vocab_size) 120 | context = torch.cat([beam_topk_output.float(), probs_i[:, output_len:].clone()], 1) 121 | context[:, t] = t_topk_onehot 122 | context_embeds = torch.einsum('blv,vh->blh', context, word_embedding) 123 | context_embeds = torch.cat([context_embeds, batch_sep_embeds], 1) 124 | clf_logits = model(input_ids=batch_input_ids, inputs_embeds=context_embeds)[0] 125 | clf_scores = torch.log_softmax(clf_logits, -1)[:, 1].detach().float() 126 | next_beam_scores.scatter_(0, t_topk_tokens, clf_scores) 127 | next_clf_scores.append(next_beam_scores.unsqueeze(0)) 128 | 129 | next_clf_scores = torch.cat(next_clf_scores, 0) 130 | next_clf_scores = next_clf_scores + input_mask + sub_mask 131 | next_scores = next_clf_scores 132 | if args.regularize: 133 | next_scores += stopwords_mask[t] 134 | 135 | if output_so_far is None: 136 | next_scores[1:] = -1e9 137 | 138 | # re-organize to group the beam together 139 | # (we are keeping top hypothesis across beams) 140 | next_scores = next_scores.view(1, args.num_beams * vocab_size) # (batch_size, num_beams * vocab_size) 141 | next_scores, next_tokens = torch.topk(next_scores, args.num_beams, dim=1, largest=True, sorted=True) 142 | # next batch beam content 143 | next_sent_beam = [] 144 | for beam_token_rank, (beam_token_id, beam_token_score) in enumerate(zip(next_tokens[0], next_scores[0])): 145 | # get beam and token IDs 146 | beam_id = beam_token_id // vocab_size 147 | token_id = beam_token_id % vocab_size 148 | next_sent_beam.append((beam_token_score, token_id, beam_id)) 149 | 150 | next_batch_beam = next_sent_beam 151 | # sanity check / prepare next batch 152 | assert len(next_batch_beam) == args.num_beams 153 | beam_tokens = torch.tensor([x[1] for x in next_batch_beam], device=device) 154 | beam_idx = torch.tensor([x[2] for x in next_batch_beam], device=device) 155 | 156 | # re-order batch 157 | if output_so_far is None: 158 | output_so_far = beam_tokens.unsqueeze(1) 159 | else: 160 | output_so_far = output_so_far[beam_idx, :] 161 | output_so_far = torch.cat([output_so_far, beam_tokens.unsqueeze(1)], dim=-1) 162 | 163 | pad_output_so_far = torch.cat([output_so_far, sep_tensor[:args.num_beams].unsqueeze(1)], 1) 164 | concat_input_ids = torch.cat([batch_input_ids[:args.num_beams], pad_output_so_far], 1) 165 | token_type_ids = torch.cat([torch.zeros_like(batch_input_ids[:args.num_beams]), 166 | torch.ones_like(pad_output_so_far)], 1) 167 | clf_logits = model(input_ids=concat_input_ids, token_type_ids=token_type_ids)[0] 168 | actual_clf_scores = torch.softmax(clf_logits, -1)[:, 1] 169 | sorter = torch.argsort(actual_clf_scores, -1, descending=True) 170 | if args.verbose: 171 | decoded = [ 172 | f'{actual_clf_scores[i].item() * 100:.2f}%, ' 173 | f'{tokenizer.decode(output_so_far[i].cpu().tolist())}' 174 | for i in sorter 175 | ] 176 | log(f'It={it}, target={inputs_a} | ' + ' | '.join(decoded)) 177 | 178 | valid_idx = sorter[0] 179 | valid = False 180 | for idx in sorter: 181 | valid, _ = valid_tokenization(output_so_far[idx], tokenizer) 182 | if valid: 183 | valid_idx = idx 184 | break 185 | 186 | # re-initialize z_i 187 | curr_best = output_so_far[valid_idx] 188 | next_z_i = torch.nn.functional.one_hot(curr_best, vocab_size).float() 189 | eps = 0.1 190 | next_z_i = (next_z_i * (1 - eps)) + (1 - next_z_i) * eps / (vocab_size - 1) 191 | z_i = torch.nn.Parameter(torch.log(next_z_i), True) 192 | 193 | curr_score = actual_clf_scores[valid_idx].item() 194 | if valid and curr_score > best_score: 195 | best_score = curr_score 196 | best_collision = tokenizer.decode(curr_best.cpu().tolist()) 197 | 198 | if prev_score >= curr_score: 199 | break 200 | prev_score = curr_score 201 | 202 | return best_collision, best_score 203 | 204 | 205 | def find_filters(query, model, tokenizer, device, k=500): 206 | words = [w for w in tokenizer.vocab if w.isalpha() and w not in STOPWORDS] 207 | inputs = tokenizer.batch_encode_plus([[query, w] for w in words], 208 | pad_to_max_length=True) 209 | all_input_ids = torch.tensor(inputs['input_ids'], device=device) 210 | all_token_type_ids = torch.tensor(inputs['token_type_ids'], device=device) 211 | all_attention_masks = torch.tensor(inputs['attention_mask'], device=device) 212 | n = len(words) 213 | batch_size = 512 214 | n_batches = n // batch_size + 1 215 | all_scores = [] 216 | for i in tqdm.trange(n_batches, desc='Filtering vocab'): 217 | input_ids = all_input_ids[i * batch_size: (i + 1) * batch_size] 218 | token_type_ids = all_token_type_ids[i * batch_size: (i + 1) * batch_size] 219 | attention_masks = all_attention_masks[i * batch_size: (i + 1) * batch_size] 220 | outputs = model.forward(input_ids, attention_masks, token_type_ids) 221 | scores = outputs[0][:, 1] 222 | all_scores.append(scores) 223 | 224 | all_scores = torch.cat(all_scores) 225 | _, top_indices = torch.topk(all_scores, k) 226 | filters = set([words[i.item()] for i in top_indices]) 227 | return [w for w in filters if w.isalpha()] 228 | 229 | 230 | def gen_natural_collision(inputs_a, inputs_b, model, tokenizer, device, lm_model, eval_lm_model=None): 231 | collition_init = tokenizer.convert_tokens_to_ids([BOS_TOKEN]) 232 | start_idx = 1 233 | num_beams = args.num_beams 234 | repetition_penalty = 5.0 235 | curr_len = len(collition_init) 236 | 237 | filters = find_filters(inputs_a, model, tokenizer, device, k=args.num_filters) 238 | best_ids = get_inputs_filter_ids(inputs_a, tokenizer) 239 | best_ids += get_inputs_filter_ids(inputs_b, tokenizer) 240 | 241 | # scores for each sentence in the beam 242 | beam_scores = torch.zeros((num_beams,), dtype=torch.float, device=device) 243 | beam_scores[1:] = -1e9 244 | 245 | output_so_far = torch.tensor([collition_init] * num_beams, device=device) 246 | past = None 247 | vocab_size = tokenizer.vocab_size 248 | topk = args.topk 249 | input_ids = tokenizer.encode(inputs_a) 250 | input_ids = torch.tensor(input_ids, device=device).unsqueeze(0) 251 | batch_input_ids = torch.cat([input_ids] * topk, 0) 252 | sep_tensor = torch.tensor([tokenizer.sep_token_id] * topk, device=device) 253 | input_mask = torch.zeros(vocab_size, device=device) 254 | # prevent output num_filters neighbor words 255 | input_mask[best_ids] = -1e9 256 | input_mask[tokenizer.convert_tokens_to_ids(['.', '@', '='])] = -1e9 257 | unk_ids = tokenizer.encode('', add_special_tokens=False) 258 | input_mask[unk_ids] = -1e9 259 | input_mask[tokenizer.convert_tokens_to_ids(filters)] = -1e9 260 | 261 | first_mask = get_sub_masks(tokenizer, device) 262 | is_first = True 263 | word_embedding = model.get_input_embeddings().weight.detach() 264 | batch_sep_embeds = word_embedding[sep_tensor].unsqueeze(1) 265 | batch_labels = torch.ones((num_beams,), dtype=torch.long, device=device) 266 | 267 | def classifier_loss(p, context): 268 | context = torch.nn.functional.one_hot(context, len(word_embedding)) 269 | one_hot = torch.cat([context.float(), p.unsqueeze(1)], 1) 270 | x = torch.einsum('blv,vh->blh', one_hot, word_embedding) 271 | # add embeddings for SEP 272 | x = torch.cat([x, batch_sep_embeds[:num_beams]], 1) 273 | cls_loss = model(batch_input_ids[:num_beams], inputs_embeds=x, labels=batch_labels)[0] 274 | return cls_loss 275 | 276 | best_score = -1e9 277 | best_collision = None 278 | 279 | while curr_len < args.seq_len: 280 | model_inputs = lm_model.prepare_inputs_for_generation(output_so_far, past=past) 281 | outputs = lm_model(**model_inputs) 282 | present = outputs[1] 283 | # (batch_size * num_beams, vocab_size) 284 | next_token_logits = outputs[0][:, -1, :] 285 | lm_scores = torch.log_softmax(next_token_logits, dim=-1) 286 | 287 | if args.perturb_iter > 0: 288 | # perturb internal states of LM 289 | def target_model_wrapper(p): 290 | return classifier_loss(p, output_so_far.detach()[:, start_idx:]) 291 | 292 | next_token_logits = perturb_logits( 293 | next_token_logits, 294 | args.lr, 295 | target_model_wrapper, 296 | num_iterations=args.perturb_iter, 297 | kl_scale=args.kl_scale, 298 | temperature=args.stemp, 299 | device=device, 300 | verbose=args.verbose 301 | ) 302 | 303 | if repetition_penalty > 1.0: 304 | lm_model.enforce_repetition_penalty_(next_token_logits, 1, num_beams, output_so_far, repetition_penalty) 305 | next_token_logits = next_token_logits / args.stemp 306 | 307 | # (batch_size * num_beams, vocab_size) 308 | next_lm_scores = lm_scores + beam_scores[:, None].expand_as(lm_scores) 309 | _, topk_tokens = torch.topk(next_token_logits, topk) 310 | 311 | # get target model score here 312 | next_clf_scores = [] 313 | for i in range(num_beams): 314 | next_beam_scores = torch.zeros(tokenizer.vocab_size, device=device) - 1e9 315 | if output_so_far.shape[1] > start_idx: 316 | curr_beam_topk = output_so_far[i, start_idx:].unsqueeze(0).expand( 317 | topk, output_so_far.shape[1] - start_idx) 318 | # (topk, curr_len + next_token + sep) 319 | curr_beam_topk = torch.cat([curr_beam_topk, topk_tokens[i].unsqueeze(1), sep_tensor.unsqueeze(1)], 1) 320 | else: 321 | curr_beam_topk = torch.cat([topk_tokens[i].unsqueeze(1), sep_tensor.unsqueeze(1)], 1) 322 | concat_input_ids = torch.cat([batch_input_ids, curr_beam_topk], 1) 323 | token_type_ids = torch.cat([torch.zeros_like(batch_input_ids), torch.ones_like(curr_beam_topk), ], 1) 324 | clf_logits = model(input_ids=concat_input_ids, token_type_ids=token_type_ids)[0] 325 | clf_scores = torch.log_softmax(clf_logits, -1)[:, 1].detach() 326 | next_beam_scores.scatter_(0, topk_tokens[i], clf_scores.float()) 327 | next_clf_scores.append(next_beam_scores.unsqueeze(0)) 328 | next_clf_scores = torch.cat(next_clf_scores, 0) 329 | 330 | if is_first: 331 | next_clf_scores += beam_scores[:, None].expand_as(lm_scores) 332 | next_clf_scores += first_mask 333 | is_first = False 334 | 335 | next_scores = (1 - args.beta) * next_clf_scores + args.beta * next_lm_scores 336 | next_scores += input_mask 337 | 338 | # re-organize to group the beam together 339 | # (we are keeping top hypothesis across beams) 340 | next_scores = next_scores.view(num_beams * vocab_size) 341 | next_lm_scores = next_lm_scores.view(num_beams * vocab_size) 342 | next_scores, next_tokens = torch.topk(next_scores, num_beams, largest=True, sorted=True) 343 | next_lm_scores = next_lm_scores[next_tokens] 344 | 345 | # next batch beam content 346 | next_sent_beam = [] 347 | for beam_token_rank, (beam_token_id, beam_token_score) in enumerate(zip(next_tokens, next_lm_scores)): 348 | # get beam and token IDs 349 | beam_id = beam_token_id // vocab_size 350 | token_id = beam_token_id % vocab_size 351 | next_sent_beam.append((beam_token_score, token_id, beam_id)) 352 | 353 | next_batch_beam = next_sent_beam 354 | 355 | # sanity check / prepare next batch 356 | assert len(next_batch_beam) == num_beams 357 | beam_scores = beam_scores.new([x[0] for x in next_batch_beam]) 358 | beam_tokens = output_so_far.new([x[1] for x in next_batch_beam]) 359 | beam_idx = output_so_far.new([x[2] for x in next_batch_beam]) 360 | 361 | # re-order batch 362 | output_so_far = output_so_far[beam_idx, :] 363 | output_so_far = torch.cat([output_so_far, beam_tokens.unsqueeze(1)], dim=-1) 364 | 365 | # sanity check 366 | pad_output_so_far = torch.cat([output_so_far[:, start_idx:], sep_tensor[:num_beams].unsqueeze(1)], 1) 367 | concat_input_ids = torch.cat([batch_input_ids[:num_beams], pad_output_so_far], 1) 368 | token_type_ids = torch.cat([torch.zeros_like(batch_input_ids[:num_beams]), 369 | torch.ones_like(pad_output_so_far)], 1) 370 | clf_logits = model(input_ids=concat_input_ids, token_type_ids=token_type_ids)[0] 371 | actual_clf_scores = torch.softmax(clf_logits, 1)[:, 1] 372 | sorter = torch.argsort(actual_clf_scores, -1, descending=True) 373 | if args.verbose: 374 | decoded = [ 375 | f'{actual_clf_scores[i].item() * 100:.2f}%, ' 376 | f'{tokenizer.decode(output_so_far[i, start_idx:].cpu().tolist())}' 377 | for i in sorter 378 | ] 379 | log(f'Target={inputs_a}, ' + ' | '.join(decoded)) 380 | 381 | if curr_len > args.min_len: 382 | valid_idx = sorter[0] 383 | valid = False 384 | for idx in sorter: 385 | valid, _ = valid_tokenization(output_so_far[idx, start_idx:], tokenizer) 386 | if valid: 387 | valid_idx = idx 388 | break 389 | curr_score = actual_clf_scores[valid_idx].item() 390 | curr_collision = tokenizer.decode(output_so_far[valid_idx, start_idx:].cpu().tolist()) 391 | if valid and curr_score > best_score: 392 | best_score = curr_score 393 | best_collision = curr_collision 394 | 395 | if args.verbose: 396 | lm_perp = eval_lm_model.perplexity(curr_collision) 397 | log(f'LM perp={lm_perp.item()}') 398 | 399 | # re-order internal states 400 | past = lm_model._reorder_cache(present, beam_idx) 401 | # update current length 402 | curr_len = curr_len + 1 403 | return best_collision, best_score 404 | 405 | 406 | def main(): 407 | device = torch.device(f'cuda:{args.gpu}') 408 | 409 | model_dir = os.path.join(args.model_dir, args.task_name.lower()) 410 | tokenizer = BertTokenizer.from_pretrained(model_dir) 411 | log(f'Loading model from {model_dir}') 412 | 413 | model = BertForConcatSequenceClassification.from_pretrained(model_dir) 414 | model.to(device) 415 | model.eval() 416 | for param in model.parameters(): 417 | param.requires_grad = False 418 | 419 | eval_lm_model = SentenceScorer(device) 420 | lm_model = BertForLM.from_pretrained(args.lm_model_dir) 421 | lm_model.to(device) 422 | lm_model.eval() 423 | for param in lm_model.parameters(): 424 | param.requires_grad = False 425 | 426 | if args.fp16: 427 | from apex import amp 428 | model, lm_model = amp.initialize([model, lm_model]) 429 | 430 | log(f'Loading data from {args.task_name.upper()}') 431 | data = glue_processors[args.task_name.lower()]().get_dev_examples(model_dir) 432 | 433 | n = 0 434 | for inputs in data: 435 | if inputs.label == '1': 436 | n += 1 437 | if args.nature: 438 | collision, score = gen_natural_collision( 439 | inputs.text_a, inputs.text_b, model, tokenizer, device, 440 | lm_model=lm_model, eval_lm_model=eval_lm_model) 441 | else: 442 | collision, score = gen_aggressive_collision( 443 | inputs.text_a, model, tokenizer, device, lm_model=lm_model) 444 | 445 | lm_perp = eval_lm_model.perplexity(collision) 446 | msg = f'Input={inputs.text_a}\n' \ 447 | f'Ground truth paraphrase={inputs.text_b}\n' \ 448 | f'Collision={collision}\n' \ 449 | f'Confidence of being paraphrase={score}\n' \ 450 | f'LM perp={lm_perp.item()}\n' 451 | log(msg) 452 | 453 | 454 | if __name__ == '__main__': 455 | main() 456 | -------------------------------------------------------------------------------- /collision_polyencoder.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | 4 | import numpy as np 5 | import torch 6 | import tqdm 7 | from pattern.text.en import singularize, pluralize 8 | 9 | from constant import CONV2AI_PATH, POLY_LM_MODEL_DIR 10 | from models.polyencoder import PretrainedBiEncoder, PretrainedPolyEncoder, PolyEncoderTokenizer, PolyEncoderLM 11 | from models.scorer import SentenceScorer 12 | from utils.constraints_utils import create_poly_constraints, get_poly_sub_masks, get_inputs_filter_ids, STOPWORDS 13 | from utils.logging_utils import log 14 | from utils.optimization_utils import perturb_logits 15 | from utils.tokenizer_utils import valid_tokenization 16 | 17 | parser = argparse.ArgumentParser() 18 | parser.add_argument('--gpu', type=str, default="0", help='gpu id') 19 | parser.add_argument('--stemp', type=float, default=1.0, help='temperature of softmax') 20 | parser.add_argument('--ptemp', type=float, default=1.0, help='temperature of Sinkhorn permutation') 21 | parser.add_argument('--lr', type=float, default=0.001, help='optimization step size') 22 | parser.add_argument('--max_iter', type=int, default=10, help='maximum iteraiton') 23 | parser.add_argument('--seq_len', type=int, default=20, help='Sequence length') 24 | parser.add_argument('--min_len', type=int, default=5, help='Min sequence length') 25 | parser.add_argument("--beta", default=0.0, type=float, help="Coefficient for language model loss.") 26 | parser.add_argument('--verbose', action='store_true', help='Print every iteration') 27 | parser.add_argument('--poly', action='store_true', help='Polyencoder or Biencoder') 28 | parser.add_argument("--lm_model_dir", default=POLY_LM_MODEL_DIR, type=str, help="Path to pre-trained LM") 29 | parser.add_argument('--perturb_iter', type=int, default=5, help='PPLM iteration') 30 | parser.add_argument("--kl_scale", default=0.0, type=float, help="KL divergence coefficient") 31 | parser.add_argument("--topk", default=30, type=int, help="Top k sampling for beam search") 32 | parser.add_argument("--num_beams", default=5, type=int, help="Number of beams") 33 | parser.add_argument('--nature', action='store_true', help='Nature collision') 34 | parser.add_argument('--save', action='store_true', help='Save collision to file') 35 | parser.add_argument("--num_filters", default=1000, type=int, help="Number of num_filters words to be filtered") 36 | parser.add_argument('--regularize', action='store_true', help='Use regularize to decrease perplexity') 37 | parser.add_argument('--fp16', action='store_true', help='fp16') 38 | 39 | args = parser.parse_args() 40 | 41 | 42 | def load_persona_chat(): 43 | with open(CONV2AI_PATH, "r", encoding="utf-8") as f: 44 | dataset = json.loads(f.read())['valid'] 45 | 46 | data_list = [] 47 | for line in dataset: 48 | persona, utterances = line['personality'], line['utterances'] 49 | for utterance in utterances: 50 | candidates, history = utterance['candidates'], utterance['history'] 51 | data_list.append((persona, history, candidates)) 52 | return data_list 53 | 54 | 55 | def gen_aggressive_collision(inputs_a, margin, model, tokenizer, device, lm_model=None): 56 | seq_len = args.seq_len 57 | 58 | word_embedding = model.get_input_embeddings().weight.detach() 59 | if lm_model is not None: 60 | lm_word_embedding = lm_model.get_input_embeddings().weight.detach() 61 | 62 | vocab_size = word_embedding.size(0) 63 | input_ids = tokenizer.encode(inputs_a) 64 | sub_mask = get_poly_sub_masks(tokenizer, device) 65 | stopwords_mask = create_poly_constraints(seq_len, tokenizer, device) 66 | input_mask = torch.zeros(vocab_size, device=device) 67 | input_ids = torch.tensor(input_ids, device=device).unsqueeze(0) 68 | # prevent output num_filters neighbor words 69 | input_filter_ids = get_inputs_filter_ids(inputs_a, tokenizer) 70 | input_mask[input_filter_ids] = -1e9 71 | first_mask = torch.zeros_like(sub_mask) 72 | first_mask[[tokenizer.dict[w] for w in tokenizer.dict.tok2ind if w.startswith('__')]] = -1e10 73 | 74 | def relaxed_to_word_embs(x): 75 | # convert relaxed inputs to word embedding by softmax attention 76 | masked_x = x + input_mask + sub_mask 77 | if args.regularize: 78 | masked_x += stopwords_mask 79 | 80 | p = torch.softmax(masked_x / args.stemp, -1) 81 | x = torch.mm(p, word_embedding) 82 | # add embeddings for period and SEP 83 | x = torch.cat([word_embedding[tokenizer.cls_token_id].unsqueeze(0), x, 84 | word_embedding[tokenizer.sep_token_id].unsqueeze(0)]) 85 | return p, x.unsqueeze(0) 86 | 87 | def get_lm_loss(p): 88 | x = torch.mm(p.detach(), lm_word_embedding).unsqueeze(0) 89 | return lm_model(inputs_embeds=x, one_hot_labels=p.unsqueeze(0))[0] 90 | 91 | # some constants 92 | cls_tensor = torch.tensor([tokenizer.cls_token_id] * args.topk, device=device) 93 | sep_tensor = torch.tensor([tokenizer.sep_token_id] * args.topk, device=device) 94 | batch_cls_emb = word_embedding[cls_tensor].unsqueeze(1) 95 | batch_sep_emb = word_embedding[sep_tensor].unsqueeze(1) 96 | 97 | best_collision = None 98 | best_score = -1e9 99 | prev_score = -1e9 100 | 101 | var_size = (seq_len, vocab_size) 102 | z_i = torch.zeros(*var_size, requires_grad=True, device=device) 103 | for it in range(args.max_iter): 104 | optimizer = torch.optim.Adam([z_i], lr=args.lr) 105 | for j in range(args.perturb_iter): 106 | optimizer.zero_grad() 107 | # relaxation 108 | p_inputs, inputs_embeds = relaxed_to_word_embs(z_i) 109 | # forward to BERT with relaxed inputs 110 | outputs = model(ctxt_input_ids=input_ids, cand_inputs_embeds=inputs_embeds) 111 | loss = torch.relu(margin - outputs[0].squeeze()).sum() 112 | if args.beta > 0.: 113 | lm_loss = get_lm_loss(p_inputs) 114 | loss = args.beta * lm_loss + (1 - args.beta) * loss 115 | loss.backward() 116 | optimizer.step() 117 | if args.verbose and (j + 1) % 10 == 0: 118 | log(f'It{it}-{j + 1}, loss={loss.item()}') 119 | 120 | # detach to free GPU memory 121 | z_i = z_i.detach() 122 | 123 | _, topk_tokens = torch.topk(z_i, args.topk) 124 | probs_i = torch.softmax(z_i / args.stemp, -1).unsqueeze(0).expand(args.topk, seq_len, vocab_size) 125 | 126 | output_so_far = None 127 | # beam search left to right 128 | for t in range(seq_len): 129 | t_topk_tokens = topk_tokens[t] 130 | t_topk_onehot = torch.nn.functional.one_hot(t_topk_tokens, vocab_size).float() 131 | next_clf_scores = [] 132 | for j in range(args.num_beams): 133 | next_beam_scores = torch.zeros(tokenizer.vocab_size, device=device) - 1e9 134 | if output_so_far is None: 135 | context = probs_i.clone() 136 | else: 137 | output_len = output_so_far.shape[1] 138 | beam_topk_output = output_so_far[j].unsqueeze(0).expand(args.topk, output_len) 139 | beam_topk_output = torch.nn.functional.one_hot(beam_topk_output, vocab_size) 140 | context = torch.cat([beam_topk_output.float(), probs_i[:, output_len:].clone()], 1) 141 | context[:, t] = t_topk_onehot 142 | context_emb = torch.einsum('blv,vh->blh', context, word_embedding) 143 | 144 | context_emb = torch.cat([batch_cls_emb, context_emb, batch_sep_emb], 1) 145 | outputs = model(ctxt_input_ids=input_ids, cand_inputs_embeds=context_emb)[0] 146 | clf_scores = outputs.squeeze().detach().float() 147 | next_beam_scores.scatter_(0, t_topk_tokens, clf_scores) 148 | next_clf_scores.append(next_beam_scores.unsqueeze(0)) 149 | 150 | next_clf_scores = torch.cat(next_clf_scores, 0) 151 | next_scores = next_clf_scores + input_mask + sub_mask 152 | 153 | if args.regularize: 154 | next_scores += stopwords_mask[t] 155 | 156 | if output_so_far is None: 157 | next_scores[1:] = -1e9 158 | next_scores += first_mask 159 | 160 | # re-organize to group the beam together 161 | # (we are keeping top hypothesis accross beams) 162 | next_scores = next_scores.view(1, args.num_beams * vocab_size) # (batch_size, num_beams * vocab_size) 163 | next_scores, next_tokens = torch.topk(next_scores, args.num_beams, dim=1, largest=True, sorted=True) 164 | 165 | # next batch beam content 166 | next_sent_beam = [] 167 | for beam_token_rank, (beam_token_id, beam_token_score) in enumerate(zip(next_tokens[0], next_scores[0])): 168 | # get beam and token IDs 169 | beam_id = beam_token_id // vocab_size 170 | token_id = beam_token_id % vocab_size 171 | next_sent_beam.append((beam_token_score, token_id, beam_id)) 172 | 173 | next_batch_beam = next_sent_beam 174 | 175 | # sanity check / prepare next batch 176 | assert len(next_batch_beam) == args.num_beams 177 | beam_tokens = torch.tensor([x[1] for x in next_batch_beam], device=device) 178 | beam_idx = torch.tensor([x[2] for x in next_batch_beam], device=device) 179 | 180 | # re-order batch 181 | if output_so_far is None: 182 | output_so_far = beam_tokens.unsqueeze(1) 183 | else: 184 | output_so_far = output_so_far[beam_idx, :] 185 | output_so_far = torch.cat([output_so_far, beam_tokens.unsqueeze(1)], dim=-1) 186 | 187 | pad_output_so_far = torch.cat([cls_tensor[:args.num_beams].unsqueeze(1), 188 | output_so_far, 189 | sep_tensor[:args.num_beams].unsqueeze(1)], 1) 190 | actual_clf_scores = model(ctxt_input_ids=input_ids, cand_input_ids=pad_output_so_far)[0] 191 | actual_clf_scores = actual_clf_scores.squeeze() 192 | sorter = torch.argsort(actual_clf_scores, -1, descending=True) 193 | if args.verbose: 194 | decoded = [ 195 | f'{actual_clf_scores[i].item():.2f}, ' 196 | f'{tokenizer.decode(output_so_far[i].cpu().tolist())}' 197 | for i in sorter 198 | ] 199 | log(f'It={it}, margin={margin}, target={inputs_a} | ' + ' | '.join(decoded)) 200 | 201 | valid_idx = sorter[0] 202 | valid = False 203 | for idx in sorter: 204 | valid, _ = valid_tokenization(output_so_far[idx], tokenizer) 205 | if valid: 206 | valid_idx = idx 207 | break 208 | 209 | # re-initialize z_i 210 | curr_best = output_so_far[valid_idx] 211 | next_z_i = torch.nn.functional.one_hot(curr_best, vocab_size).float() 212 | eps = 0.1 213 | next_z_i = (next_z_i * (1 - eps)) + (1 - next_z_i) * eps / (vocab_size - 1) 214 | z_i = torch.nn.Parameter(torch.log(next_z_i), True) 215 | 216 | curr_score = actual_clf_scores[valid_idx].item() 217 | if valid and curr_score > best_score: 218 | best_score = curr_score 219 | best_collision = tokenizer.decode(curr_best.cpu().tolist()) 220 | 221 | if curr_score == prev_score: 222 | break 223 | prev_score = curr_score 224 | 225 | return best_collision, best_score 226 | 227 | 228 | def find_filters(query, model, tokenizer, device, k=500): 229 | words = [w for w in tokenizer.dict.tok2ind if w.isalpha() and w not in STOPWORDS] 230 | inputs_a = torch.tensor(tokenizer.encode(query), device=device).unsqueeze(0) 231 | inputs_b = [tokenizer.encode(w) for w in words] 232 | inputs_b = torch.tensor(inputs_b, device=device) 233 | n = len(words) 234 | batch_size = 1024 235 | n_batches = n // batch_size + 1 236 | all_scores = [] 237 | for i in tqdm.trange(n_batches, desc='Filtering vocab'): 238 | cand_input_ids = inputs_b[i * batch_size: (i + 1) * batch_size] 239 | scores = model(ctxt_input_ids=inputs_a, cand_input_ids=cand_input_ids)[0].squeeze() 240 | all_scores.append(scores) 241 | all_scores = torch.cat(all_scores) 242 | _, top_indices = torch.topk(all_scores, k) 243 | filters = set([words[i.item()] for i in top_indices]) 244 | return [w for w in filters if w.isalpha()] 245 | 246 | 247 | def add_single_plural(text, tokenizer): 248 | tokens = tokenizer.tokenize(text) 249 | vocab = tokenizer.get_vocab() 250 | contains = [] 251 | for word in vocab: 252 | if word.isalpha() and len(word) > 2: 253 | for t in tokens: 254 | if len(t) > 2 and word != t and (word.startswith(t) or t.startswith(word)): 255 | contains.append(word) 256 | 257 | for t in tokens[:]: 258 | if not t.isalpha(): 259 | continue 260 | sig_t = singularize(t) 261 | plu_t = pluralize(t) 262 | if sig_t != t and sig_t in vocab: 263 | tokens.append(sig_t) 264 | if plu_t != t and plu_t in vocab: 265 | tokens.append(plu_t) 266 | return [w for w in tokens + contains if w not in STOPWORDS] 267 | 268 | 269 | def gen_natural_collision(inputs_a, inputs_b, model, tokenizer, device, lm_model, eval_lm_model, margin): 270 | filters = find_filters(inputs_a, model, tokenizer, device, args.num_filters) 271 | num_filters_ids = tokenizer.convert_tokens_to_ids(filters) 272 | 273 | newline_token_id = tokenizer.dict['__newln__'] 274 | 275 | sub_mask = torch.zeros(tokenizer.vocab_size, device=device) 276 | sub_mask[newline_token_id] = -1e9 277 | filter_ids = [tokenizer.dict[w] for w in tokenizer.dict.tok2ind if not w.isalnum()] 278 | first_mask = torch.zeros_like(sub_mask) 279 | first_mask[filter_ids] = -1e9 280 | 281 | collition_init = [tokenizer.dict['__start__']] 282 | start_idx = 1 283 | num_beams = args.num_beams 284 | repetition_penalty = 5.0 285 | curr_len = len(collition_init) 286 | 287 | # scores for each sentence in the beam 288 | beam_scores = torch.zeros((num_beams,), dtype=torch.float, device=device) 289 | beam_scores[1:] = -1e9 290 | 291 | output_so_far = torch.tensor([collition_init] * num_beams, device=device) 292 | past = None 293 | vocab_size = tokenizer.vocab_size 294 | topk = args.topk 295 | input_ids = tokenizer.encode(inputs_a) 296 | input_mask = torch.zeros(vocab_size, device=device) 297 | # prevent output num_filters neighbor words 298 | input_filter_ids = get_inputs_filter_ids(inputs_a, tokenizer) + get_inputs_filter_ids(inputs_b, tokenizer) 299 | remove_tokens = add_single_plural(inputs_a, tokenizer) 300 | remove_ids = tokenizer.convert_tokens_to_ids(remove_tokens) 301 | input_mask[remove_ids] = -1e10 302 | input_mask[input_filter_ids] = -1e9 303 | input_mask[tokenizer.dict['.']] = -1e9 304 | input_mask[tokenizer.dict['=']] = -1e9 305 | input_mask[tokenizer.dict['@']] = -1e9 306 | input_mask[tokenizer.dict['@@.']] = -1e9 307 | input_mask[tokenizer.dict['__null__']] = -1e9 308 | input_mask[num_filters_ids] = -1e9 309 | unk_ids = tokenizer.encode('', add_special_tokens=False) 310 | input_mask[unk_ids] = -1e9 311 | input_mask[[tokenizer.dict[w] for w in tokenizer.dict.tok2ind if w.startswith('__')]] = -1e10 312 | 313 | input_ids = torch.tensor(input_ids, device=device).unsqueeze(0) 314 | sep_tensor = torch.tensor([tokenizer.sep_token_id] * topk, device=device) 315 | cls_tensor = torch.tensor([tokenizer.cls_token_id] * topk, device=device) 316 | 317 | is_first = True 318 | word_embedding = model.get_input_embeddings().weight.detach() 319 | batch_sep_emb = word_embedding[sep_tensor].unsqueeze(1) 320 | batch_cls_emb = word_embedding[cls_tensor].unsqueeze(1) 321 | 322 | def classifier_loss(p, context): 323 | context = torch.nn.functional.one_hot(context, len(word_embedding)) 324 | one_hot = torch.cat([context.float(), p.unsqueeze(1)], 1) 325 | x = torch.einsum('blv,vh->blh', one_hot, word_embedding) 326 | # add embeddings for SEP 327 | x = torch.cat([batch_cls_emb[:num_beams], x, batch_sep_emb[:num_beams]], 1) 328 | scores = model(ctxt_input_ids=input_ids, cand_inputs_embeds=x)[0] 329 | loss = (margin - scores.squeeze()).mean() 330 | return loss 331 | 332 | best_score = -1e9 333 | best_collision = None 334 | 335 | while curr_len < args.seq_len: 336 | model_inputs = lm_model.prepare_inputs_for_generation(output_so_far, past=past) 337 | outputs = lm_model(**model_inputs) 338 | present = outputs[1] 339 | # (batch_size * num_beams, vocab_size) 340 | next_token_logits = outputs[0][:, -1, :] 341 | lm_scores = torch.log_softmax(next_token_logits, dim=-1) 342 | 343 | if args.perturb_iter > 0: 344 | # perturb internal states of LM 345 | def target_model_wrapper(p): 346 | return classifier_loss(p, output_so_far.detach()[:, start_idx:]) 347 | 348 | next_token_logits = perturb_logits( 349 | next_token_logits, 350 | args.lr, 351 | target_model_wrapper, 352 | num_iterations=args.perturb_iter, 353 | kl_scale=args.kl_scale, 354 | temperature=args.stemp, 355 | device=device, 356 | verbose=args.verbose, 357 | logit_mask=input_mask, 358 | ) 359 | 360 | if repetition_penalty > 1.0: 361 | lm_model.enforce_repetition_penalty_(next_token_logits, 1, num_beams, output_so_far, repetition_penalty) 362 | next_token_logits = next_token_logits / args.stemp 363 | 364 | # (batch_size * num_beams, vocab_size) 365 | next_lm_scores = lm_scores + beam_scores[:, None].expand_as(lm_scores) + sub_mask 366 | _, topk_tokens = torch.topk(next_token_logits, topk) 367 | 368 | # get target model score here 369 | next_clf_scores = [] 370 | for i in range(num_beams): 371 | next_beam_scores = torch.zeros(tokenizer.vocab_size, device=device) - 1e9 372 | if output_so_far.shape[1] > start_idx: 373 | curr_beam_topk = output_so_far[i, start_idx:].unsqueeze(0).expand( 374 | topk, output_so_far.shape[1] - start_idx) 375 | # (topk, curr_len + next_token + sep) 376 | curr_beam_topk = torch.cat([cls_tensor.unsqueeze(1), curr_beam_topk, 377 | topk_tokens[i].unsqueeze(1), 378 | sep_tensor.unsqueeze(1)], 1) 379 | else: 380 | curr_beam_topk = torch.cat([cls_tensor.unsqueeze(1), 381 | topk_tokens[i].unsqueeze(1), 382 | sep_tensor.unsqueeze(1)], 1) 383 | clf_scores = model(ctxt_input_ids=input_ids, cand_input_ids=curr_beam_topk)[0].squeeze() 384 | next_beam_scores.scatter_(0, topk_tokens[i], clf_scores.float()) 385 | next_clf_scores.append(next_beam_scores.unsqueeze(0)) 386 | next_clf_scores = torch.cat(next_clf_scores, 0) 387 | 388 | if is_first: 389 | next_clf_scores += beam_scores[:, None].expand_as(lm_scores) 390 | next_clf_scores += first_mask 391 | is_first = False 392 | 393 | next_scores = (1 - args.beta) * next_clf_scores + args.beta * next_lm_scores 394 | next_scores += input_mask 395 | 396 | # re-organize to group the beam together (we are keeping top hypothesis accross beams) 397 | next_scores = next_scores.view(num_beams * vocab_size) 398 | next_lm_scores = next_lm_scores.view(num_beams * vocab_size) 399 | next_scores, next_tokens = torch.topk(next_scores, num_beams, largest=True, sorted=True) 400 | next_lm_scores = next_lm_scores[next_tokens] 401 | # next batch beam content 402 | next_sent_beam = [] 403 | for beam_token_rank, (beam_token_id, beam_token_score) in enumerate(zip(next_tokens, next_lm_scores)): 404 | # get beam and token IDs 405 | beam_id = beam_token_id // vocab_size 406 | token_id = beam_token_id % vocab_size 407 | next_sent_beam.append((beam_token_score, token_id, beam_id)) 408 | 409 | next_batch_beam = next_sent_beam 410 | 411 | # sanity check / prepare next batch 412 | assert len(next_batch_beam) == num_beams 413 | beam_scores = beam_scores.new([x[0] for x in next_batch_beam]) 414 | beam_tokens = output_so_far.new([x[1] for x in next_batch_beam]) 415 | beam_idx = output_so_far.new([x[2] for x in next_batch_beam]) 416 | 417 | # re-order batch 418 | output_so_far = output_so_far[beam_idx, :] 419 | output_so_far = torch.cat([output_so_far, beam_tokens.unsqueeze(1)], dim=-1) 420 | 421 | # sanity check 422 | pad_output_so_far = torch.cat([cls_tensor[: num_beams].unsqueeze(1), 423 | output_so_far[:, start_idx:], 424 | sep_tensor[:num_beams].unsqueeze(1)], 1) 425 | actual_clf_scores = model(ctxt_input_ids=input_ids, cand_input_ids=pad_output_so_far)[0].squeeze() 426 | 427 | sorter = torch.argsort(actual_clf_scores, -1, descending=True) 428 | if args.verbose: 429 | decoded = [ 430 | f'{actual_clf_scores[i].item():.2f}, ' 431 | f'{tokenizer.decode(output_so_far[i, start_idx:].cpu().tolist())}' 432 | for i in sorter 433 | ] 434 | log(f'Margin={margin}, target={inputs_a} | ' + ' | '.join(decoded)) 435 | 436 | if curr_len > args.min_len: 437 | valid_idx = sorter[0] 438 | valid = False 439 | for idx in sorter: 440 | valid, _ = valid_tokenization(output_so_far[idx, start_idx:], tokenizer) 441 | if valid: 442 | valid_idx = idx 443 | break 444 | 445 | curr_score = actual_clf_scores[valid_idx].item() 446 | curr_collision = tokenizer.decode(output_so_far[valid_idx, start_idx:].cpu().tolist()) 447 | if valid and curr_score > best_score: 448 | best_score = curr_score 449 | best_collision = curr_collision 450 | 451 | if args.verbose: 452 | lm_perp = eval_lm_model.perplexity(curr_collision) 453 | log(f'LM perp={lm_perp.item()}') 454 | 455 | # re-order internal states 456 | past = lm_model._reorder_cache(present, beam_idx) 457 | # update current length 458 | curr_len = curr_len + 1 459 | 460 | return best_collision, best_score 461 | 462 | 463 | def main(): 464 | device = torch.device(f'cuda:{args.gpu}') 465 | data_list = load_persona_chat() 466 | tokenizer = PolyEncoderTokenizer.from_pretrained() 467 | if args.poly: 468 | model = PretrainedPolyEncoder.from_pretrained() 469 | else: 470 | model = PretrainedBiEncoder.from_pretrained() 471 | 472 | history_size = model.opt['history_size'] 473 | text_truncate = model.opt['text_truncate'] 474 | model.to(device) 475 | model.eval() 476 | for param in model.parameters(): 477 | param.requires_grad = False 478 | 479 | log(f'Loading LM model from {args.lm_model_dir}') 480 | lm_model = PolyEncoderLM.from_pretrained(checkpoint=args.lm_model_dir) 481 | lm_model.to(device) 482 | lm_model.eval() 483 | for param in lm_model.parameters(): 484 | param.requires_grad = False 485 | eval_lm_model = SentenceScorer(device) 486 | 487 | if args.fp16: 488 | from apex import amp 489 | model, lm_model = amp.initialize([model, lm_model]) 490 | 491 | for data in data_list: 492 | query = '\n'.join(data[0] + data[1][-history_size:]) 493 | candidates = data[2] 494 | truth = candidates[-1] 495 | query_ids = torch.tensor(tokenizer.encode(query, max_length=text_truncate), device=device).unsqueeze(0) 496 | candidates = torch.tensor(tokenizer.batch_encode_plus( 497 | candidates, pad_to_max_length=True)['input_ids'], device=device) 498 | output = model.forward(ctxt_input_ids=query_ids, cand_input_ids=candidates) 499 | scores = output[0].squeeze() 500 | if args.nature: 501 | collision, score = gen_natural_collision( 502 | query, truth, model, tokenizer, device, lm_model, eval_lm_model, scores.max()) 503 | else: 504 | collision, score = gen_aggressive_collision(query, scores.max(), model, tokenizer, device, lm_model) 505 | 506 | # get the rank of collision 507 | scores = scores.cpu().tolist() 508 | scores = np.asarray([score] + scores) 509 | n = len(scores) 510 | ranks = np.empty(n) 511 | ranks[np.argsort(-scores)] = np.arange(n) 512 | 513 | lm_perp = eval_lm_model.perplexity(collision) 514 | msg = f'Input={query}\n' \ 515 | f'Ground truth response={truth}\n' \ 516 | f'Collision={collision}\n' \ 517 | f'Collision similarity core={score}\n' \ 518 | f'Rank={ranks[0]}\n' \ 519 | f'LM perp={lm_perp.item()}\n' 520 | log(msg) 521 | 522 | 523 | if __name__ == '__main__': 524 | main() 525 | -------------------------------------------------------------------------------- /constant.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | COLLISION_DIR = os.environ.get('COLLISION_DIR', 'collision/') 4 | assert os.path.exists(COLLISION_DIR), 'Please create a directory for extracting data and models.' 5 | 6 | BOS_TOKEN = '[unused0]' 7 | EOS_TOKEN = '[unused1]' 8 | 9 | # Paraphrase related constants 10 | PARA_DIR = os.path.join(COLLISION_DIR, 'paraphrase') 11 | 12 | # BIRCH related constants 13 | BIRCH_DIR = os.path.join(COLLISION_DIR, 'birch') 14 | BIRCH_MODEL_DIR = os.path.join(BIRCH_DIR, 'models') 15 | BIRCH_DATA_DIR = os.path.join(BIRCH_DIR, 'data') 16 | BIRCH_INDEX_DIR = os.path.join(BIRCH_DIR, 'index') 17 | BIRCH_PRED_DIR = os.path.join(BIRCH_DATA_DIR, 'predictions') 18 | BIRCH_ALPHAS = [1.0, 0.5, 0.1] 19 | BIRCH_GAMMA = 0.6 20 | 21 | # Polyencoder related constants 22 | PARLAI_DIR = os.path.join(COLLISION_DIR, 'parlai') 23 | CONV2AI_PATH = os.path.join(PARLAI_DIR, 'personachat_self_original.json') 24 | 25 | # PreSumm related constants 26 | PRESUMM_DIR = os.path.join(COLLISION_DIR, 'presumm') 27 | PRESUMM_DATA_DIR = os.path.join(PRESUMM_DIR, 'data') 28 | PRESUMM_MODEL_DIR = os.path.join(PRESUMM_DIR, 'models') 29 | PRESUMM_MODEL_PATH = os.path.join(PRESUMM_MODEL_DIR, 'bertext_cnndm_transformer_ckpt.pt') 30 | 31 | 32 | # LM Model dir 33 | BERT_LM_MODEL_DIR = os.path.join(COLLISION_DIR, 'wiki103', 'bert') 34 | POLY_LM_MODEL_DIR = os.path.join(COLLISION_DIR, 'wiki103', 'polyencoder') 35 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csong27/collision-bert/43eda087bf6d632bdb150d98e934206327f8d082/models/__init__.py -------------------------------------------------------------------------------- /models/bert_layers.py: -------------------------------------------------------------------------------- 1 | from transformers.modeling_distilbert import MultiHeadSelfAttention, TransformerBlock, Transformer 2 | from transformers.modeling_bert import BertSelfAttention, BertAttention, BertLayer, BertEncoder 3 | import math 4 | import copy 5 | import torch 6 | import torch.nn as nn 7 | 8 | 9 | class BertSelfAttentionPast(BertSelfAttention): 10 | def forward( 11 | self, 12 | hidden_states, 13 | attention_mask=None, 14 | head_mask=None, 15 | encoder_hidden_states=None, 16 | encoder_attention_mask=None, 17 | layer_past=None, 18 | cache_query=False, 19 | ): 20 | mixed_query_layer = self.query(hidden_states) 21 | 22 | # If this is instantiated as a cross-attention module, the keys 23 | # and values come from an encoder; the attention mask needs to be 24 | # such that the encoder's padding tokens are not attended to. 25 | if encoder_hidden_states is not None: 26 | mixed_key_layer = self.key(encoder_hidden_states) 27 | mixed_value_layer = self.value(encoder_hidden_states) 28 | attention_mask = encoder_attention_mask 29 | else: 30 | mixed_key_layer = self.key(hidden_states) 31 | mixed_value_layer = self.value(hidden_states) 32 | 33 | query_layer = self.transpose_for_scores(mixed_query_layer) 34 | key_layer = self.transpose_for_scores(mixed_key_layer) 35 | value_layer = self.transpose_for_scores(mixed_value_layer) 36 | 37 | if layer_past is not None: 38 | if cache_query: 39 | past_q = layer_past[2] 40 | query_layer = torch.cat((past_q, query_layer), dim=-2) 41 | 42 | past_k, past_v = layer_past[0], layer_past[1] 43 | key_layer = torch.cat((past_k, key_layer), dim=-2) 44 | value_layer = torch.cat((past_v, value_layer), dim=-2) 45 | 46 | if cache_query: 47 | present = torch.stack([key_layer, value_layer, query_layer]) 48 | else: 49 | present = torch.stack([key_layer, value_layer]) 50 | 51 | attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) 52 | attention_scores = attention_scores / math.sqrt(self.attention_head_size) 53 | if layer_past is None and attention_mask is not None: 54 | attention_scores += attention_mask 55 | 56 | # Normalize the attention scores to probabilities. 57 | attention_probs = nn.Softmax(dim=-1)(attention_scores) 58 | 59 | # This is actually dropping out entire tokens to attend to, which might 60 | # seem a bit unusual, but is taken from the original Transformer paper. 61 | attention_probs = self.dropout(attention_probs) 62 | 63 | # Mask heads if we want to 64 | if head_mask is not None: 65 | attention_probs = attention_probs * head_mask 66 | 67 | context_layer = torch.matmul(attention_probs, value_layer) 68 | 69 | context_layer = context_layer.permute(0, 2, 1, 3).contiguous() 70 | new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) 71 | context_layer = context_layer.view(*new_context_layer_shape) 72 | 73 | outputs = (context_layer, attention_probs, present) if self.output_attentions else (context_layer, present) 74 | return outputs 75 | 76 | 77 | class BertAttentionPast(BertAttention): 78 | def __init__(self, config): 79 | super().__init__(config) 80 | self.self = BertSelfAttentionPast(config) 81 | 82 | def forward( 83 | self, 84 | hidden_states, 85 | attention_mask=None, 86 | head_mask=None, 87 | encoder_hidden_states=None, 88 | encoder_attention_mask=None, 89 | layer_past=None, 90 | cache_query=False, 91 | ): 92 | self_outputs = self.self( 93 | hidden_states, attention_mask, head_mask, encoder_hidden_states, 94 | encoder_attention_mask, layer_past, cache_query 95 | ) 96 | attention_output = self.output(self_outputs[0], hidden_states) 97 | outputs = (attention_output,) + self_outputs[1:] 98 | return outputs 99 | 100 | 101 | class BertLayerPast(BertLayer): 102 | def __init__(self, config): 103 | super().__init__(config) 104 | self.attention = BertAttentionPast(config) 105 | 106 | def forward( 107 | self, 108 | hidden_states, 109 | attention_mask=None, 110 | head_mask=None, 111 | encoder_hidden_states=None, 112 | encoder_attention_mask=None, 113 | layer_past=None, 114 | cache_query=False 115 | ): 116 | self_attention_outputs = self.attention(hidden_states, attention_mask, 117 | head_mask, layer_past=layer_past, 118 | cache_query=cache_query) 119 | attention_output = self_attention_outputs[0] 120 | outputs = self_attention_outputs[1:] 121 | 122 | if self.is_decoder and encoder_hidden_states is not None: 123 | cross_attention_outputs = self.crossattention( 124 | attention_output, attention_mask, head_mask, 125 | encoder_hidden_states, encoder_attention_mask 126 | ) 127 | attention_output = cross_attention_outputs[0] 128 | outputs = outputs + cross_attention_outputs[1:] 129 | 130 | intermediate_output = self.intermediate(attention_output) 131 | layer_output = self.output(intermediate_output, attention_output) 132 | outputs = (layer_output,) + outputs 133 | return outputs 134 | 135 | 136 | class BertEncoderPast(BertEncoder): 137 | def __init__(self, config): 138 | super().__init__(config) 139 | self.output_past = getattr(config, 'output_past', True) 140 | self.layer = nn.ModuleList( 141 | [BertLayerPast(config) for _ in range(config.num_hidden_layers)]) 142 | 143 | def forward( 144 | self, 145 | hidden_states, 146 | attention_mask=None, 147 | head_mask=None, 148 | encoder_hidden_states=None, 149 | encoder_attention_mask=None, 150 | past=None, 151 | cache_query=False 152 | ): 153 | if past is None: 154 | past = [None] * len(self.layer) 155 | 156 | all_hidden_states = () 157 | all_attentions = () 158 | presents = () 159 | 160 | for i, (layer_module, layer_past) in enumerate(zip(self.layer, past)): 161 | if self.output_hidden_states: 162 | all_hidden_states = all_hidden_states + (hidden_states,) 163 | 164 | layer_outputs = layer_module( 165 | hidden_states, attention_mask, head_mask[i], encoder_hidden_states, 166 | encoder_attention_mask, layer_past, cache_query 167 | ) 168 | hidden_states = layer_outputs[0] 169 | 170 | if self.output_attentions: 171 | all_attentions = all_attentions + (layer_outputs[1],) 172 | 173 | present = layer_outputs[-1] 174 | if self.output_past: 175 | presents = presents + (present,) 176 | 177 | # Add last layer 178 | if self.output_hidden_states: 179 | all_hidden_states = all_hidden_states + (hidden_states,) 180 | 181 | outputs = (hidden_states,) 182 | if self.output_past: 183 | outputs = outputs + (presents,) 184 | if self.output_hidden_states: 185 | outputs = outputs + (all_hidden_states,) 186 | if self.output_attentions: 187 | outputs = outputs + (all_attentions,) 188 | return outputs 189 | 190 | 191 | class MaskedMultiHeadSelfAttention(MultiHeadSelfAttention): 192 | def forward(self, query, key, value, layer_past=None, 193 | mask=None, head_mask=None): 194 | bs, q_length, dim = query.size() 195 | dim_per_head = self.dim // self.n_heads 196 | 197 | def shape(x): 198 | """ separate heads """ 199 | return x.view(bs, -1, self.n_heads, dim_per_head).transpose(1, 2) 200 | 201 | def unshape(x): 202 | """ group heads """ 203 | return x.transpose(1, 2).contiguous().view(bs, -1, self.n_heads * dim_per_head) 204 | 205 | q = shape(self.q_lin(query)) # (bs, n_heads, q_length, dim_per_head) 206 | k = shape(self.k_lin(key)) # (bs, n_heads, k_length, dim_per_head) 207 | v = shape(self.v_lin(value)) # (bs, n_heads, k_length, dim_per_head) 208 | if layer_past is not None: 209 | past_k, past_v = layer_past[0], layer_past[1] 210 | k = torch.cat((past_k, k), dim=-2) 211 | v = torch.cat((past_v, v), dim=-2) 212 | 213 | present = torch.stack([k, v]) 214 | 215 | q = q / math.sqrt(dim_per_head) # (bs, n_heads, q_length, dim_per_head) 216 | # (bs, n_heads, q_length, k_length) 217 | scores = torch.matmul(q, k.transpose(2, 3)) 218 | if layer_past is None and mask is not None: 219 | scores += mask 220 | 221 | weights = nn.Softmax(dim=-1)(scores) # (bs, n_heads, q_length, k_length) 222 | weights = self.dropout(weights) # (bs, n_heads, q_length, k_length) 223 | 224 | # Mask heads if we want to 225 | if head_mask is not None: 226 | weights = weights * head_mask 227 | 228 | context = torch.matmul(weights, v) # (bs, n_heads, q_length, dim_per_head) 229 | context = unshape(context) # (bs, q_length, dim) 230 | context = self.out_lin(context) # (bs, q_length, dim) 231 | 232 | if self.output_attentions: 233 | return context, present, weights 234 | else: 235 | return context, present 236 | 237 | 238 | class MaskedTransformerBlock(TransformerBlock): 239 | def __init__(self, config): 240 | super().__init__(config) 241 | self.attention = MaskedMultiHeadSelfAttention(config) 242 | 243 | def forward(self, x, layer_past=None, attn_mask=None, head_mask=None): 244 | sa_output = self.attention(query=x, key=x, value=x, layer_past=layer_past, 245 | mask=attn_mask, head_mask=head_mask) 246 | if self.output_attentions: 247 | # (bs, seq_length, dim), (bs, n_heads, seq_length, seq_length) 248 | sa_output, sa_present, sa_weights = sa_output 249 | else: 250 | assert type(sa_output) == tuple 251 | sa_output, sa_present = sa_output 252 | sa_output = self.sa_layer_norm(sa_output + x) # (bs, seq_length, dim) 253 | 254 | # Feed Forward Network 255 | ffn_output = self.ffn(sa_output) # (bs, seq_length, dim) 256 | ffn_output = self.output_layer_norm( 257 | ffn_output + sa_output) # (bs, seq_length, dim) 258 | 259 | output = (ffn_output, sa_present) 260 | if self.output_attentions: 261 | output = (sa_weights,) + output 262 | return output 263 | 264 | 265 | class MaskedTransformer(Transformer): 266 | def __init__(self, config): 267 | super().__init__(config) 268 | self.output_past = getattr(config, 'output_past', True) 269 | layer = MaskedTransformerBlock(config) 270 | self.layer = nn.ModuleList( 271 | [copy.deepcopy(layer) for _ in range(config.n_layers)]) 272 | 273 | def forward(self, x, past=None, attn_mask=None, head_mask=None): 274 | if past is None: 275 | past = [None] * len(self.layer) 276 | 277 | all_hidden_states = () 278 | all_attentions = () 279 | presents = () 280 | 281 | hidden_state = x 282 | for i, (layer_module, layer_past) in enumerate(zip(self.layer, past)): 283 | if self.output_hidden_states: 284 | all_hidden_states = all_hidden_states + (hidden_state,) 285 | 286 | layer_outputs = layer_module(x=hidden_state, layer_past=layer_past, 287 | attn_mask=attn_mask, head_mask=head_mask[i]) 288 | hidden_state = layer_outputs[-2] 289 | present = layer_outputs[-1] 290 | if self.output_past: 291 | presents = presents + (present,) 292 | 293 | if self.output_attentions: 294 | assert len(layer_outputs) == 3 295 | attentions = layer_outputs[0] 296 | all_attentions = all_attentions + (attentions,) 297 | else: 298 | assert len(layer_outputs) == 2 299 | 300 | # Add last layer 301 | if self.output_hidden_states: 302 | all_hidden_states = all_hidden_states + (hidden_state,) 303 | 304 | outputs = (hidden_state,) 305 | if self.output_past: 306 | outputs = outputs + (presents,) 307 | if self.output_hidden_states: 308 | outputs = outputs + (all_hidden_states,) 309 | if self.output_attentions: 310 | outputs = outputs + (all_attentions,) 311 | return outputs 312 | -------------------------------------------------------------------------------- /models/bert_models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from transformers import BertModel, BertForSequenceClassification, BertForNextSentencePrediction, BertForMaskedLM 3 | from transformers.modeling_bert import BertEmbeddings 4 | 5 | from models.bert_layers import BertEncoderPast 6 | 7 | 8 | class BertForConcatNextSentencePrediction(BertForNextSentencePrediction): 9 | def __init__(self, config): 10 | super().__init__(config) 11 | self.bert = BertConcatModel(config) 12 | self.init_weights() 13 | 14 | def forward( 15 | self, 16 | input_ids=None, 17 | attention_mask=None, 18 | token_type_ids=None, 19 | position_ids=None, 20 | head_mask=None, 21 | inputs_embeds=None, 22 | next_sentence_label=None, 23 | ): 24 | outputs = self.bert( 25 | input_ids, 26 | attention_mask=attention_mask, 27 | token_type_ids=token_type_ids, 28 | position_ids=position_ids, 29 | head_mask=head_mask, 30 | inputs_embeds=inputs_embeds, 31 | ) 32 | 33 | pooled_output = outputs[1] 34 | 35 | seq_relationship_score = self.cls(pooled_output) 36 | # add hidden states and attention if they are here 37 | outputs = (seq_relationship_score, pooled_output) + outputs[2:] 38 | if next_sentence_label is not None: 39 | loss_fct = torch.nn.CrossEntropyLoss() 40 | next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1)) 41 | outputs = (next_sentence_loss,) + outputs 42 | 43 | # (next_sentence_loss), seq_relationship_score, 44 | # (hidden_states), (attentions) 45 | return outputs 46 | 47 | 48 | class BertForConcatSequenceClassification(BertForSequenceClassification): 49 | def __init__(self, config): 50 | super().__init__(config) 51 | self.bert = BertConcatModel(config) 52 | self.init_weights() 53 | 54 | def forward( 55 | self, 56 | input_ids=None, 57 | attention_mask=None, 58 | token_type_ids=None, 59 | position_ids=None, 60 | head_mask=None, 61 | inputs_embeds=None, 62 | labels=None, 63 | ): 64 | outputs = self.bert( 65 | input_ids, 66 | attention_mask=attention_mask, 67 | token_type_ids=token_type_ids, 68 | position_ids=position_ids, 69 | head_mask=head_mask, 70 | inputs_embeds=inputs_embeds, 71 | ) 72 | 73 | pooled_output = outputs[1] 74 | 75 | pooled_output = self.dropout(pooled_output) 76 | logits = self.classifier(pooled_output) 77 | 78 | outputs = (logits, pooled_output) + outputs[2:] 79 | 80 | if labels is not None: 81 | if self.num_labels == 1: 82 | # We are doing regression 83 | loss_fct = torch.nn.MSELoss() 84 | loss = loss_fct(logits.view(-1), labels.view(-1)) 85 | else: 86 | loss_fct = torch.nn.CrossEntropyLoss() 87 | loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) 88 | outputs = (loss,) + outputs 89 | 90 | return outputs 91 | 92 | 93 | class BertConcatEmbeddings(BertEmbeddings): 94 | def forward(self, input_ids=None, token_type_ids=None, position_ids=None, 95 | inputs_embeds=None): 96 | if input_ids is not None and inputs_embeds is not None: 97 | input_shape = (input_ids.size(0), input_ids.size(1) + inputs_embeds.size(1)) 98 | elif input_ids is not None: 99 | input_shape = input_ids.size() 100 | elif inputs_embeds is not None: 101 | input_shape = inputs_embeds.size()[:-1] 102 | else: 103 | raise ValueError( 104 | "You have to specify either input_ids or inputs_embeds") 105 | 106 | seq_length = input_shape[1] 107 | device = input_ids.device if input_ids is not None else inputs_embeds.device 108 | if position_ids is None: 109 | position_ids = torch.arange(seq_length, dtype=torch.long, device=device) 110 | position_ids = position_ids.unsqueeze(0).expand(input_shape) 111 | if token_type_ids is None: 112 | token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device) 113 | 114 | if inputs_embeds is None: 115 | inputs_embeds = self.word_embeddings(input_ids) 116 | elif input_ids is not None: 117 | inputs_a_embeds = self.word_embeddings(input_ids) 118 | inputs_embeds = torch.cat([inputs_a_embeds, inputs_embeds], dim=1) 119 | 120 | position_embeddings = self.position_embeddings(position_ids) 121 | token_type_embeddings = self.token_type_embeddings(token_type_ids) 122 | embeddings = inputs_embeds + position_embeddings + token_type_embeddings 123 | embeddings = self.LayerNorm(embeddings) 124 | embeddings = self.dropout(embeddings) 125 | return embeddings 126 | 127 | 128 | class BertConcatModel(BertModel): 129 | def __init__(self, config): 130 | super().__init__(config) 131 | self.embeddings = BertConcatEmbeddings(config) 132 | self.init_weights() 133 | 134 | def forward( 135 | self, 136 | input_ids=None, 137 | attention_mask=None, 138 | token_type_ids=None, 139 | position_ids=None, 140 | head_mask=None, 141 | inputs_embeds=None, 142 | encoder_hidden_states=None, 143 | encoder_attention_mask=None, 144 | ): 145 | past_length = 0 146 | if input_ids is not None and inputs_embeds is not None: 147 | input_shape = (input_ids.size(0), input_ids.size(1) + inputs_embeds.size(1)) 148 | elif input_ids is not None: 149 | input_shape = input_ids.size() 150 | elif inputs_embeds is not None: 151 | input_shape = inputs_embeds.size()[:-1] 152 | else: 153 | raise ValueError( 154 | "You have to specify either input_ids or inputs_embeds") 155 | 156 | device = input_ids.device if input_ids is not None else inputs_embeds.device 157 | 158 | batch_size, seq_length = input_shape 159 | if attention_mask is None: 160 | attention_mask = torch.ones((batch_size, seq_length + past_length), device=device) # (bs, seq_length) 161 | 162 | if attention_mask.dim() == 3: 163 | extended_attention_mask = attention_mask[:, None, :, :] 164 | elif attention_mask.dim() == 2: 165 | if self.config.is_decoder: 166 | batch_size, seq_length = input_shape 167 | seq_ids = torch.arange(past_length + seq_length, device=device) 168 | causal_mask = seq_ids[None, None, :].repeat( 169 | batch_size, past_length + seq_length, 1) <= seq_ids[None, :, None] 170 | causal_mask = causal_mask.to(attention_mask.dtype) 171 | extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :] 172 | else: 173 | extended_attention_mask = attention_mask[:, None, None, :] 174 | else: 175 | raise ValueError( 176 | "Wrong shape for input_ids (shape {}) or attention_mask (shape {})".format( 177 | input_shape, attention_mask.shape)) 178 | 179 | # Since attention_mask is 1.0 for positions we want to attend and 0.0 for 180 | # masked positions, this operation will create a tensor which is 0.0 for 181 | # positions we want to attend and -10000.0 for masked positions. 182 | # Since we are adding it to the raw scores before the softmax, this is 183 | # effectively the same as removing these entirely. 184 | extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility 185 | extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 186 | 187 | # If a 2D ou 3D attention mask is provided for the cross-attention 188 | # we need to make broadcastabe to 189 | # [batch_size, num_heads, seq_length, seq_length] 190 | if self.config.is_decoder and encoder_hidden_states is not None: 191 | encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() 192 | encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) 193 | if encoder_attention_mask is None: 194 | encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) 195 | 196 | if encoder_attention_mask.dim() == 3: 197 | encoder_extended_attention_mask = encoder_attention_mask[:, None, :, :] 198 | elif encoder_attention_mask.dim() == 2: 199 | encoder_extended_attention_mask = encoder_attention_mask[:, None, None, :] 200 | else: 201 | raise ValueError("Wrong shape for encoder_hidden_shape (shape {}) or " 202 | "encoder_attention_mask (shape {})".format(encoder_hidden_shape, 203 | encoder_attention_mask.shape)) 204 | 205 | # fp16 compatibility 206 | encoder_extended_attention_mask = encoder_extended_attention_mask.to(dtype=next(self.parameters()).dtype) 207 | encoder_extended_attention_mask = (1.0 - encoder_extended_attention_mask) * -10000.0 208 | else: 209 | encoder_extended_attention_mask = None 210 | 211 | if head_mask is not None: 212 | if head_mask.dim() == 1: 213 | head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1) 214 | head_mask = head_mask.expand(self.config.num_hidden_layers, -1, -1, -1, -1) 215 | elif head_mask.dim() == 2: 216 | head_mask = (head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1)) 217 | head_mask = head_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility 218 | else: 219 | head_mask = [None] * self.config.num_hidden_layers 220 | 221 | position_ids = torch.arange(past_length, past_length + seq_length, dtype=torch.long, device=device) 222 | position_ids = position_ids.unsqueeze(0).expand(input_shape) 223 | 224 | if input_ids is not None and inputs_embeds is not None: 225 | if token_type_ids is None: 226 | input_a_shape = input_ids.size() 227 | token_a_type_ids = torch.zeros(input_a_shape, dtype=torch.long, device=device) 228 | input_b_shape = inputs_embeds.size()[:-1] 229 | token_b_type_ids = torch.ones(input_b_shape, dtype=torch.long, device=device) 230 | token_type_ids = torch.cat([token_a_type_ids, token_b_type_ids], dim=1) 231 | 232 | embedding_output = self.embeddings( 233 | input_ids=input_ids, position_ids=position_ids, 234 | token_type_ids=token_type_ids, inputs_embeds=inputs_embeds, 235 | ) 236 | else: 237 | embedding_output = self.embeddings( 238 | input_ids=input_ids, position_ids=position_ids, 239 | token_type_ids=token_type_ids, inputs_embeds=inputs_embeds, 240 | ) 241 | 242 | encoder_outputs = self.encoder( 243 | embedding_output, 244 | attention_mask=extended_attention_mask, 245 | head_mask=head_mask, 246 | encoder_hidden_states=encoder_hidden_states, 247 | encoder_attention_mask=encoder_extended_attention_mask, 248 | ) 249 | sequence_output = encoder_outputs[0] 250 | pooled_output = self.pooler(sequence_output) 251 | 252 | outputs = (sequence_output, pooled_output,) + encoder_outputs[1:] 253 | return outputs 254 | 255 | 256 | class BertAutoRegressiveModel(BertModel): 257 | def __init__(self, config): 258 | super().__init__(config) 259 | self.encoder = BertEncoderPast(config) 260 | self.init_weights() 261 | 262 | def forward( 263 | self, 264 | input_ids=None, 265 | attention_mask=None, 266 | token_type_ids=None, 267 | position_ids=None, 268 | head_mask=None, 269 | inputs_embeds=None, 270 | encoder_hidden_states=None, 271 | encoder_attention_mask=None, 272 | past=None 273 | ): 274 | if past is None: 275 | past_length = 0 276 | else: 277 | past_length = past[0][0].size(-2) 278 | 279 | if input_ids is not None and inputs_embeds is not None: 280 | raise ValueError( 281 | "You cannot specify both input_ids and inputs_embeds at the same time") 282 | elif input_ids is not None: 283 | input_shape = input_ids.size() 284 | elif inputs_embeds is not None: 285 | input_shape = inputs_embeds.size()[:-1] 286 | else: 287 | raise ValueError( 288 | "You have to specify either input_ids or inputs_embeds") 289 | 290 | device = input_ids.device if input_ids is not None else inputs_embeds.device 291 | batch_size, seq_length = input_shape 292 | if attention_mask is None: 293 | attention_mask = torch.ones((batch_size, seq_length + past_length), device=device) # (bs, seq_length) 294 | seq_ids = torch.arange(past_length + seq_length, device=device) 295 | # add a upper triangle mask for auto-regressive language model 296 | causal_mask = seq_ids[None, None, :].repeat(batch_size, past_length + seq_length, 1) <= seq_ids[None, :, None] 297 | causal_mask = causal_mask.to(attention_mask.dtype) 298 | extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :] 299 | extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility 300 | extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 301 | encoder_extended_attention_mask = None 302 | 303 | # Prepare head mask if needed 304 | # 1.0 in head_mask indicate we keep the head 305 | # attention_probs has shape bsz x n_heads x N x N 306 | # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] 307 | # and head_mask is converted to shape 308 | # [num_hidden_layers x batch x num_heads x seq_length x seq_length] 309 | if head_mask is not None: 310 | if head_mask.dim() == 1: 311 | head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1) 312 | head_mask = head_mask.expand(self.config.num_hidden_layers, -1, -1, -1, -1) 313 | elif head_mask.dim() == 2: 314 | head_mask = (head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1)) 315 | head_mask = head_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility 316 | else: 317 | head_mask = [None] * self.config.num_hidden_layers 318 | 319 | position_ids = torch.arange(past_length, past_length + seq_length, dtype=torch.long, device=device) 320 | position_ids = position_ids.unsqueeze(0).expand(input_shape) 321 | 322 | embedding_output = self.embeddings( 323 | input_ids=input_ids, position_ids=position_ids, 324 | token_type_ids=token_type_ids, inputs_embeds=inputs_embeds 325 | ) 326 | encoder_outputs = self.encoder( 327 | embedding_output, 328 | attention_mask=extended_attention_mask, 329 | head_mask=head_mask, 330 | encoder_hidden_states=encoder_hidden_states, 331 | encoder_attention_mask=encoder_extended_attention_mask, 332 | past=past 333 | ) 334 | sequence_output = encoder_outputs[0] 335 | pooled_output = self.pooler(sequence_output) 336 | 337 | # add hidden_states and attentions if they are here 338 | outputs = (sequence_output, pooled_output,) + encoder_outputs[1:] 339 | # sequence_output, pooled_output, (hidden_states), (attentions) 340 | return outputs 341 | 342 | 343 | class BertForLM(BertForMaskedLM): 344 | def __init__(self, config): 345 | super().__init__(config) 346 | self.bert = BertAutoRegressiveModel(config) 347 | self.start_idx = 1 348 | self.init_weights() 349 | 350 | def prepare_inputs_for_generation(self, input_ids, past): 351 | # only last token for inputs_ids if past is defined in kwargs 352 | if past: 353 | input_ids = input_ids[:, -1].unsqueeze(-1) 354 | return {"input_ids": input_ids, "past": past} 355 | 356 | def forward( 357 | self, 358 | input_ids=None, 359 | attention_mask=None, 360 | token_type_ids=None, 361 | position_ids=None, 362 | head_mask=None, 363 | inputs_embeds=None, 364 | masked_lm_labels=None, 365 | encoder_hidden_states=None, 366 | encoder_attention_mask=None, 367 | labels=None, 368 | one_hot_labels=None, 369 | past=None 370 | ): 371 | label_start_idx = 1 372 | if inputs_embeds is not None: 373 | start_embeds = self.get_input_embeddings().weight[self.start_idx] 374 | inputs_embeds = torch.cat([start_embeds.view(1, 1, -1), inputs_embeds], 1) 375 | label_start_idx = 0 376 | 377 | outputs = self.bert( 378 | input_ids, 379 | attention_mask=attention_mask, 380 | token_type_ids=token_type_ids, 381 | position_ids=position_ids, 382 | head_mask=head_mask, 383 | inputs_embeds=inputs_embeds, 384 | encoder_hidden_states=encoder_hidden_states, 385 | encoder_attention_mask=encoder_attention_mask, 386 | past=past 387 | ) 388 | 389 | sequence_output = outputs[0] 390 | prediction_scores = self.cls(sequence_output) 391 | # Add hidden states and attention if they are here 392 | outputs = (prediction_scores,) + outputs[2:] 393 | 394 | # we are doing next-token prediction; 395 | # shift prediction scores and input ids by one 396 | if one_hot_labels is not None: 397 | prediction_scores = prediction_scores[:, :-1, :].contiguous() 398 | lm_labels = one_hot_labels[:, label_start_idx:, :].contiguous() 399 | nll = -torch.log_softmax(prediction_scores, -1) 400 | ltr_lm_loss = torch.sum(nll * lm_labels, -1).mean() 401 | outputs = (ltr_lm_loss,) + outputs 402 | elif labels is not None: 403 | prediction_scores = prediction_scores[:, :-1, :].contiguous() 404 | lm_labels = labels[:, label_start_idx:].contiguous() 405 | loss_fct = torch.nn.CrossEntropyLoss() 406 | ltr_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), lm_labels.view(-1)) 407 | outputs = (ltr_lm_loss,) + outputs 408 | return outputs 409 | -------------------------------------------------------------------------------- /models/polyencoder/__init__.py: -------------------------------------------------------------------------------- 1 | from models.polyencoder.modeling_polyencoder import PretrainedPolyEncoder, PretrainedBiEncoder, PolyEncoderLM 2 | from models.polyencoder.tokenization_polyencoder import PolyEncoderTokenizer 3 | -------------------------------------------------------------------------------- /models/polyencoder/config_polyencoder.py: -------------------------------------------------------------------------------- 1 | from parlai.core.opt import load_opt_file 2 | from constant import PARLAI_DIR 3 | 4 | PARLAI_MODEL_DIR = PARLAI_DIR + '/models/' 5 | POLYENC_MODEL_DIR = PARLAI_MODEL_DIR + 'model_poly/' 6 | POLYENC_OPT_FILE = POLYENC_MODEL_DIR + 'model.opt' 7 | BI_MODEL_DIR = PARLAI_MODEL_DIR + 'model_bi/' 8 | BIENC_OPT_FILE = BI_MODEL_DIR + 'model.opt' 9 | PRETRAINED_BI_MODEL_DIR = PARLAI_MODEL_DIR + 'bi_model_huge_reddit/' 10 | PRETRAINED_BIENC_OPT_FILE = PRETRAINED_BI_MODEL_DIR + 'model.opt' 11 | PRETRAINED_POLY_MODEL_DIR = PARLAI_MODEL_DIR + 'poly_model_huge_reddit/' 12 | PRETRAINED_POLYENC_OPT_FILE = PRETRAINED_POLY_MODEL_DIR + 'model.opt' 13 | 14 | 15 | def load_poly_encoder_opt(): 16 | opt = load_opt_file(POLYENC_OPT_FILE) 17 | if isinstance(opt['fixed_candidates_path'], str): 18 | opt['fixed_candidates_path'] = PARLAI_DIR + opt['fixed_candidates_path'] 19 | opt['data_path'] = PARLAI_DIR + 'data' 20 | opt['datapath'] = PARLAI_DIR + 'data' 21 | opt['model_file'] = POLYENC_MODEL_DIR + 'model' 22 | opt['dict_file'] = POLYENC_MODEL_DIR + 'model.dict' 23 | opt['encode_candidate_vecs'] = False 24 | return opt 25 | 26 | 27 | def load_bi_encoder_opt(): 28 | opt = load_opt_file(BIENC_OPT_FILE) 29 | if isinstance(opt['fixed_candidates_path'], str): 30 | opt['fixed_candidates_path'] = PARLAI_DIR + opt['fixed_candidates_path'] 31 | opt['data_path'] = PARLAI_DIR + 'data' 32 | opt['datapath'] = PARLAI_DIR + 'data' 33 | opt['model_file'] = BI_MODEL_DIR + 'model' 34 | opt['dict_file'] = BI_MODEL_DIR + 'model.dict' 35 | opt['encode_candidate_vecs'] = False 36 | return opt 37 | 38 | 39 | def load_pretrained_poly_encoder_opt(): 40 | opt = load_opt_file(PRETRAINED_POLYENC_OPT_FILE) 41 | if isinstance(opt['fixed_candidates_path'], str): 42 | opt['fixed_candidates_path'] = PARLAI_DIR + opt['fixed_candidates_path'] 43 | opt['data_path'] = PARLAI_DIR + 'data' 44 | opt['datapath'] = PARLAI_DIR + 'data' 45 | opt['model_file'] = PRETRAINED_POLY_MODEL_DIR + 'model' 46 | opt['dict_file'] = PRETRAINED_POLY_MODEL_DIR + 'model.dict' 47 | opt['encode_candidate_vecs'] = False 48 | return opt 49 | 50 | 51 | def load_pretrained_bi_encoder_opt(): 52 | opt = load_opt_file(PRETRAINED_BIENC_OPT_FILE) 53 | if isinstance(opt['fixed_candidates_path'], str): 54 | opt['fixed_candidates_path'] = PARLAI_DIR + opt['fixed_candidates_path'] 55 | opt['data_path'] = PARLAI_DIR + 'data' 56 | opt['datapath'] = PARLAI_DIR + 'data' 57 | opt['model_file'] = PRETRAINED_BI_MODEL_DIR + 'model' 58 | opt['dict_file'] = PRETRAINED_BI_MODEL_DIR + 'model.dict' 59 | opt['encode_candidate_vecs'] = False 60 | return opt 61 | -------------------------------------------------------------------------------- /models/polyencoder/layers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | from parlai.agents.transformer.modules import MultiHeadAttention, \ 4 | TransformerEncoderLayer, _normalize, LAYER_NORM_EPS, \ 5 | create_position_codes, gelu 6 | from torch.nn import LayerNorm 7 | 8 | 9 | class MultiHeadAttentionPast(MultiHeadAttention): 10 | def forward(self, hidden_states, attention_mask=None, layer_past=None): 11 | batch_size, query_len, dim = hidden_states.size() 12 | assert (dim == self.dim), 'Dimensions do not match: {} query vs {} configured'.format(dim, self.dim) 13 | n_heads = self.n_heads 14 | dim_per_head = dim // n_heads 15 | scale = math.sqrt(dim_per_head) 16 | 17 | def prepare_head(tensor): 18 | # input is [batch_size, seq_len, n_heads * dim_per_head] 19 | # output is [batch_size * n_heads, seq_len, dim_per_head] 20 | bsz, seq_len, _ = tensor.size() 21 | tensor = tensor.view(batch_size, tensor.size(1), n_heads, dim_per_head) 22 | tensor = tensor.transpose(1, 2).contiguous().view(batch_size * n_heads, seq_len, dim_per_head) 23 | return tensor 24 | 25 | q = self.q_lin(hidden_states) 26 | k = self.k_lin(hidden_states) 27 | v = self.v_lin(hidden_states) 28 | 29 | if layer_past is not None: 30 | past_k, past_v = layer_past[0], layer_past[1] 31 | k = torch.cat((past_k, k), dim=-2) 32 | v = torch.cat((past_v, v), dim=-2) 33 | 34 | present = torch.stack([k, v]) 35 | q = prepare_head(q) 36 | k = prepare_head(k) 37 | v = prepare_head(v) 38 | 39 | dot_prod = q.div_(scale).bmm(k.transpose(1, 2)) 40 | # [B * n_heads, query_len, key_len] 41 | if layer_past is None and attention_mask is not None: 42 | dot_prod += attention_mask 43 | 44 | attn_weights = torch.softmax(dot_prod, dim=-1, dtype=torch.float).type_as(hidden_states) 45 | attn_weights = self.attn_dropout(attn_weights) # --attention-dropout 46 | 47 | attentioned = attn_weights.bmm(v) 48 | attentioned = ( 49 | attentioned.type_as(hidden_states) 50 | .view(batch_size, n_heads, query_len, dim_per_head) 51 | .transpose(1, 2) 52 | .contiguous() 53 | .view(batch_size, query_len, dim) 54 | ) 55 | out = self.out_lin(attentioned) 56 | return out, present 57 | 58 | 59 | class TransformerEncoderLayerPast(TransformerEncoderLayer): 60 | def __init__( 61 | self, 62 | n_heads, 63 | embedding_size, 64 | ffn_size, 65 | attention_dropout=0.0, 66 | relu_dropout=0.0, 67 | dropout=0.0, 68 | activation='relu', 69 | variant=None, 70 | ): 71 | super().__init__(n_heads, embedding_size, ffn_size, attention_dropout, 72 | relu_dropout, dropout, activation, variant) 73 | self.attention = MultiHeadAttentionPast(n_heads, embedding_size, dropout=attention_dropout) 74 | 75 | def forward(self, tensor, attention_mask=None, layer_past=None): 76 | """ 77 | Forward pass. 78 | """ 79 | 80 | residual = tensor 81 | if self.variant == 'prelayernorm': 82 | tensor = _normalize(tensor, self.norm1) 83 | attended_tensor, layer_past = self.attention(tensor, attention_mask, layer_past) 84 | tensor = residual + self.dropout(attended_tensor) 85 | if self.variant == 'aiayn' or self.variant == 'xlm': 86 | tensor = _normalize(tensor, self.norm1) 87 | residual = tensor 88 | if self.variant == 'prelayernorm': 89 | tensor = _normalize(tensor, self.norm2) 90 | tensor = residual + self.dropout(self.ffn(tensor)) 91 | if self.variant == 'aiayn' or self.variant == 'xlm': 92 | tensor = _normalize(tensor, self.norm2) 93 | # tensor *= mask.unsqueeze(-1).type_as(tensor) 94 | return tensor, layer_past 95 | 96 | 97 | class TransformerAREncoder(torch.nn.Module): 98 | def __init__( 99 | self, 100 | n_heads, 101 | n_layers, 102 | embedding_size, 103 | ffn_size, 104 | embedding=None, 105 | dropout=0.0, 106 | attention_dropout=0.0, 107 | relu_dropout=0.0, 108 | padding_idx=0, 109 | learn_positional_embeddings=False, 110 | embeddings_scale=False, 111 | n_positions=1024, 112 | activation='relu', 113 | variant='aiayn', 114 | n_segments=0, 115 | output_scaling=1.0, 116 | ): 117 | super(TransformerAREncoder, self).__init__() 118 | self.embedding_size = embedding_size 119 | self.ffn_size = ffn_size 120 | self.n_layers = n_layers 121 | self.n_heads = n_heads 122 | self.dim = embedding_size 123 | self.embeddings_scale = embeddings_scale 124 | self.padding_idx = padding_idx 125 | # this is --dropout, not --relu-dropout or --attention-dropout 126 | self.dropout_frac = dropout 127 | self.dropout = torch.nn.Dropout(p=self.dropout_frac) 128 | self.variant = variant 129 | self.n_segments = n_segments 130 | 131 | self.n_positions = n_positions 132 | self.out_dim = embedding_size 133 | assert ( 134 | embedding_size % n_heads == 0 135 | ), 'Transformer embedding size must be a multiple of n_heads' 136 | 137 | # check input formats: 138 | if embedding is not None: 139 | assert ( 140 | embedding_size is None or embedding_size == 141 | embedding.weight.shape[1] 142 | ), "Embedding dim must match the embedding size." 143 | 144 | if embedding is not None: 145 | self.embeddings = embedding 146 | else: 147 | raise AssertionError( 148 | "This code should not execute. Left here in case we want to enable it." 149 | ) 150 | 151 | # create the positional embeddings 152 | self.position_embeddings = torch.nn.Embedding(n_positions, embedding_size) 153 | if not learn_positional_embeddings: 154 | create_position_codes(n_positions, embedding_size, out=self.position_embeddings.weight) 155 | else: 156 | torch.nn.init.normal_(self.position_embeddings.weight, 0, embedding_size ** -0.5) 157 | 158 | # embedding normalization 159 | if self.variant == 'xlm' or self.variant == 'prelayernorm': 160 | self.norm_embeddings = LayerNorm(self.dim, eps=LAYER_NORM_EPS) 161 | elif self.variant == 'aiayn': 162 | pass 163 | else: 164 | raise ValueError("Can't handle --variant {}".format(self.variant)) 165 | 166 | if self.n_segments >= 1: 167 | self.segment_embeddings = torch.nn.Embedding(self.n_segments, self.dim) 168 | 169 | # build the model 170 | self.layers = torch.nn.ModuleList() 171 | for _ in range(self.n_layers): 172 | self.layers.append( 173 | TransformerEncoderLayerPast( 174 | n_heads, 175 | embedding_size, 176 | ffn_size, 177 | attention_dropout=attention_dropout, 178 | relu_dropout=relu_dropout, 179 | dropout=dropout, 180 | variant=variant, 181 | activation=activation, 182 | ) 183 | ) 184 | self.output_scaling = output_scaling 185 | 186 | def forward(self, input_ids=None, attention_mask=None, position_ids=None, 187 | segments=None, past=None, inputs_embeds=None): 188 | if past is None: 189 | past_length = 0 190 | past = [None] * len(self.layers) 191 | else: 192 | past_length = past[0][0].size(-2) 193 | 194 | if input_ids is None: 195 | assert inputs_embeds is not None 196 | input_shape = inputs_embeds.size()[:2] 197 | device = inputs_embeds.device 198 | else: 199 | input_shape = input_ids.size() 200 | device = input_ids.device 201 | batch_size, seq_length = input_shape 202 | 203 | if attention_mask is None: 204 | attention_mask = torch.ones((1, seq_length + past_length), device=device) # (bs, seq_length) 205 | seq_ids = torch.arange(past_length + seq_length, device=device) 206 | # add a upper triangle mask for auto-regressive language model 207 | causal_mask = seq_ids[None, None, :].repeat(1, past_length + seq_length, 1) <= seq_ids[None, :, None] 208 | causal_mask = causal_mask.to(attention_mask.dtype) 209 | extended_attention_mask = causal_mask[:, :, :] * attention_mask[:, None, :] 210 | extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 211 | 212 | if position_ids is None: 213 | position_ids = torch.arange(past_length, past_length + seq_length, dtype=torch.long, device=device) 214 | position_ids = position_ids.unsqueeze(0).expand(input_shape) 215 | 216 | tensor = self.embeddings(input_ids) if inputs_embeds is None else inputs_embeds 217 | if self.embeddings_scale: 218 | tensor = tensor * math.sqrt(self.dim) 219 | 220 | assert position_ids.max().item() <= self.n_positions 221 | position_embs = self.position_embeddings(position_ids) 222 | tensor = tensor + position_embs 223 | 224 | if self.n_segments >= 1: 225 | if segments is None: 226 | segments = torch.zeros_like(position_ids) 227 | tensor = tensor + self.segment_embeddings(segments) 228 | 229 | if self.variant == 'xlm': 230 | tensor = _normalize(tensor, self.norm_embeddings) 231 | 232 | # --dropout on the embeddings 233 | tensor = self.dropout(tensor) 234 | presents = () 235 | for i in range(self.n_layers): 236 | layer_past = past[i] 237 | tensor, layer_present = self.layers[i](tensor, extended_attention_mask, layer_past) 238 | presents = presents + (layer_present,) 239 | 240 | if self.variant == 'prelayernorm': 241 | tensor = _normalize(tensor, self.norm_embeddings) 242 | tensor *= self.output_scaling 243 | return tensor, presents 244 | 245 | 246 | class LMPredictionHead(torch.nn.Module): 247 | def __init__(self, opt, vocab_size): 248 | super().__init__() 249 | hidden_size = opt['embedding_size'] 250 | activation = opt['activation'] 251 | self.dense = torch.nn.Linear(hidden_size, hidden_size) 252 | if activation == 'relu': 253 | self.nonlinear = torch.relu 254 | elif activation == 'gelu': 255 | self.nonlinear = gelu 256 | else: 257 | raise ValueError( 258 | "Don't know how to handle --activation {}".format(activation) 259 | ) 260 | self.LayerNorm = LayerNorm(hidden_size, eps=LAYER_NORM_EPS) 261 | self.decoder = torch.nn.Linear(hidden_size, vocab_size) 262 | 263 | def forward(self, hidden_states): 264 | hidden_states = self.dense(hidden_states) 265 | hidden_states = self.nonlinear(hidden_states) 266 | hidden_states = self.LayerNorm(hidden_states) 267 | hidden_states = self.decoder(hidden_states) 268 | return hidden_states 269 | -------------------------------------------------------------------------------- /models/polyencoder/modeling_polyencoder.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import os 4 | 5 | from transformers import WEIGHTS_NAME 6 | from parlai.agents.transformer.modules import TransformerEncoder, \ 7 | TransformerMemNetModel, TransformerResponseWrapper, _normalize, get_n_positions_from_options 8 | from parlai.agents.transformer.polyencoder import PolyEncoderModule 9 | from parlai.core.dict import DictionaryAgent 10 | from models.polyencoder.layers import TransformerAREncoder, LMPredictionHead 11 | from models.polyencoder.config_polyencoder import load_poly_encoder_opt, \ 12 | load_bi_encoder_opt, load_pretrained_poly_encoder_opt, load_pretrained_bi_encoder_opt 13 | from collections import OrderedDict 14 | 15 | 16 | class PolyEncoderTransformerEncoder(TransformerEncoder): 17 | def forward(self, input, positions=None, segments=None, inputs_embeds=None): 18 | if input is None: 19 | device = inputs_embeds.device 20 | input_shape = inputs_embeds.size()[:2] 21 | mask = torch.ones(input_shape, dtype=torch.bool).to(device) 22 | else: 23 | input_shape = input.shape 24 | device = input.device 25 | mask = input != self.padding_idx 26 | 27 | if positions is None: 28 | positions = (mask.cumsum(dim=1, dtype=torch.int64) - 1).clamp_(min=0) 29 | 30 | tensor = self.embeddings(input) if inputs_embeds is None else inputs_embeds 31 | if self.embeddings_scale: 32 | tensor = tensor * np.sqrt(self.dim) 33 | 34 | if positions.max().item() > self.n_positions: 35 | raise ValueError( 36 | 'You are inputting a sequence of {x} length, but only have ' 37 | '--n-positions {y}. Set --truncate or increase --n-positions'.format( 38 | x=positions.max().item(), y=self.n_positions 39 | ) 40 | ) 41 | position_embs = self.position_embeddings(positions).expand_as(tensor) 42 | tensor = tensor + position_embs 43 | 44 | if self.n_segments >= 1: 45 | if segments is None: 46 | segments = torch.zeros(input_shape, dtype=torch.long, device=device) 47 | tensor = tensor + self.segment_embeddings(segments) 48 | 49 | if self.variant == 'xlm': 50 | tensor = _normalize(tensor, self.norm_embeddings) 51 | 52 | # --dropout on the embeddings 53 | tensor = self.dropout(tensor) 54 | 55 | tensor *= mask.unsqueeze(-1).type_as(tensor) 56 | 57 | if getattr(self.layers, 'is_model_parallel', False): 58 | # factored out for readability. It is equivalent to the other condition 59 | tensor = self._apply_model_parallel(tensor, mask) 60 | else: 61 | for i in range(self.n_layers): 62 | tensor = self.layers[i](tensor, mask) 63 | 64 | if self.variant == 'prelayernorm': 65 | tensor = _normalize(tensor, self.norm_embeddings) 66 | tensor *= self.output_scaling 67 | if self.reduction_type == 'first': 68 | return tensor[:, 0, :] 69 | elif self.reduction_type == 'max': 70 | return tensor.max(dim=1)[0] 71 | elif self.reduction_type == 'mean': 72 | divisor = mask.float().sum(dim=1).unsqueeze(-1).clamp(min=1).type_as( 73 | tensor) 74 | output = tensor.sum(dim=1) / divisor 75 | return output 76 | elif self.reduction_type is None or 'none' in self.reduction_type: 77 | return tensor, mask 78 | else: 79 | raise ValueError("Can't handle --reduction-type {}".format(self.reduction_type)) 80 | 81 | 82 | def _build_encoder( 83 | opt, 84 | dictionary, 85 | embedding=None, 86 | padding_idx=None, 87 | reduction_type='mean', 88 | n_positions=1024, 89 | n_segments=0, 90 | ): 91 | n_layers = ( 92 | opt['n_encoder_layers'] 93 | if opt.get('n_encoder_layers', -1) > 0 94 | else opt['n_layers'] 95 | ) 96 | return PolyEncoderTransformerEncoder( 97 | n_heads=opt['n_heads'], 98 | n_layers=n_layers, 99 | embedding_size=opt['embedding_size'], 100 | ffn_size=opt['ffn_size'], 101 | vocabulary_size=len(dictionary), 102 | embedding=embedding, 103 | dropout=opt['dropout'], 104 | attention_dropout=opt['attention_dropout'], 105 | relu_dropout=opt['relu_dropout'], 106 | padding_idx=padding_idx, 107 | learn_positional_embeddings=opt['learn_positional_embeddings'], 108 | embeddings_scale=opt['embeddings_scale'], 109 | reduction_type=reduction_type, 110 | n_positions=n_positions, 111 | n_segments=n_segments, 112 | activation=opt['activation'], 113 | variant=opt['variant'], 114 | output_scaling=opt['output_scaling'], 115 | ) 116 | 117 | 118 | class PolyEncoderModel(PolyEncoderModule): 119 | def __init__(self, opt, dict_, null_idx): 120 | super(PolyEncoderModel, self).__init__(opt, dict_, null_idx) 121 | self.encoder_ctxt = self.get_encoder( 122 | opt=opt, 123 | dict_=dict_, 124 | null_idx=null_idx, 125 | reduction_type=None, 126 | for_context=True, 127 | ) 128 | self.encoder_cand = self.get_encoder( 129 | opt=opt, 130 | dict_=dict_, 131 | null_idx=null_idx, 132 | reduction_type=opt['reduction_type'], 133 | for_context=False, 134 | ) 135 | 136 | def get_encoder(self, opt, dict_, null_idx, reduction_type, 137 | for_context: bool): 138 | n_positions = get_n_positions_from_options(opt) 139 | embeddings = self._get_embeddings(dict_=dict_, null_idx=null_idx, embedding_size=opt['embedding_size']) 140 | return PolyEncoderTransformerEncoder( 141 | n_heads=opt['n_heads'], 142 | n_layers=opt['n_layers'], 143 | embedding_size=opt['embedding_size'], 144 | ffn_size=opt['ffn_size'], 145 | vocabulary_size=len(dict_), 146 | embedding=embeddings, 147 | dropout=opt['dropout'], 148 | attention_dropout=opt['attention_dropout'], 149 | relu_dropout=opt['relu_dropout'], 150 | padding_idx=null_idx, 151 | learn_positional_embeddings=opt['learn_positional_embeddings'], 152 | embeddings_scale=opt['embeddings_scale'], 153 | reduction_type=reduction_type, 154 | n_positions=n_positions, 155 | n_segments=opt.get('n_segments', 2), 156 | activation=opt['activation'], 157 | variant=opt['variant'], 158 | output_scaling=opt['output_scaling'], 159 | ) 160 | 161 | 162 | class PretrainedPolyEncoder(torch.nn.Module): 163 | def __init__(self): 164 | super().__init__() 165 | opt = load_poly_encoder_opt() 166 | d = DictionaryAgent(opt) 167 | self.opt = opt 168 | self.model = PolyEncoderModel(opt, d, d[d.null_token]) 169 | 170 | def get_input_embeddings(self): 171 | return self.model.encoder_cand.embeddings 172 | 173 | def forward(self, ctxt_input_ids=None, ctxt_inputs_embeds=None, 174 | cand_input_ids=None, cand_inputs_embeds=None, 175 | return_scores=True): 176 | outputs = () 177 | if ctxt_input_ids is not None or ctxt_inputs_embeds is not None: 178 | ctxt_hiddens, ctxt_masks = self.model.encoder_ctxt(ctxt_input_ids, inputs_embeds=ctxt_inputs_embeds) 179 | outputs = outputs + (ctxt_hiddens, ctxt_masks,) 180 | 181 | if cand_input_ids is not None or cand_inputs_embeds is not None: 182 | cand_hiddens = self.model.encoder_cand(cand_input_ids, inputs_embeds=cand_inputs_embeds) 183 | outputs = outputs + (cand_hiddens,) 184 | 185 | if return_scores and len(outputs) == 3: 186 | scores = self.score(*outputs) 187 | outputs = (scores,) + outputs 188 | 189 | return outputs 190 | 191 | def score(self, ctxt_hiddens, ctxt_masks, cand_hiddens): 192 | bsz = ctxt_hiddens.size(0) 193 | dim = ctxt_hiddens.size(2) 194 | 195 | if self.model.type == 'codes': 196 | ctxt_rep = self.model.attend( 197 | self.model.code_attention, 198 | queries=self.model.codes.repeat(bsz, 1, 1), 199 | keys=ctxt_hiddens, 200 | values=ctxt_hiddens, 201 | mask=ctxt_masks, 202 | ) 203 | ctxt_rep_mask = ctxt_rep.new_ones(bsz, self.model.n_codes).byte() 204 | elif self.model.type == 'n_first': 205 | # Expand the output if it is not long enough 206 | if ctxt_hiddens.size(1) < self.model.n_codes: 207 | difference = self.model.n_codes - ctxt_hiddens.size(1) 208 | extra_rep = ctxt_hiddens.new_zeros(bsz, difference, dim) 209 | ctxt_rep = torch.cat([ctxt_hiddens, extra_rep], dim=1) 210 | extra_mask = ctxt_masks.new_zeros(bsz, difference) 211 | ctxt_rep_mask = torch.cat([ctxt_masks, extra_mask], dim=1) 212 | else: 213 | ctxt_rep = ctxt_hiddens[:, 0: self.model.n_codes, :] 214 | ctxt_rep_mask = ctxt_masks[:, 0: self.model.n_codes] 215 | else: 216 | raise ValueError(self.model.type) 217 | 218 | if bsz > 1: 219 | cand_hiddens = torch.cat([cand_hiddens.unsqueeze(0)] * bsz, 0) 220 | else: 221 | cand_hiddens = cand_hiddens.unsqueeze(0) 222 | 223 | ctxt_final_rep = self.model.attend(self.model.attention, cand_hiddens, ctxt_rep, ctxt_rep, ctxt_rep_mask) 224 | scores = torch.sum(ctxt_final_rep * cand_hiddens, 2) 225 | return scores 226 | 227 | @classmethod 228 | def from_pretrained(cls): 229 | opt = load_poly_encoder_opt() 230 | model_file = opt['model_file'] 231 | state_dict = torch.load(model_file, map_location='cpu')['model'] 232 | self = cls() 233 | self.model.load_state_dict(state_dict) 234 | return self 235 | 236 | 237 | class BiEncoderModel(TransformerMemNetModel): 238 | def __init__(self, opt, dictionary): 239 | super().__init__(opt, dictionary) 240 | n_positions = get_n_positions_from_options(opt) 241 | self.context_encoder = _build_encoder( 242 | opt, 243 | dictionary, 244 | self.embeddings, 245 | self.pad_idx, 246 | reduction_type=self.reduction_type, 247 | n_positions=n_positions, 248 | n_segments=self.n_segments, 249 | ) 250 | 251 | if opt.get('share_encoders'): 252 | self.cand_encoder = TransformerResponseWrapper(self.context_encoder, self.context_encoder.out_dim) 253 | else: 254 | if not self.share_word_embedding: 255 | cand_embeddings = self.cand_embeddings 256 | else: 257 | cand_embeddings = self.embeddings 258 | self.cand_encoder = _build_encoder( 259 | opt, 260 | dictionary, 261 | cand_embeddings, 262 | self.pad_idx, 263 | n_positions=n_positions, 264 | reduction_type=self.reduction_type, 265 | n_segments=self.n_segments, 266 | ) 267 | 268 | # build memory encoder 269 | if opt.get('wrap_memory_encoder', False): 270 | self.memory_transformer = TransformerResponseWrapper(self.context_encoder, self.context_encoder.out_dim) 271 | else: 272 | self.memory_transformer = self.context_encoder 273 | 274 | 275 | class PretrainedBiEncoder(torch.nn.Module): 276 | def __init__(self): 277 | super().__init__() 278 | opt = load_bi_encoder_opt() 279 | d = DictionaryAgent(opt) 280 | self.opt = opt 281 | self.model = BiEncoderModel(opt, d) 282 | 283 | def get_input_embeddings(self): 284 | return self.model.cand_embeddings 285 | 286 | def forward(self, ctxt_input_ids=None, ctxt_inputs_embeds=None, 287 | cand_input_ids=None, cand_inputs_embeds=None, 288 | return_scores=True): 289 | outputs = () 290 | if ctxt_input_ids is not None or ctxt_inputs_embeds is not None: 291 | ctxt_hiddens = self.model.context_encoder(ctxt_input_ids, inputs_embeds=ctxt_inputs_embeds) 292 | outputs = outputs + (ctxt_hiddens,) 293 | 294 | if cand_input_ids is not None or cand_inputs_embeds is not None: 295 | cand_hiddens = self.model.cand_encoder(cand_input_ids, inputs_embeds=cand_inputs_embeds) 296 | outputs = outputs + (cand_hiddens,) 297 | if return_scores and len(outputs) == 2: 298 | scores = self.score(*outputs) 299 | outputs = (scores,) + outputs 300 | return outputs 301 | 302 | def score(self, context_h, cands_h): 303 | # possibly normalize the context and candidate representations 304 | if self.opt['normalize_sent_emb']: 305 | context_h = context_h / context_h.norm(2, dim=1, keepdim=True) 306 | cands_h = cands_h / cands_h.norm(2, dim=1, keepdim=True) 307 | return torch.matmul(context_h, cands_h.t()) 308 | 309 | @classmethod 310 | def from_pretrained(cls): 311 | opt = load_bi_encoder_opt() 312 | model_file = opt['model_file'] 313 | state_dict = torch.load(model_file, map_location='cpu')['model'] 314 | self = cls() 315 | self.model.load_state_dict(state_dict) 316 | return self 317 | 318 | 319 | class PolyEncoderLM(torch.nn.Module): 320 | def __init__(self, opt, encoder_name='encoder_cand'): 321 | super().__init__() 322 | n_positions = get_n_positions_from_options(opt) 323 | d = DictionaryAgent(opt) 324 | e = torch.nn.Embedding(len(d), opt['embedding_size'], d[d.null_token]) 325 | torch.nn.init.normal_(e.weight, mean=0, std=opt['embedding_size'] ** -0.5) 326 | torch.nn.init.constant_(e.weight[d[d.null_token]], 0) 327 | 328 | self.opt = opt 329 | self.vocab_size = len(d) 330 | encoder_cand = TransformerAREncoder( 331 | n_heads=opt['n_heads'], 332 | n_layers=opt['n_layers'], 333 | embedding_size=opt['embedding_size'], 334 | ffn_size=opt['ffn_size'], 335 | embedding=e, 336 | dropout=opt['dropout'], 337 | attention_dropout=opt['attention_dropout'], 338 | relu_dropout=opt['relu_dropout'], 339 | padding_idx=d[d.null_token], 340 | learn_positional_embeddings=opt['learn_positional_embeddings'], 341 | embeddings_scale=opt['embeddings_scale'], 342 | n_positions=n_positions, 343 | n_segments=opt.get('n_segments', 2), 344 | activation=opt['activation'], 345 | variant=opt['variant'], 346 | output_scaling=opt['output_scaling'], 347 | ) 348 | self.encoder_name = encoder_name 349 | setattr(self, encoder_name, encoder_cand) 350 | self.cls = LMPredictionHead(opt, len(d)) 351 | 352 | @property 353 | def enc(self): 354 | return getattr(self, self.encoder_name) 355 | 356 | def get_input_embeddings(self): 357 | return self.enc.embeddings 358 | 359 | def save_pretrained(self, save_directory): 360 | assert os.path.isdir(save_directory), "Saving path should be a directory" \ 361 | " where the model and configuration can be saved" 362 | # Only save the model itself if we are using distributed training 363 | # If we save using the predefined names, we can load using `from_pretrained` 364 | output_model_file = os.path.join(save_directory, WEIGHTS_NAME) 365 | torch.save(self.state_dict(), output_model_file) 366 | 367 | @classmethod 368 | def from_pretrained(cls, model_type='poly', checkpoint=None): 369 | if model_type == 'poly': 370 | opt = load_pretrained_poly_encoder_opt() 371 | encoder_name = 'encoder' 372 | else: 373 | raise ValueError(model_type) 374 | 375 | if checkpoint is None: 376 | model_file = opt['model_file'] 377 | state_dict = torch.load(model_file, map_location='cpu')['model'] 378 | else: 379 | model_file = os.path.join(checkpoint, WEIGHTS_NAME) 380 | state_dict = torch.load(model_file, map_location='cpu') 381 | 382 | self = cls(opt, encoder_name) 383 | common_state_dict = OrderedDict() 384 | for key in self.state_dict(): 385 | if key in state_dict: 386 | common_state_dict[key] = state_dict[key] 387 | if not key.startswith('cls') and key not in state_dict: 388 | raise ValueError(f'Weight not found in pretrained model for {key}') 389 | self.load_state_dict(common_state_dict, strict=False) 390 | if hasattr(self, 'cls'): 391 | self.cls.decoder.weight = self.get_input_embeddings().weight 392 | return self 393 | 394 | @staticmethod 395 | def prepare_inputs_for_generation(input_ids, past): 396 | # only last token for inputs_ids if past is defined in kwargs 397 | if past: 398 | input_ids = input_ids[:, -1].unsqueeze(-1) 399 | return {"input_ids": input_ids, "past": past} 400 | 401 | @staticmethod 402 | def enforce_repetition_penalty_(lprobs, batch_size, num_beams, prev_output_tokens, repetition_penalty): 403 | for i in range(batch_size * num_beams): 404 | for previous_token in set(prev_output_tokens[i].tolist()): 405 | # if score < 0 then repetition penalty has to 406 | # multiplied to reduce the previous token probability 407 | if lprobs[i, previous_token] < 0: 408 | lprobs[i, previous_token] *= repetition_penalty 409 | else: 410 | lprobs[i, previous_token] /= repetition_penalty 411 | 412 | @staticmethod 413 | def _reorder_cache(past, beam_idx): 414 | reordered_past = [] 415 | for layer_past in past: 416 | # get the correct batch idx from layer past batch dim 417 | # batch dim of `past` and `mems` is at 2nd position 418 | # print(layer_past.shape) 419 | reordered_layer_past = [layer_past[:, i].unsqueeze(1).clone().detach() for i in beam_idx] 420 | reordered_layer_past = torch.cat(reordered_layer_past, dim=1) 421 | # check that shape matches 422 | assert reordered_layer_past.shape == layer_past.shape 423 | reordered_past.append(reordered_layer_past) 424 | past = tuple(reordered_past) 425 | return past 426 | 427 | def forward( 428 | self, 429 | input_ids=None, 430 | attention_mask=None, 431 | token_type_ids=None, 432 | position_ids=None, 433 | labels=None, 434 | one_hot_labels=None, 435 | past=None, 436 | inputs_embeds=None, 437 | ): 438 | outputs = self.enc( 439 | input_ids, 440 | attention_mask=attention_mask, 441 | segments=token_type_ids, 442 | position_ids=position_ids, 443 | past=past, 444 | inputs_embeds=inputs_embeds 445 | ) 446 | sequence_output = outputs[0] 447 | if hasattr(self, 'cls'): 448 | prediction_scores = self.cls(sequence_output) 449 | else: 450 | prediction_scores = torch.nn.functional.linear(sequence_output, self.get_input_embeddings().weight) 451 | # Add hidden states and attention if they are here 452 | outputs = (prediction_scores,) + outputs[1:] 453 | 454 | # we are doing next-token prediction; 455 | # shift prediction scores and input ids by one 456 | if one_hot_labels is not None: 457 | prediction_scores = prediction_scores[:, :-1, :].contiguous() 458 | lm_labels = one_hot_labels[:, 1:, :].contiguous() 459 | nll = -torch.log_softmax(prediction_scores, -1) 460 | ltr_lm_loss = torch.sum(nll * lm_labels, -1).mean() 461 | outputs = (ltr_lm_loss,) + outputs 462 | elif labels is not None: 463 | prediction_scores = prediction_scores[:, :-1, :].contiguous() 464 | lm_labels = labels[:, 1:].contiguous() 465 | loss_fct = torch.nn.CrossEntropyLoss() 466 | ltr_lm_loss = loss_fct(prediction_scores.view(-1, self.vocab_size), lm_labels.view(-1)) 467 | outputs = (ltr_lm_loss,) + outputs 468 | return outputs 469 | -------------------------------------------------------------------------------- /models/polyencoder/tokenization_polyencoder.py: -------------------------------------------------------------------------------- 1 | from parlai.core.dict import DictionaryAgent 2 | from transformers import PreTrainedTokenizer 3 | from models.polyencoder.config_polyencoder import load_poly_encoder_opt 4 | 5 | 6 | class PolyEncoderTokenizer(PreTrainedTokenizer): 7 | def __init__(self, **kwargs): 8 | opt = load_poly_encoder_opt() 9 | self.dict = DictionaryAgent(opt) 10 | super().__init__( 11 | unk_token=self.dict.unk_token, 12 | pad_token=self.dict.null_token, 13 | cls_token=self.dict.start_token, 14 | sep_token=self.dict.end_token, **kwargs, 15 | ) 16 | 17 | def get_vocab(self): 18 | return self.dict.tok2ind 19 | 20 | def save_vocabulary(self, save_directory): 21 | pass 22 | 23 | @property 24 | def vocab_size(self): 25 | return len(self.dict.tok2ind) 26 | 27 | def _tokenize(self, text, **kwargs): 28 | return self.dict.tokenize(str(text)) 29 | 30 | def _convert_token_to_id(self, token): 31 | return self.dict[token] 32 | 33 | def _convert_id_to_token(self, index): 34 | return self.dict.ind2tok.get(index, self.unk_token) 35 | 36 | def convert_tokens_to_string(self, tokens): 37 | out_string = self.dict.bpe.decode(tokens, token_ids=[], delimiter=' ') 38 | return out_string 39 | 40 | def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None): 41 | if token_ids_1 is None: 42 | return [self.cls_token_id] + token_ids_0 + [self.sep_token_id] 43 | cls = [self.cls_token_id] 44 | sep = [self.sep_token_id] 45 | return cls + token_ids_0 + sep + token_ids_1 + sep 46 | 47 | @classmethod 48 | def from_pretrained(cls, *inputs, **kwargs): 49 | return cls() 50 | -------------------------------------------------------------------------------- /models/presumm/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csong27/collision-bert/43eda087bf6d632bdb150d98e934206327f8d082/models/presumm/__init__.py -------------------------------------------------------------------------------- /models/presumm/config.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | from constant import PRESUMM_DIR 4 | 5 | DATA_DIR = os.path.join(PRESUMM_DIR, 'data') 6 | MODEL_DIR = os.path.join(PRESUMM_DIR, 'models') 7 | MODEL_PATH = os.path.join(MODEL_DIR, 'bertext_cnndm_transformer_ckpt.pt') 8 | 9 | model_flags = ['hidden_size', 'ff_size', 'heads', 'emb_size', 'enc_layers', 10 | 'enc_hidden_size', 'enc_ff_size', 11 | 'dec_layers', 'dec_hidden_size', 'dec_ff_size', 'encoder', 12 | 'ff_actv', 'use_interval'] 13 | 14 | 15 | def str2bool(v): 16 | if v.lower() in ('yes', 'true', 't', 'y', '1'): 17 | return True 18 | elif v.lower() in ('no', 'false', 'f', 'n', '0'): 19 | return False 20 | else: 21 | raise argparse.ArgumentTypeError('Boolean value expected.') 22 | 23 | 24 | def get_config(): 25 | config = argparse.Namespace( 26 | encoder='bert', 27 | max_pos=512, 28 | bert_max_pos=1024, 29 | ext_dropout=0.2, 30 | ext_layers=2, 31 | ext_hidden_size=768, 32 | ext_heads=8, 33 | ext_ff_size=2048, 34 | use_interval=True, 35 | max_tgt_len=140, 36 | ) 37 | return config 38 | 39 | 40 | def get_args(): 41 | parser = argparse.ArgumentParser() 42 | parser.add_argument("-task", default='ext', type=str, choices=['ext', 'abs']) 43 | parser.add_argument("-encoder", default='bert', type=str, choices=['bert', 'baseline']) 44 | parser.add_argument("-mode", default='test', type=str, choices=['train', 'validate', 'test']) 45 | parser.add_argument("-bert_data_path", default=DATA_DIR + '/cnndm') 46 | parser.add_argument("-model_path", default=MODEL_DIR) 47 | parser.add_argument("-result_path", default='results/cnndm') 48 | parser.add_argument("-temp_dir", default='results/') 49 | 50 | parser.add_argument("-batch_size", default=140, type=int) 51 | parser.add_argument("-test_batch_size", default=500, type=int) 52 | parser.add_argument("-max_pos", default=512, type=int) 53 | parser.add_argument("-bert_max_pos", default=1024, type=int) 54 | parser.add_argument("-use_interval", type=str2bool, nargs='?', const=True, default=True) 55 | parser.add_argument("-large", type=str2bool, nargs='?', const=True, default=False) 56 | parser.add_argument("-load_from_extractive", default='', type=str) 57 | 58 | parser.add_argument("-sep_optim", type=str2bool, nargs='?', const=True, default=False) 59 | parser.add_argument("-lr_bert", default=2e-3, type=float) 60 | parser.add_argument("-lr_dec", default=2e-3, type=float) 61 | parser.add_argument("-use_bert_emb", type=str2bool, nargs='?', const=True, default=False) 62 | 63 | parser.add_argument("-share_emb", type=str2bool, nargs='?', const=True, default=False) 64 | parser.add_argument("-finetune_bert", type=str2bool, nargs='?', const=True, default=True) 65 | parser.add_argument("-dec_dropout", default=0.2, type=float) 66 | parser.add_argument("-dec_layers", default=6, type=int) 67 | parser.add_argument("-dec_hidden_size", default=768, type=int) 68 | parser.add_argument("-dec_heads", default=8, type=int) 69 | parser.add_argument("-dec_ff_size", default=2048, type=int) 70 | parser.add_argument("-enc_hidden_size", default=512, type=int) 71 | parser.add_argument("-enc_ff_size", default=512, type=int) 72 | parser.add_argument("-enc_dropout", default=0.2, type=float) 73 | parser.add_argument("-enc_layers", default=6, type=int) 74 | 75 | # params for EXT 76 | parser.add_argument("-ext_dropout", default=0.2, type=float) 77 | parser.add_argument("-ext_layers", default=2, type=int) 78 | parser.add_argument("-ext_hidden_size", default=768, type=int) 79 | parser.add_argument("-ext_heads", default=8, type=int) 80 | parser.add_argument("-ext_ff_size", default=2048, type=int) 81 | 82 | parser.add_argument("-label_smoothing", default=0.1, type=float) 83 | parser.add_argument("-generator_shard_size", default=32, type=int) 84 | parser.add_argument("-alpha", default=0.6, type=float) 85 | parser.add_argument("-beam_size", default=5, type=int) 86 | parser.add_argument("-min_length", default=15, type=int) 87 | parser.add_argument("-max_length", default=150, type=int) 88 | parser.add_argument("-max_tgt_len", default=140, type=int) 89 | 90 | parser.add_argument("-param_init", default=0, type=float) 91 | parser.add_argument("-param_init_glorot", type=str2bool, nargs='?', const=True, default=True) 92 | parser.add_argument("-optim", default='adam', type=str) 93 | parser.add_argument("-lr", default=1, type=float) 94 | parser.add_argument("-beta1", default=0.9, type=float) 95 | parser.add_argument("-beta2", default=0.999, type=float) 96 | parser.add_argument("-warmup_steps", default=8000, type=int) 97 | parser.add_argument("-warmup_steps_bert", default=8000, type=int) 98 | parser.add_argument("-warmup_steps_dec", default=8000, type=int) 99 | parser.add_argument("-max_grad_norm", default=0, type=float) 100 | 101 | parser.add_argument("-save_checkpoint_steps", default=5, type=int) 102 | parser.add_argument("-accum_count", default=1, type=int) 103 | parser.add_argument("-report_every", default=1, type=int) 104 | parser.add_argument("-train_steps", default=1000, type=int) 105 | parser.add_argument("-recall_eval", type=str2bool, nargs='?', const=True, default=False) 106 | 107 | parser.add_argument('-visible_gpus', default='-1', type=str) 108 | parser.add_argument('-gpu_ranks', default='0', type=str) 109 | parser.add_argument('-log_file', default='logs/cnndm.log') 110 | parser.add_argument('-seed', default=666, type=int) 111 | 112 | parser.add_argument("-test_all", type=str2bool, nargs='?', const=True, default=False) 113 | parser.add_argument("-test_from", default='') 114 | parser.add_argument("-test_start_from", default=-1, type=int) 115 | 116 | parser.add_argument("-train_from", default='') 117 | parser.add_argument("-report_rouge", type=str2bool, nargs='?', const=True, default=True) 118 | parser.add_argument("-block_trigram", type=str2bool, nargs='?', const=True, default=True) 119 | args = parser.parse_args() 120 | args.world_size = 1 121 | return args 122 | -------------------------------------------------------------------------------- /models/presumm/encoder.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | from models.presumm.neural import MultiHeadedAttention, PositionwiseFeedForward 7 | 8 | 9 | class Classifier(nn.Module): 10 | def __init__(self, hidden_size, ): 11 | super(Classifier, self).__init__() 12 | self.linear1 = nn.Linear(hidden_size, 1) 13 | self.sigmoid = nn.Sigmoid() 14 | 15 | def forward(self, x, mask_cls, output_logits=False): 16 | h = self.linear1(x).squeeze(-1) 17 | if output_logits: 18 | return h 19 | sent_scores = self.sigmoid(h) * mask_cls.float() 20 | return sent_scores 21 | 22 | 23 | class PositionalEncoding(nn.Module): 24 | 25 | def __init__(self, dropout, dim, max_len=5000): 26 | pe = torch.zeros(max_len, dim) 27 | position = torch.arange(0, max_len).unsqueeze(1) 28 | div_term = torch.exp((torch.arange(0, dim, 2, dtype=torch.float) * -(math.log(10000.0) / dim))) 29 | pe[:, 0::2] = torch.sin(position.float() * div_term) 30 | pe[:, 1::2] = torch.cos(position.float() * div_term) 31 | pe = pe.unsqueeze(0) 32 | super(PositionalEncoding, self).__init__() 33 | self.register_buffer('pe', pe) 34 | self.dropout = nn.Dropout(p=dropout) 35 | self.dim = dim 36 | 37 | def forward(self, emb, step=None): 38 | emb = emb * math.sqrt(self.dim) 39 | if step: 40 | emb = emb + self.pe[:, step][:, None, :] 41 | else: 42 | emb = emb + self.pe[:, :emb.size(1)] 43 | emb = self.dropout(emb) 44 | return emb 45 | 46 | def get_emb(self, emb): 47 | return self.pe[:, :emb.size(1)] 48 | 49 | 50 | class TransformerEncoderLayer(nn.Module): 51 | def __init__(self, d_model, heads, d_ff, dropout): 52 | super(TransformerEncoderLayer, self).__init__() 53 | self.self_attn = MultiHeadedAttention(heads, d_model, dropout=dropout) 54 | self.feed_forward = PositionwiseFeedForward(d_model, d_ff, dropout) 55 | self.layer_norm = nn.LayerNorm(d_model, eps=1e-6) 56 | self.dropout = nn.Dropout(dropout) 57 | 58 | def forward(self, iter, query, inputs, mask): 59 | if iter != 0: 60 | input_norm = self.layer_norm(inputs) 61 | else: 62 | input_norm = inputs 63 | 64 | mask = mask.unsqueeze(1) 65 | context = self.self_attn(input_norm, input_norm, input_norm, mask=mask) 66 | out = self.dropout(context) + inputs 67 | return self.feed_forward(out) 68 | 69 | 70 | class ExtTransformerEncoder(nn.Module): 71 | def __init__(self, d_model, d_ff, heads, dropout, num_inter_layers=0): 72 | super(ExtTransformerEncoder, self).__init__() 73 | self.d_model = d_model 74 | self.num_inter_layers = num_inter_layers 75 | self.pos_emb = PositionalEncoding(dropout, d_model) 76 | self.transformer_inter = nn.ModuleList( 77 | [TransformerEncoderLayer(d_model, heads, d_ff, dropout) 78 | for _ in range(num_inter_layers)]) 79 | self.dropout = nn.Dropout(dropout) 80 | self.layer_norm = nn.LayerNorm(d_model, eps=1e-6) 81 | self.wo = nn.Linear(d_model, 1, bias=True) 82 | self.sigmoid = nn.Sigmoid() 83 | 84 | def forward(self, top_vecs, mask, output_logits=False): 85 | """ See :obj:`EncoderBase.forward()`""" 86 | 87 | batch_size, n_sents = top_vecs.size(0), top_vecs.size(1) 88 | pos_emb = self.pos_emb.pe[:, :n_sents] 89 | x = top_vecs * mask[:, :, None].float() 90 | x = x + pos_emb 91 | 92 | for i in range(self.num_inter_layers): 93 | x = self.transformer_inter[i](i, x, x, 1 - mask) 94 | x = self.layer_norm(x) 95 | if output_logits: 96 | return self.wo(x).squeeze(-1) 97 | 98 | sent_scores = self.sigmoid(self.wo(x)) 99 | sent_scores = sent_scores.squeeze(-1) * mask.float() 100 | return sent_scores 101 | -------------------------------------------------------------------------------- /models/presumm/model_builder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from transformers import BertModel, BertConfig 4 | from models.presumm.encoder import Classifier, ExtTransformerEncoder 5 | 6 | 7 | def get_generator(vocab_size, dec_hidden_size, device): 8 | gen_func = nn.LogSoftmax(dim=-1) 9 | generator = nn.Sequential( 10 | nn.Linear(dec_hidden_size, vocab_size), 11 | gen_func 12 | ) 13 | generator.to(device) 14 | 15 | return generator 16 | 17 | 18 | class Bert(nn.Module): 19 | def __init__(self): 20 | super(Bert, self).__init__() 21 | self.model = BertModel.from_pretrained('bert-base-uncased') 22 | 23 | def forward(self, x, segs, mask, x_embeds=None): 24 | top_vec, _ = self.model(input_ids=x, token_type_ids=segs, attention_mask=mask, inputs_embeds=x_embeds) 25 | return top_vec 26 | 27 | 28 | class ExtSummarizer(nn.Module): 29 | def __init__(self, config, checkpoint=None): 30 | super(ExtSummarizer, self).__init__() 31 | self.bert = Bert() 32 | 33 | self.ext_layer = ExtTransformerEncoder(self.bert.model.config.hidden_size, 34 | config.ext_ff_size, config.ext_heads, 35 | config.ext_dropout, config.ext_layers) 36 | if config.encoder == 'baseline': 37 | bert_config = BertConfig(self.bert.model.config.vocab_size, 38 | hidden_size=config.ext_hidden_size, 39 | num_hidden_layers=config.ext_layers, 40 | num_attention_heads=config.ext_heads, 41 | intermediate_size=config.ext_ff_size) 42 | self.bert.model = BertModel(bert_config) 43 | self.ext_layer = Classifier(self.bert.model.config.hidden_size) 44 | 45 | if checkpoint is not None: 46 | self.load_state_dict(checkpoint, strict=True) 47 | 48 | if config.bert_max_pos > 512: 49 | my_pos_embeddings = nn.Embedding(config.bert_max_pos, self.bert.model.config.hidden_size) 50 | my_pos_embeddings.weight.data[:512] = self.bert.model.embeddings.position_embeddings.weight.data 51 | my_pos_embeddings.weight.data[512:] = \ 52 | self.bert.model.embeddings.position_embeddings.weight.data[-1][None, :].repeat( 53 | config.bert_max_pos - 512, 1) 54 | self.bert.model.embeddings.position_embeddings = my_pos_embeddings 55 | 56 | def forward(self, src, segs, clss, mask_src, mask_cls, src_embeds=None, 57 | output_logits=False): 58 | top_vec = self.bert(src, segs, mask_src, src_embeds) 59 | sents_vec = top_vec[torch.arange(top_vec.size(0)).unsqueeze(1), clss] 60 | sents_vec = sents_vec * mask_cls[:, :, None].float() 61 | sent_scores = self.ext_layer(sents_vec, mask_cls, output_logits).squeeze(-1) 62 | return sent_scores 63 | -------------------------------------------------------------------------------- /models/presumm/neural.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | 8 | def aeq(*args): 9 | """ 10 | Assert all arguments have the same value 11 | """ 12 | arguments = (arg for arg in args) 13 | first = next(arguments) 14 | assert all(arg == first for arg in arguments), \ 15 | "Not all arguments have the same value: " + str(args) 16 | 17 | 18 | def sequence_mask(lengths, max_len=None): 19 | """ 20 | Creates a boolean mask from sequence lengths. 21 | """ 22 | batch_size = lengths.numel() 23 | max_len = max_len or lengths.max() 24 | return (torch.arange(0, max_len) 25 | .type_as(lengths) 26 | .repeat(batch_size, 1) 27 | .lt(lengths.unsqueeze(1))) 28 | 29 | 30 | def gelu(x): 31 | return 0.5 * x * (1 + torch.tanh( 32 | math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) 33 | 34 | 35 | class GlobalAttention(nn.Module): 36 | """ 37 | Global attention takes a matrix and a query vector. It 38 | then computes a parameterized convex combination of the matrix 39 | based on the input query. 40 | 41 | Constructs a unit mapping a query `q` of size `dim` 42 | and a source matrix `H` of size `n x dim`, to an output 43 | of size `dim`. 44 | 45 | 46 | .. mermaid:: 47 | 48 | graph BT 49 | A[Query] 50 | subgraph RNN 51 | C[H 1] 52 | D[H 2] 53 | E[H N] 54 | end 55 | F[Attn] 56 | G[Output] 57 | A --> F 58 | C --> F 59 | D --> F 60 | E --> F 61 | C -.-> G 62 | D -.-> G 63 | E -.-> G 64 | F --> G 65 | 66 | All models compute the output as 67 | :math:`c = sum_{j=1}^{SeqLength} a_j H_j` where 68 | :math:`a_j` is the softmax of a score function. 69 | Then then apply a projection layer to [q, c]. 70 | 71 | However they 72 | differ on how they compute the attention score. 73 | 74 | * Luong Attention (dot, general): 75 | * dot: :math:`score(H_j,q) = H_j^T q` 76 | * general: :math:`score(H_j, q) = H_j^T W_a q` 77 | 78 | 79 | * Bahdanau Attention (mlp): 80 | * :math:`score(H_j, q) = v_a^T tanh(W_a q + U_a h_j)` 81 | 82 | 83 | Args: 84 | dim (int): dimensionality of query and key 85 | coverage (bool): use coverage term 86 | attn_type (str): type of attention to use, options [dot,general,mlp] 87 | 88 | """ 89 | 90 | def __init__(self, dim, attn_type="dot"): 91 | super(GlobalAttention, self).__init__() 92 | 93 | self.dim = dim 94 | assert attn_type in ["dot", "general", "mlp"], ( 95 | "Please select a valid attention type.") 96 | self.attn_type = attn_type 97 | 98 | if self.attn_type == "general": 99 | self.linear_in = nn.Linear(dim, dim, bias=False) 100 | elif self.attn_type == "mlp": 101 | self.linear_context = nn.Linear(dim, dim, bias=False) 102 | self.linear_query = nn.Linear(dim, dim, bias=True) 103 | self.v = nn.Linear(dim, 1, bias=False) 104 | # mlp wants it with bias 105 | out_bias = self.attn_type == "mlp" 106 | self.linear_out = nn.Linear(dim * 2, dim, bias=out_bias) 107 | 108 | def score(self, h_t, h_s): 109 | """ 110 | Args: 111 | h_t (`FloatTensor`): sequence of queries `[batch x tgt_len x dim]` 112 | h_s (`FloatTensor`): sequence of sources `[batch x src_len x dim]` 113 | 114 | Returns: 115 | :obj:`FloatTensor`: 116 | raw attention scores (unnormalized) for each src index 117 | `[batch x tgt_len x src_len]` 118 | 119 | """ 120 | 121 | # Check input sizes 122 | src_batch, src_len, src_dim = h_s.size() 123 | tgt_batch, tgt_len, tgt_dim = h_t.size() 124 | 125 | if self.attn_type in ["general", "dot"]: 126 | if self.attn_type == "general": 127 | h_t_ = h_t.view(tgt_batch * tgt_len, tgt_dim) 128 | h_t_ = self.linear_in(h_t_) 129 | h_t = h_t_.view(tgt_batch, tgt_len, tgt_dim) 130 | h_s_ = h_s.transpose(1, 2) 131 | # (batch, t_len, d) x (batch, d, s_len) --> (batch, t_len, s_len) 132 | return torch.bmm(h_t, h_s_) 133 | else: 134 | dim = self.dim 135 | wq = self.linear_query(h_t.view(-1, dim)) 136 | wq = wq.view(tgt_batch, tgt_len, 1, dim) 137 | wq = wq.expand(tgt_batch, tgt_len, src_len, dim) 138 | 139 | uh = self.linear_context(h_s.contiguous().view(-1, dim)) 140 | uh = uh.view(src_batch, 1, src_len, dim) 141 | uh = uh.expand(src_batch, tgt_len, src_len, dim) 142 | 143 | # (batch, t_len, s_len, d) 144 | wquh = torch.tanh(wq + uh) 145 | 146 | return self.v(wquh.view(-1, dim)).view(tgt_batch, tgt_len, src_len) 147 | 148 | def forward(self, source, memory_bank, memory_lengths=None, 149 | memory_masks=None): 150 | """ 151 | 152 | Args: 153 | source (`FloatTensor`): query vectors `[batch x tgt_len x dim]` 154 | memory_bank (`FloatTensor`): source vectors `[batch x src_len x dim]` 155 | memory_lengths (`LongTensor`): the source context lengths `[batch]` 156 | coverage (`FloatTensor`): None (not supported yet) 157 | 158 | Returns: 159 | (`FloatTensor`, `FloatTensor`): 160 | 161 | * Computed vector `[tgt_len x batch x dim]` 162 | * Attention distribtutions for each query 163 | `[tgt_len x batch x src_len]` 164 | """ 165 | 166 | # one step input 167 | if source.dim() == 2: 168 | one_step = True 169 | source = source.unsqueeze(1) 170 | else: 171 | one_step = False 172 | 173 | batch, source_l, dim = memory_bank.size() 174 | batch_, target_l, dim_ = source.size() 175 | 176 | # compute attention scores, as in Luong et al. 177 | align = self.score(source, memory_bank) 178 | 179 | if memory_masks is not None: 180 | memory_masks = memory_masks.transpose(0, 1) 181 | memory_masks = memory_masks.transpose(1, 2) 182 | align.masked_fill_(1 - memory_masks.byte(), -float('inf')) 183 | 184 | if memory_lengths is not None: 185 | mask = sequence_mask(memory_lengths, max_len=align.size(-1)) 186 | mask = mask.unsqueeze(1) # Make it broadcastable. 187 | align.masked_fill_(1 - mask, -float('inf')) 188 | 189 | align_vectors = F.softmax(align.view(batch * target_l, source_l), -1) 190 | align_vectors = align_vectors.view(batch, target_l, source_l) 191 | 192 | c = torch.bmm(align_vectors, memory_bank) 193 | 194 | # concatenate 195 | concat_c = torch.cat([c, source], 2).view(batch * target_l, dim * 2) 196 | attn_h = self.linear_out(concat_c).view(batch, target_l, dim) 197 | if self.attn_type in ["general", "dot"]: 198 | attn_h = torch.tanh(attn_h) 199 | 200 | if one_step: 201 | attn_h = attn_h.squeeze(1) 202 | align_vectors = align_vectors.squeeze(1) 203 | 204 | 205 | else: 206 | attn_h = attn_h.transpose(0, 1).contiguous() 207 | align_vectors = align_vectors.transpose(0, 1).contiguous() 208 | 209 | return attn_h, align_vectors 210 | 211 | 212 | class PositionwiseFeedForward(nn.Module): 213 | """ A two-layer Feed-Forward-Network with residual layer norm. 214 | 215 | Args: 216 | d_model (int): the size of input for the first-layer of the FFN. 217 | d_ff (int): the hidden layer size of the second-layer 218 | of the FNN. 219 | dropout (float): dropout probability in :math:`[0, 1)`. 220 | """ 221 | 222 | def __init__(self, d_model, d_ff, dropout=0.1): 223 | super(PositionwiseFeedForward, self).__init__() 224 | self.w_1 = nn.Linear(d_model, d_ff) 225 | self.w_2 = nn.Linear(d_ff, d_model) 226 | self.layer_norm = nn.LayerNorm(d_model, eps=1e-6) 227 | self.actv = gelu 228 | self.dropout_1 = nn.Dropout(dropout) 229 | self.dropout_2 = nn.Dropout(dropout) 230 | 231 | def forward(self, x): 232 | inter = self.dropout_1(self.actv(self.w_1(self.layer_norm(x)))) 233 | output = self.dropout_2(self.w_2(inter)) 234 | return output + x 235 | 236 | 237 | class MultiHeadedAttention(nn.Module): 238 | """ 239 | Multi-Head Attention module from 240 | "Attention is All You Need" 241 | :cite:`DBLP:journals/corr/VaswaniSPUJGKP17`. 242 | 243 | Similar to standard `dot` attention but uses 244 | multiple attention distributions simulataneously 245 | to select relevant items. 246 | 247 | .. mermaid:: 248 | 249 | graph BT 250 | A[key] 251 | B[value] 252 | C[query] 253 | O[output] 254 | subgraph Attn 255 | D[Attn 1] 256 | E[Attn 2] 257 | F[Attn N] 258 | end 259 | A --> D 260 | C --> D 261 | A --> E 262 | C --> E 263 | A --> F 264 | C --> F 265 | D --> O 266 | E --> O 267 | F --> O 268 | B --> O 269 | 270 | Also includes several additional tricks. 271 | 272 | Args: 273 | head_count (int): number of parallel heads 274 | model_dim (int): the dimension of keys/values/queries, 275 | must be divisible by head_count 276 | dropout (float): dropout parameter 277 | """ 278 | 279 | def __init__(self, head_count, model_dim, dropout=0.1, use_final_linear=True): 280 | assert model_dim % head_count == 0 281 | self.dim_per_head = model_dim // head_count 282 | self.model_dim = model_dim 283 | 284 | super(MultiHeadedAttention, self).__init__() 285 | self.head_count = head_count 286 | 287 | self.linear_keys = nn.Linear(model_dim, 288 | head_count * self.dim_per_head) 289 | self.linear_values = nn.Linear(model_dim, 290 | head_count * self.dim_per_head) 291 | self.linear_query = nn.Linear(model_dim, 292 | head_count * self.dim_per_head) 293 | self.softmax = nn.Softmax(dim=-1) 294 | self.dropout = nn.Dropout(dropout) 295 | self.use_final_linear = use_final_linear 296 | if (self.use_final_linear): 297 | self.final_linear = nn.Linear(model_dim, model_dim) 298 | 299 | def forward(self, key, value, query, mask=None, 300 | layer_cache=None, type=None, predefined_graph_1=None): 301 | """ 302 | Compute the context vector and the attention vectors. 303 | 304 | Args: 305 | key (`FloatTensor`): set of `key_len` 306 | key vectors `[batch, key_len, dim]` 307 | value (`FloatTensor`): set of `key_len` 308 | value vectors `[batch, key_len, dim]` 309 | query (`FloatTensor`): set of `query_len` 310 | query vectors `[batch, query_len, dim]` 311 | mask: binary mask indicating which keys have 312 | non-zero attention `[batch, query_len, key_len]` 313 | Returns: 314 | (`FloatTensor`, `FloatTensor`) : 315 | 316 | * output context vectors `[batch, query_len, dim]` 317 | * one of the attention vectors `[batch, query_len, key_len]` 318 | """ 319 | 320 | # CHECKS 321 | # batch, k_len, d = key.size() 322 | # batch_, k_len_, d_ = value.size() 323 | # aeq(batch, batch_) 324 | # aeq(k_len, k_len_) 325 | # aeq(d, d_) 326 | # batch_, q_len, d_ = query.size() 327 | # aeq(batch, batch_) 328 | # aeq(d, d_) 329 | # aeq(self.model_dim % 8, 0) 330 | # if mask is not None: 331 | # batch_, q_len_, k_len_ = mask.size() 332 | # aeq(batch_, batch) 333 | # aeq(k_len_, k_len) 334 | # aeq(q_len_ == q_len) 335 | # END CHECKS 336 | 337 | batch_size = key.size(0) 338 | dim_per_head = self.dim_per_head 339 | head_count = self.head_count 340 | key_len = key.size(1) 341 | query_len = query.size(1) 342 | 343 | def shape(x): 344 | """ projection """ 345 | return x.view(batch_size, -1, head_count, dim_per_head) \ 346 | .transpose(1, 2) 347 | 348 | def unshape(x): 349 | """ compute context """ 350 | return x.transpose(1, 2).contiguous() \ 351 | .view(batch_size, -1, head_count * dim_per_head) 352 | 353 | # 1) Project key, value, and query. 354 | if layer_cache is not None: 355 | if type == "self": 356 | query, key, value = self.linear_query(query), \ 357 | self.linear_keys(query), \ 358 | self.linear_values(query) 359 | 360 | key = shape(key) 361 | value = shape(value) 362 | 363 | if layer_cache is not None: 364 | device = key.device 365 | if layer_cache["self_keys"] is not None: 366 | key = torch.cat( 367 | (layer_cache["self_keys"].to(device), key), 368 | dim=2) 369 | if layer_cache["self_values"] is not None: 370 | value = torch.cat( 371 | (layer_cache["self_values"].to(device), value), 372 | dim=2) 373 | layer_cache["self_keys"] = key 374 | layer_cache["self_values"] = value 375 | elif type == "context": 376 | query = self.linear_query(query) 377 | if layer_cache is not None: 378 | if layer_cache["memory_keys"] is None: 379 | key, value = self.linear_keys(key), \ 380 | self.linear_values(value) 381 | key = shape(key) 382 | value = shape(value) 383 | else: 384 | key, value = layer_cache["memory_keys"], \ 385 | layer_cache["memory_values"] 386 | layer_cache["memory_keys"] = key 387 | layer_cache["memory_values"] = value 388 | else: 389 | key, value = self.linear_keys(key), \ 390 | self.linear_values(value) 391 | key = shape(key) 392 | value = shape(value) 393 | else: 394 | key = self.linear_keys(key) 395 | value = self.linear_values(value) 396 | query = self.linear_query(query) 397 | key = shape(key) 398 | value = shape(value) 399 | 400 | query = shape(query) 401 | 402 | key_len = key.size(2) 403 | query_len = query.size(2) 404 | 405 | # 2) Calculate and scale scores. 406 | query = query / math.sqrt(dim_per_head) 407 | scores = torch.matmul(query, key.transpose(2, 3)) 408 | 409 | if mask is not None: 410 | mask = mask.unsqueeze(1).expand_as(scores) 411 | scores = scores.masked_fill(mask.bool(), -10000.0) 412 | 413 | # 3) Apply attention dropout and compute context vectors. 414 | 415 | attn = self.softmax(scores) 416 | 417 | if (not predefined_graph_1 is None): 418 | attn_masked = attn[:, -1] * predefined_graph_1 419 | attn_masked = attn_masked / ( 420 | torch.sum(attn_masked, 2).unsqueeze(2) + 1e-9) 421 | 422 | attn = torch.cat([attn[:, :-1], attn_masked.unsqueeze(1)], 1) 423 | 424 | drop_attn = self.dropout(attn) 425 | if (self.use_final_linear): 426 | context = unshape(torch.matmul(drop_attn, value)) 427 | output = self.final_linear(context) 428 | return output 429 | else: 430 | context = torch.matmul(drop_attn, value) 431 | return context 432 | 433 | # CHECK 434 | # batch_, q_len_, d_ = output.size() 435 | # aeq(q_len, q_len_) 436 | # aeq(batch, batch_) 437 | # aeq(d, d_) 438 | 439 | # Return one attn 440 | 441 | 442 | class DecoderState(object): 443 | """Interface for grouping together the current state of a recurrent 444 | decoder. In the simplest case just represents the hidden state of 445 | the model. But can also be used for implementing various forms of 446 | input_feeding and non-recurrent models. 447 | 448 | Modules need to implement this to utilize beam search decoding. 449 | """ 450 | 451 | def detach(self): 452 | """ Need to document this """ 453 | self.hidden = tuple([_.detach() for _ in self.hidden]) 454 | self.input_feed = self.input_feed.detach() 455 | 456 | def beam_update(self, idx, positions, beam_size): 457 | """ Need to document this """ 458 | for e in self._all: 459 | sizes = e.size() 460 | br = sizes[1] 461 | if len(sizes) == 3: 462 | sent_states = e.view(sizes[0], beam_size, br // beam_size, 463 | sizes[2])[:, :, idx] 464 | else: 465 | sent_states = e.view(sizes[0], beam_size, 466 | br // beam_size, 467 | sizes[2], 468 | sizes[3])[:, :, idx] 469 | 470 | sent_states.data.copy_( 471 | sent_states.data.index_select(1, positions)) 472 | 473 | def map_batch_fn(self, fn): 474 | raise NotImplementedError() 475 | -------------------------------------------------------------------------------- /models/scorer.py: -------------------------------------------------------------------------------- 1 | from transformers import GPT2LMHeadModel, GPT2Tokenizer 2 | import torch 3 | 4 | 5 | class GPT2PerpModel(GPT2LMHeadModel): 6 | def forward( 7 | self, 8 | input_ids=None, 9 | past=None, 10 | attention_mask=None, 11 | token_type_ids=None, 12 | position_ids=None, 13 | head_mask=None, 14 | inputs_embeds=None, 15 | labels=None, 16 | ): 17 | transformer_outputs = self.transformer( 18 | input_ids, 19 | past=past, 20 | attention_mask=attention_mask, 21 | token_type_ids=token_type_ids, 22 | position_ids=position_ids, 23 | head_mask=head_mask, 24 | inputs_embeds=inputs_embeds, 25 | ) 26 | hidden_states = transformer_outputs[0] 27 | 28 | lm_logits = self.lm_head(hidden_states) 29 | 30 | outputs = (lm_logits,) + transformer_outputs[1:] 31 | if labels is not None: 32 | # Shift so that tokens < n predict n 33 | shift_logits = lm_logits[..., :-1, :].contiguous() 34 | shift_labels = labels[..., 1:].contiguous() 35 | shift_masks = attention_mask[..., 1:].contiguous() 36 | # Flatten the tokens 37 | loss_fct = torch.nn.CrossEntropyLoss(reduction='none') 38 | loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) 39 | loss = loss.view(shift_labels.shape[0], -1) * shift_masks 40 | loss = torch.sum(loss, -1) / torch.sum(shift_masks, -1) 41 | outputs = (loss,) + outputs 42 | return outputs # (loss), lm_logits, presents, (all hidden_states), (attentions) 43 | 44 | 45 | class SentenceScorer(object): 46 | def __init__(self, device): 47 | self.lm_model = GPT2PerpModel.from_pretrained('gpt2') 48 | self.tokenizer = GPT2Tokenizer.from_pretrained('gpt2') 49 | self.tokenizer.pad_token = self.tokenizer.eos_token 50 | self.lm_model.eval() 51 | self.lm_model.to(device) 52 | self.device = device 53 | for param in self.lm_model.parameters(): 54 | param.requires_grad = False 55 | 56 | def perplexity(self, inputs): 57 | if isinstance(inputs, str): 58 | inputs = [inputs] 59 | inputs = self.tokenizer.batch_encode_plus(inputs, pad_to_max_length=True) 60 | attention_mask = torch.tensor(inputs['attention_mask'], device=self.device) 61 | inputs = torch.tensor(inputs['input_ids'], device=self.device) 62 | loss = torch.scalar_tensor(20.0) 63 | if inputs.shape[1] > 1: 64 | loss = self.lm_model(inputs, attention_mask=attention_mask, labels=inputs)[0].squeeze() 65 | return torch.exp(loss) 66 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | apex 2 | numpy 3 | tqdm 4 | nltk 5 | torch==1.4.0 6 | transformers==2.8.0 7 | -------------------------------------------------------------------------------- /scripts/ft_bert_lm.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import glob 3 | import logging 4 | import os 5 | import pickle 6 | import random 7 | import re 8 | import shutil 9 | from typing import Dict, List, Tuple 10 | 11 | import numpy as np 12 | import torch 13 | from torch.nn.utils.rnn import pad_sequence 14 | from torch.utils.data import DataLoader, Dataset, RandomSampler, SequentialSampler 15 | from constant import BOS_TOKEN, EOS_TOKEN 16 | from models.bert_models import BertForLM 17 | from tqdm import tqdm, trange 18 | from transformers import ( 19 | AdamW, 20 | BertConfig, 21 | BertTokenizer, 22 | get_linear_schedule_with_warmup, 23 | ) 24 | 25 | logger = logging.getLogger(__name__) 26 | 27 | 28 | class TextDataset(Dataset): 29 | def __init__(self, tokenizer, args, file_path: str, 30 | block_size=512): 31 | assert os.path.isfile(file_path) 32 | 33 | block_size = block_size - (tokenizer.max_len - tokenizer.max_len_single_sentence) 34 | 35 | directory, filename = os.path.split(file_path) 36 | cached_features_file = os.path.join( 37 | directory, "bert_cached_lm_" + str(block_size) + "_" + filename) 38 | 39 | if os.path.exists(cached_features_file) and not args.overwrite_cache: 40 | logger.info("Loading features from cached file %s", cached_features_file) 41 | with open(cached_features_file, "rb") as handle: 42 | self.examples = pickle.load(handle) 43 | else: 44 | logger.info("Creating features from dataset file at %s", directory) 45 | 46 | self.examples = [] 47 | with open(file_path, encoding="utf-8") as f: 48 | text = f.read() 49 | 50 | bos_token_id = tokenizer.vocab[BOS_TOKEN] 51 | eos_token_id = tokenizer.vocab[EOS_TOKEN] 52 | logger.info("BOS %d, EOS %d", bos_token_id, eos_token_id) 53 | tokenized_text = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(text)) 54 | # Truncate in block of block_size 55 | for i in range(0, len(tokenized_text) - block_size + 1, block_size): 56 | self.examples.append([bos_token_id] + 57 | tokenized_text[i: i + block_size] + 58 | [eos_token_id]) 59 | 60 | logger.info("Saving features into cached file %s", cached_features_file) 61 | with open(cached_features_file, "wb") as handle: 62 | pickle.dump(self.examples, handle, protocol=pickle.HIGHEST_PROTOCOL) 63 | 64 | def __len__(self): 65 | return len(self.examples) 66 | 67 | def __getitem__(self, item): 68 | return torch.tensor(self.examples[item], dtype=torch.long) 69 | 70 | 71 | def load_and_cache_examples(args, tokenizer, evaluating=False): 72 | file_path = args.eval_data_file if evaluating else args.train_data_file 73 | return TextDataset(tokenizer, args, file_path=file_path, 74 | block_size=args.block_size) 75 | 76 | 77 | def set_seed(args): 78 | random.seed(args.seed) 79 | np.random.seed(args.seed) 80 | torch.manual_seed(args.seed) 81 | 82 | 83 | def _sorted_checkpoints(args, checkpoint_prefix="checkpoint", 84 | use_mtime=False) -> List[str]: 85 | ordering_and_checkpoint_path = [] 86 | 87 | glob_checkpoints = glob.glob( 88 | os.path.join(args.output_dir, "{}-*".format(checkpoint_prefix))) 89 | 90 | for path in glob_checkpoints: 91 | if use_mtime: 92 | ordering_and_checkpoint_path.append((os.path.getmtime(path), path)) 93 | else: 94 | regex_match = re.match(".*{}-([0-9]+)".format(checkpoint_prefix), path) 95 | if regex_match and regex_match.groups(): 96 | ordering_and_checkpoint_path.append( 97 | (int(regex_match.groups()[0]), path)) 98 | 99 | checkpoints_sorted = sorted(ordering_and_checkpoint_path) 100 | checkpoints_sorted = [checkpoint[1] for checkpoint in checkpoints_sorted] 101 | return checkpoints_sorted 102 | 103 | 104 | def _rotate_checkpoints(args, checkpoint_prefix="checkpoint", 105 | use_mtime=False) -> None: 106 | if not args.save_total_limit: 107 | return 108 | if args.save_total_limit <= 0: 109 | return 110 | 111 | # Check if we should delete older checkpoint(s) 112 | checkpoints_sorted = _sorted_checkpoints(args, checkpoint_prefix, use_mtime) 113 | if len(checkpoints_sorted) <= args.save_total_limit: 114 | return 115 | 116 | number_of_checkpoints_to_delete = max(0, len( 117 | checkpoints_sorted) - args.save_total_limit) 118 | checkpoints_to_be_deleted = checkpoints_sorted[:number_of_checkpoints_to_delete] 119 | for checkpoint in checkpoints_to_be_deleted: 120 | logger.info("Deleting older checkpoint [{}] due to args.save_total_limit".format(checkpoint)) 121 | shutil.rmtree(checkpoint) 122 | 123 | 124 | def train(args, train_dataset, model, tokenizer) -> Tuple[int, float]: 125 | args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu) 126 | 127 | def collate(examples: List[torch.Tensor]): 128 | if tokenizer.pad_token is None: 129 | return pad_sequence(examples, batch_first=True) 130 | return pad_sequence(examples, batch_first=True, padding_value=tokenizer.pad_token_id) 131 | 132 | train_sampler = RandomSampler(train_dataset) 133 | train_dataloader = DataLoader( 134 | train_dataset, sampler=train_sampler, batch_size=args.train_batch_size, 135 | collate_fn=collate 136 | ) 137 | 138 | if args.max_steps > 0: 139 | t_total = args.max_steps 140 | args.num_train_epochs = args.max_steps // (len(train_dataloader) // args.gradient_accumulation_steps) + 1 141 | else: 142 | t_total = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs 143 | 144 | # Prepare optimizer and schedule (linear warmup and decay) 145 | no_decay = ["bias", "LayerNorm.weight"] 146 | optimizer_grouped_parameters = [ 147 | { 148 | "params": [p for n, p in model.named_parameters() if 149 | not any(nd in n for nd in no_decay) and p.requires_grad], 150 | "weight_decay": args.weight_decay, 151 | }, 152 | {"params": [p for n, p in model.named_parameters() if 153 | any(nd in n for nd in no_decay) and p.requires_grad], 154 | "weight_decay": 0.0 155 | }, 156 | ] 157 | optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon) 158 | scheduler = get_linear_schedule_with_warmup( 159 | optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total 160 | ) 161 | 162 | # Train! 163 | logger.info("***** Running training *****") 164 | logger.info(" Num examples = %d", len(train_dataset)) 165 | logger.info(" Num Epochs = %d", args.num_train_epochs) 166 | logger.info(" Instantaneous batch size per GPU = %d", 167 | args.per_gpu_train_batch_size) 168 | logger.info(" Gradient Accumulation steps = %d", 169 | args.gradient_accumulation_steps) 170 | logger.info(" Total optimization steps = %d", t_total) 171 | 172 | global_step = 0 173 | epochs_trained = 0 174 | steps_trained_in_current_epoch = 0 175 | # Check if continuing training from a checkpoint 176 | if args.model_name_or_path and os.path.exists(args.model_name_or_path): 177 | try: 178 | # set global_step to gobal_step of last saved checkpoint from model path 179 | checkpoint_suffix = args.model_name_or_path.split("-")[-1].split("/")[0] 180 | global_step = int(checkpoint_suffix) 181 | epochs_trained = global_step // (len(train_dataloader) // args.gradient_accumulation_steps) 182 | steps_trained_in_current_epoch = global_step % (len(train_dataloader) // args.gradient_accumulation_steps) 183 | 184 | logger.info(" Continuing training from checkpoint, will skip to saved global_step") 185 | logger.info(" Continuing training from epoch %d", epochs_trained) 186 | logger.info(" Continuing training from global step %d", global_step) 187 | logger.info(" Will skip the first %d steps in the first epoch", steps_trained_in_current_epoch) 188 | except ValueError: 189 | logger.info(" Starting fine-tuning.") 190 | 191 | tr_loss, logging_loss = 0.0, 0.0 192 | 193 | model_to_resize = model.module if hasattr(model, "module") else model 194 | model_to_resize.resize_token_embeddings(len(tokenizer)) 195 | 196 | model.zero_grad() 197 | train_iterator = trange( 198 | epochs_trained, int(args.num_train_epochs), desc="Epoch") 199 | set_seed(args) # Added here for reproducibility 200 | for _ in train_iterator: 201 | epoch_iterator = tqdm(train_dataloader, desc="Iteration") 202 | for step, batch in enumerate(epoch_iterator): 203 | 204 | # Skip past any already trained steps if resuming training 205 | if steps_trained_in_current_epoch > 0: 206 | steps_trained_in_current_epoch -= 1 207 | continue 208 | 209 | inputs, labels = (batch, batch) 210 | inputs = inputs.to(args.device) 211 | labels = labels.to(args.device) 212 | model.train() 213 | outputs = model(inputs, labels=labels) 214 | loss = outputs[0] 215 | 216 | if args.n_gpu > 1: 217 | loss = loss.mean() # mean() to average on multi-gpu parallel training 218 | if args.gradient_accumulation_steps > 1: 219 | loss = loss / args.gradient_accumulation_steps 220 | loss.backward() 221 | 222 | tr_loss += loss.item() 223 | if (step + 1) % args.gradient_accumulation_steps == 0: 224 | torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm) 225 | optimizer.step() 226 | scheduler.step() # Update learning rate schedule 227 | model.zero_grad() 228 | global_step += 1 229 | 230 | if args.logging_steps > 0 and global_step % args.logging_steps == 0: 231 | # Log metrics 232 | perplexity = evaluate(args, model, tokenizer)['perplexity'] 233 | logging_loss = tr_loss / global_step 234 | logger.info(f'Step={global_step}, train loss={logging_loss:.4f}, eval perplexity={perplexity:.4f}') 235 | 236 | if 0 < args.max_steps < global_step: 237 | epoch_iterator.close() 238 | break 239 | 240 | if args.save_steps > 0: 241 | checkpoint_prefix = "checkpoint" 242 | # Save model checkpoint 243 | output_dir = os.path.join(args.output_dir, "{}-{}".format(checkpoint_prefix, global_step)) 244 | os.makedirs(output_dir, exist_ok=True) 245 | model_to_save = ( 246 | model.module if hasattr(model, "module") else model 247 | ) # Take care of distributed/parallel training 248 | model_to_save.save_pretrained(output_dir) 249 | tokenizer.save_pretrained(output_dir) 250 | 251 | torch.save(args, os.path.join(output_dir, "training_args.bin")) 252 | logger.info("Saving model checkpoint to %s", output_dir) 253 | 254 | _rotate_checkpoints(args, checkpoint_prefix) 255 | 256 | if 0 < args.max_steps < global_step: 257 | train_iterator.close() 258 | break 259 | 260 | return global_step, tr_loss / global_step 261 | 262 | 263 | def evaluate(args, model, tokenizer: BertTokenizer, prefix="") -> Dict: 264 | # Loop to handle MNLI double evaluation (matched, mis-matched) 265 | eval_output_dir = args.output_dir 266 | 267 | eval_dataset = load_and_cache_examples(args, tokenizer, evaluating=True) 268 | 269 | os.makedirs(eval_output_dir, exist_ok=True) 270 | 271 | args.eval_batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu) 272 | 273 | def collate(examples: List[torch.Tensor]): 274 | if tokenizer.pad_token is None: 275 | return pad_sequence(examples, batch_first=True) 276 | return pad_sequence(examples, batch_first=True, 277 | padding_value=tokenizer.pad_token_id) 278 | 279 | eval_sampler = SequentialSampler(eval_dataset) 280 | eval_dataloader = DataLoader( 281 | eval_dataset, sampler=eval_sampler, batch_size=args.eval_batch_size, 282 | collate_fn=collate 283 | ) 284 | 285 | # Eval! 286 | logger.info("***** Running evaluation {} *****".format(prefix)) 287 | logger.info(" Num examples = %d", len(eval_dataset)) 288 | logger.info(" Batch size = %d", args.eval_batch_size) 289 | eval_loss = 0.0 290 | nb_eval_steps = 0 291 | model.eval() 292 | 293 | for batch in tqdm(eval_dataloader, desc="Evaluating"): 294 | inputs, labels = (batch, batch) 295 | inputs = inputs.to(args.device) 296 | labels = labels.to(args.device) 297 | 298 | with torch.no_grad(): 299 | outputs = model(inputs, labels=labels) 300 | lm_loss = outputs[0] 301 | eval_loss += lm_loss.mean().item() 302 | nb_eval_steps += 1 303 | 304 | eval_loss = eval_loss / nb_eval_steps 305 | perplexity = torch.exp(torch.tensor(eval_loss)) 306 | 307 | result = {"perplexity": perplexity} 308 | return result 309 | 310 | 311 | def main(): 312 | parser = argparse.ArgumentParser() 313 | 314 | # Required parameters 315 | parser.add_argument( 316 | "--train_data_file", default=None, type=str, 317 | help="The input training data file (a text file)." 318 | ) 319 | parser.add_argument( 320 | "--output_dir", 321 | type=str, 322 | required=True, 323 | help="The output directory checkpoints will be written.", 324 | ) 325 | # Other parameters 326 | parser.add_argument( 327 | "--eval_data_file", 328 | default=None, 329 | type=str, 330 | help="An optional input evaluation data file to evaluate" 331 | " the perplexity on (a text file).", 332 | ) 333 | parser.add_argument( 334 | "--should_continue", action="store_true", 335 | help="Whether to continue from latest checkpoint in output_dir" 336 | ) 337 | parser.add_argument( 338 | "--model_name_or_path", 339 | default=None, 340 | type=str, 341 | required=True, 342 | help="The model checkpoint for weights initialization. " 343 | "Leave None if you want to train a model from scratch.", 344 | ) 345 | parser.add_argument( 346 | "--model_type", 347 | default=None, 348 | type=str, 349 | required=True, 350 | help="Bert or DistilBert", 351 | ) 352 | 353 | parser.add_argument( 354 | "--config_name", 355 | default=None, 356 | type=str, 357 | help="Optional pretrained config name or path if not the same as " 358 | "model_name_or_path. If both are None, initialize a new config.", 359 | ) 360 | parser.add_argument( 361 | "--tokenizer_name", 362 | default=None, 363 | type=str, 364 | help="Optional pretrained tokenizer name or path if not the same as" 365 | " model_name_or_path. If both are None, initialize a new tokenizer.", 366 | ) 367 | parser.add_argument( 368 | "--cache_dir", 369 | default=None, 370 | type=str, 371 | help="Optional directory to store the pre-trained models" 372 | " downloaded from s3 (instead of the default one)", 373 | ) 374 | parser.add_argument( 375 | "--block_size", 376 | default=-1, 377 | type=int, 378 | help="Optional input sequence length after tokenization." 379 | "The training dataset will be truncated in block of" 380 | " this size for training." 381 | "Default to the model max input length for single sentence" 382 | " inputs (take into account special tokens).", 383 | ) 384 | parser.add_argument("--freeze", action="store_true", 385 | help="Whether to freeze the transformer weights.") 386 | parser.add_argument("--do_train", action="store_true", 387 | help="Whether to run training.") 388 | parser.add_argument("--do_eval", action="store_true", 389 | help="Whether to run eval on the dev set.") 390 | parser.add_argument("--per_gpu_train_batch_size", default=4, type=int, 391 | help="Batch size per GPU/CPU for training.") 392 | parser.add_argument( 393 | "--per_gpu_eval_batch_size", default=4, type=int, 394 | help="Batch size per GPU/CPU for evaluation." 395 | ) 396 | parser.add_argument( 397 | "--gradient_accumulation_steps", 398 | type=int, 399 | default=1, 400 | help="Number of updates steps to accumulate " 401 | "before performing a backward/update pass.", 402 | ) 403 | parser.add_argument("--learning_rate", default=2e-5, type=float, 404 | help="The initial learning rate for Adam.") 405 | parser.add_argument("--weight_decay", default=0.0, type=float, 406 | help="Weight decay if we apply some.") 407 | parser.add_argument("--adam_epsilon", default=1e-8, type=float, 408 | help="Epsilon for Adam optimizer.") 409 | parser.add_argument("--max_grad_norm", default=1.0, type=float, 410 | help="Max gradient norm.") 411 | parser.add_argument( 412 | "--num_train_epochs", default=1.0, type=float, 413 | help="Total number of training epochs to perform." 414 | ) 415 | parser.add_argument( 416 | "--max_steps", 417 | default=-1, 418 | type=int, 419 | help="If > 0: set total number of training steps " 420 | "to perform. Override num_train_epochs.", 421 | ) 422 | parser.add_argument("--warmup_steps", default=0, type=int, 423 | help="Linear warmup over warmup_steps.") 424 | parser.add_argument("--logging_steps", type=int, default=1000, 425 | help="Log every X updates steps.") 426 | parser.add_argument("--save_steps", type=int, default=1, 427 | help="Save checkpoint every X updates steps.") 428 | parser.add_argument( 429 | "--save_total_limit", 430 | type=int, 431 | default=5, 432 | help="Limit the total amount of checkpoints, delete the older checkpoints" 433 | " in the output_dir, does not delete by default", 434 | ) 435 | parser.add_argument("--no_cuda", action="store_true", 436 | help="Avoid using CUDA when available") 437 | parser.add_argument( 438 | "--overwrite_output_dir", action="store_true", 439 | help="Overwrite the content of the output directory" 440 | ) 441 | parser.add_argument( 442 | "--overwrite_cache", action="store_true", 443 | help="Overwrite the cached training and evaluation sets" 444 | ) 445 | parser.add_argument("--seed", type=int, default=42, 446 | help="random seed for initialization") 447 | 448 | args = parser.parse_args() 449 | args.n_gpu = 1 450 | 451 | # Setup CUDA, GPU & distributed training 452 | device = torch.device( 453 | "cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu") 454 | args.device = device 455 | 456 | # Setup logging 457 | logging.basicConfig( 458 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 459 | datefmt="%m/%d/%Y %H:%M:%S", 460 | level=logging.INFO, 461 | ) 462 | 463 | # Set seed 464 | set_seed(args) 465 | config_class, tokenizer_class, model_class = BertConfig, BertTokenizer, BertForLM 466 | 467 | if args.config_name: 468 | config = config_class.from_pretrained(args.config_name, cache_dir=args.cache_dir) 469 | elif args.model_name_or_path: 470 | config = config_class.from_pretrained(args.model_name_or_path, cache_dir=args.cache_dir) 471 | else: 472 | config = config_class() 473 | 474 | if args.tokenizer_name: 475 | tokenizer = tokenizer_class.from_pretrained(args.tokenizer_name, cache_dir=args.cache_dir) 476 | elif args.model_name_or_path: 477 | tokenizer = tokenizer_class.from_pretrained(args.model_name_or_path, cache_dir=args.cache_dir) 478 | else: 479 | raise ValueError("You are instantiating a new {} tokenizer.".format(tokenizer_class.__name__)) 480 | 481 | if args.block_size <= 0: 482 | args.block_size = tokenizer.max_len 483 | # Our input block size will be the max possible for the model 484 | else: 485 | args.block_size = min(args.block_size, tokenizer.max_len) 486 | 487 | if args.model_name_or_path: 488 | model = model_class.from_pretrained( 489 | args.model_name_or_path, 490 | from_tf=bool(".ckpt" in args.model_name_or_path), 491 | config=config, 492 | cache_dir=args.cache_dir, 493 | ) 494 | else: 495 | logger.info("Training new model from scratch") 496 | model = model_class(config=config) 497 | 498 | model.to(args.device) 499 | if args.freeze: 500 | logger.info("Freezing bert weights") 501 | for param in model.bert.parameters(): 502 | param.requires_grad = False 503 | 504 | logger.info("Training/evaluation parameters %s", args) 505 | 506 | # Training 507 | if args.do_train: 508 | train_dataset = load_and_cache_examples(args, tokenizer, evaluating=False) 509 | global_step, tr_loss = train(args, train_dataset, model, tokenizer) 510 | logger.info(" global_step = %s, average loss = %s", global_step, tr_loss) 511 | 512 | # Saving best-practices: if you use save_pretrained for the 513 | # model and tokenizer, you can reload them using from_pretrained() 514 | if args.do_train: 515 | os.makedirs(args.output_dir, exist_ok=True) 516 | 517 | logger.info("Saving model checkpoint to %s", args.output_dir) 518 | # Save a trained model, configuration and tokenizer using 519 | # `save_pretrained()`. They can then be reloaded using `from_pretrained()` 520 | model_to_save = ( 521 | model.module if hasattr(model, "module") else model 522 | ) # Take care of distributed/parallel training 523 | model_to_save.save_pretrained(args.output_dir) 524 | tokenizer.save_pretrained(args.output_dir) 525 | 526 | # Good practice: save your training arguments 527 | # together with the trained model 528 | torch.save(args, os.path.join(args.output_dir, "training_args.bin")) 529 | 530 | # Evaluation 531 | results = {} 532 | if args.do_eval: 533 | checkpoints = [args.output_dir] 534 | logger.info("Evaluate the following checkpoints: %s", checkpoints) 535 | tokenizer = tokenizer_class.from_pretrained(args.output_dir) 536 | 537 | for checkpoint in checkpoints: 538 | global_step = checkpoint.split("-")[-1] if len(checkpoints) > 1 else "" 539 | prefix = checkpoint.split("/")[-1] if checkpoint.find( 540 | "checkpoint") != -1 else "" 541 | 542 | model = model_class.from_pretrained(checkpoint) 543 | model.to(args.device) 544 | result = evaluate(args, model, tokenizer, prefix=prefix) 545 | logger.info((global_step, result)) 546 | 547 | return results 548 | 549 | 550 | if __name__ == "__main__": 551 | main() 552 | -------------------------------------------------------------------------------- /scripts/ft_polyencoder_lm.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import glob 3 | import logging 4 | import os 5 | import pickle 6 | import random 7 | import re 8 | import shutil 9 | from typing import Dict, List, Tuple 10 | 11 | import numpy as np 12 | import torch 13 | from torch.nn.utils.rnn import pad_sequence 14 | from torch.utils.data import DataLoader, Dataset, RandomSampler, SequentialSampler 15 | from models.polyencoder import PolyEncoderTokenizer, PolyEncoderLM 16 | from tqdm import tqdm, trange 17 | from transformers import ( 18 | AdamW, 19 | get_linear_schedule_with_warmup, 20 | ) 21 | 22 | logger = logging.getLogger(__name__) 23 | 24 | 25 | class TextDataset(Dataset): 26 | def __init__(self, tokenizer: PolyEncoderTokenizer, args, file_path: str, block_size=512): 27 | assert os.path.isfile(file_path) 28 | 29 | block_size = block_size - 2 30 | directory, filename = os.path.split(file_path) 31 | cached_features_file = os.path.join(directory, "polyencoder_cached_lm_" + str(block_size) + "_" + filename) 32 | 33 | if os.path.exists(cached_features_file) and not args.overwrite_cache: 34 | logger.info("Loading features from cached file %s", cached_features_file) 35 | with open(cached_features_file, "rb") as handle: 36 | self.examples = pickle.load(handle) 37 | else: 38 | logger.info("Creating features from dataset file at %s", directory) 39 | 40 | self.examples = [] 41 | with open(file_path, encoding="utf-8") as f: 42 | text = f.read() 43 | 44 | bos_token_id = tokenizer.dict[tokenizer.dict.start_token] 45 | eos_token_id = tokenizer.dict[tokenizer.dict.end_token] 46 | 47 | logger.info("BOS %d, EOS %d", bos_token_id, eos_token_id) 48 | tokenized_text = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(text)) 49 | # Truncate in block of block_size 50 | for i in range(0, len(tokenized_text) - block_size + 1, block_size): 51 | self.examples.append([bos_token_id] + 52 | tokenized_text[i: i + block_size] + 53 | [eos_token_id]) 54 | 55 | logger.info("Saving features into cached file %s", cached_features_file) 56 | with open(cached_features_file, "wb") as handle: 57 | pickle.dump(self.examples, handle, protocol=pickle.HIGHEST_PROTOCOL) 58 | 59 | def __len__(self): 60 | return len(self.examples) 61 | 62 | def __getitem__(self, item): 63 | return torch.tensor(self.examples[item], dtype=torch.long) 64 | 65 | 66 | def load_and_cache_examples(args, tokenizer, evaluating=False): 67 | file_path = args.eval_data_file if evaluating else args.train_data_file 68 | return TextDataset(tokenizer, args, file_path=file_path, block_size=args.block_size) 69 | 70 | 71 | def set_seed(args): 72 | random.seed(args.seed) 73 | np.random.seed(args.seed) 74 | torch.manual_seed(args.seed) 75 | 76 | 77 | def _sorted_checkpoints(args, checkpoint_prefix="checkpoint", use_mtime=False) -> List[str]: 78 | ordering_and_checkpoint_path = [] 79 | 80 | glob_checkpoints = glob.glob( 81 | os.path.join(args.output_dir, "{}-*".format(checkpoint_prefix))) 82 | 83 | for path in glob_checkpoints: 84 | if use_mtime: 85 | ordering_and_checkpoint_path.append((os.path.getmtime(path), path)) 86 | else: 87 | regex_match = re.match(".*{}-([0-9]+)".format(checkpoint_prefix), path) 88 | if regex_match and regex_match.groups(): 89 | ordering_and_checkpoint_path.append( 90 | (int(regex_match.groups()[0]), path)) 91 | 92 | checkpoints_sorted = sorted(ordering_and_checkpoint_path) 93 | checkpoints_sorted = [checkpoint[1] for checkpoint in checkpoints_sorted] 94 | return checkpoints_sorted 95 | 96 | 97 | def _rotate_checkpoints(args, checkpoint_prefix="checkpoint", use_mtime=False) -> None: 98 | if not args.save_total_limit: 99 | return 100 | if args.save_total_limit <= 0: 101 | return 102 | 103 | # Check if we should delete older checkpoint(s) 104 | checkpoints_sorted = _sorted_checkpoints(args, checkpoint_prefix, use_mtime) 105 | if len(checkpoints_sorted) <= args.save_total_limit: 106 | return 107 | 108 | number_of_checkpoints_to_delete = max(0, len( 109 | checkpoints_sorted) - args.save_total_limit) 110 | checkpoints_to_be_deleted = checkpoints_sorted[ 111 | :number_of_checkpoints_to_delete] 112 | for checkpoint in checkpoints_to_be_deleted: 113 | logger.info( 114 | "Deleting older checkpoint [{}] due to args.save_total_limit".format( 115 | checkpoint)) 116 | shutil.rmtree(checkpoint) 117 | 118 | 119 | def train(args, train_dataset, model, tokenizer) -> Tuple[int, float]: 120 | args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu) 121 | 122 | def collate(examples: List[torch.Tensor]): 123 | if tokenizer.pad_token is None: 124 | return pad_sequence(examples, batch_first=True) 125 | return pad_sequence(examples, batch_first=True, padding_value=tokenizer.pad_token_id) 126 | 127 | train_sampler = RandomSampler(train_dataset) 128 | train_dataloader = DataLoader( 129 | train_dataset, sampler=train_sampler, batch_size=args.train_batch_size, 130 | collate_fn=collate 131 | ) 132 | 133 | if args.max_steps > 0: 134 | t_total = args.max_steps 135 | args.num_train_epochs = args.max_steps // (len(train_dataloader) // args.gradient_accumulation_steps) + 1 136 | else: 137 | t_total = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs 138 | 139 | # Prepare optimizer and schedule (linear warmup and decay) 140 | no_decay = ["bias", "LayerNorm.weight"] 141 | optimizer_grouped_parameters = [ 142 | { 143 | "params": [p for n, p in model.named_parameters() if 144 | not any(nd in n for nd in no_decay) and p.requires_grad], 145 | "weight_decay": args.weight_decay, 146 | }, 147 | {"params": [p for n, p in model.named_parameters() if 148 | any(nd in n for nd in no_decay) and p.requires_grad], 149 | "weight_decay": 0.0 150 | }, 151 | ] 152 | optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon) 153 | scheduler = get_linear_schedule_with_warmup( 154 | optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total 155 | ) 156 | 157 | # Train! 158 | logger.info("***** Running training *****") 159 | logger.info(" Num examples = %d", len(train_dataset)) 160 | logger.info(" Num Epochs = %d", args.num_train_epochs) 161 | logger.info(" Instantaneous batch size per GPU = %d", 162 | args.per_gpu_train_batch_size) 163 | logger.info(" Gradient Accumulation steps = %d", 164 | args.gradient_accumulation_steps) 165 | logger.info(" Total optimization steps = %d", t_total) 166 | 167 | global_step = 0 168 | epochs_trained = 0 169 | steps_trained_in_current_epoch = 0 170 | tr_loss, logging_loss = 0.0, 0.0 171 | 172 | model.zero_grad() 173 | train_iterator = trange( 174 | epochs_trained, int(args.num_train_epochs), desc="Epoch") 175 | set_seed(args) # Added here for reproducibility 176 | for _ in train_iterator: 177 | epoch_iterator = tqdm(train_dataloader, desc="Iteration") 178 | for step, batch in enumerate(epoch_iterator): 179 | 180 | # Skip past any already trained steps if resuming training 181 | if steps_trained_in_current_epoch > 0: 182 | steps_trained_in_current_epoch -= 1 183 | continue 184 | 185 | inputs, labels = (batch, batch) 186 | inputs = inputs.to(args.device) 187 | labels = labels.to(args.device) 188 | model.train() 189 | outputs = model(inputs, labels=labels) 190 | loss = outputs[0] 191 | 192 | if args.n_gpu > 1: 193 | loss = loss.mean() # mean() to average on multi-gpu parallel training 194 | if args.gradient_accumulation_steps > 1: 195 | loss = loss / args.gradient_accumulation_steps 196 | loss.backward() 197 | 198 | tr_loss += loss.item() 199 | if (step + 1) % args.gradient_accumulation_steps == 0: 200 | torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm) 201 | optimizer.step() 202 | scheduler.step() # Update learning rate schedule 203 | model.zero_grad() 204 | global_step += 1 205 | 206 | if args.logging_steps > 0 and global_step % args.logging_steps == 0: 207 | # Log metrics 208 | perplexity = evaluate(args, model, tokenizer)['perplexity'] 209 | logging_loss = tr_loss / global_step 210 | logger.info(f'Step={global_step}, train loss={logging_loss:.4f}, eval perplexity={perplexity:.4f}') 211 | 212 | if 0 < args.max_steps < global_step: 213 | epoch_iterator.close() 214 | break 215 | 216 | checkpoint_prefix = "checkpoint" 217 | # Save model checkpoint 218 | output_dir = os.path.join(args.output_dir, "{}-{}".format(checkpoint_prefix, global_step)) 219 | os.makedirs(output_dir, exist_ok=True) 220 | model_to_save = ( 221 | model.module if hasattr(model, "module") else model 222 | ) # Take care of distributed/parallel training 223 | model_to_save.save_pretrained(output_dir) 224 | 225 | torch.save(args, os.path.join(output_dir, "training_args.bin")) 226 | logger.info("Saving model checkpoint to %s", output_dir) 227 | 228 | _rotate_checkpoints(args, checkpoint_prefix) 229 | 230 | if 0 < args.max_steps < global_step: 231 | train_iterator.close() 232 | break 233 | 234 | return global_step, tr_loss / global_step 235 | 236 | 237 | def evaluate(args, model, tokenizer: PolyEncoderTokenizer, prefix="") -> Dict: 238 | # Loop to handle MNLI double evaluation (matched, mis-matched) 239 | eval_output_dir = args.output_dir 240 | 241 | eval_dataset = load_and_cache_examples(args, tokenizer, evaluating=True) 242 | 243 | os.makedirs(eval_output_dir, exist_ok=True) 244 | 245 | args.eval_batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu) 246 | 247 | def collate(examples: List[torch.Tensor]): 248 | if tokenizer.pad_token is None: 249 | return pad_sequence(examples, batch_first=True) 250 | return pad_sequence(examples, batch_first=True, padding_value=tokenizer.pad_token_id) 251 | 252 | eval_sampler = SequentialSampler(eval_dataset) 253 | eval_dataloader = DataLoader( 254 | eval_dataset, sampler=eval_sampler, batch_size=args.eval_batch_size, 255 | collate_fn=collate 256 | ) 257 | 258 | # Eval! 259 | logger.info("***** Running evaluation {} *****".format(prefix)) 260 | logger.info(" Num examples = %d", len(eval_dataset)) 261 | logger.info(" Batch size = %d", args.eval_batch_size) 262 | eval_loss = 0.0 263 | nb_eval_steps = 0 264 | model.eval() 265 | 266 | for batch in tqdm(eval_dataloader, desc="Evaluating"): 267 | inputs, labels = (batch, batch) 268 | inputs = inputs.to(args.device) 269 | labels = labels.to(args.device) 270 | 271 | with torch.no_grad(): 272 | outputs = model(inputs, labels=labels) 273 | lm_loss = outputs[0] 274 | eval_loss += lm_loss.mean().item() 275 | nb_eval_steps += 1 276 | 277 | eval_loss = eval_loss / nb_eval_steps 278 | perplexity = torch.exp(torch.tensor(eval_loss)) 279 | 280 | result = {"perplexity": perplexity} 281 | return result 282 | 283 | 284 | def main(): 285 | parser = argparse.ArgumentParser() 286 | 287 | # Required parameters 288 | parser.add_argument( 289 | "--train_data_file", default=None, type=str, 290 | help="The input training data file (a text file)." 291 | ) 292 | parser.add_argument( 293 | "--output_dir", 294 | type=str, 295 | required=True, 296 | help="The output directory checkpoints will be written.", 297 | ) 298 | # Other parameters 299 | parser.add_argument( 300 | "--eval_data_file", 301 | default=None, 302 | type=str, 303 | help="An optional input evaluation data file to evaluate" 304 | " the perplexity on (a text file).", 305 | ) 306 | parser.add_argument( 307 | "--should_continue", action="store_true", 308 | help="Whether to continue from latest checkpoint in output_dir" 309 | ) 310 | parser.add_argument( 311 | "--cache_dir", 312 | default=None, 313 | type=str, 314 | help="Optional directory to store the pre-trained models" 315 | " downloaded from s3 (instead of the default one)", 316 | ) 317 | parser.add_argument( 318 | "--block_size", 319 | default=512, 320 | type=int, 321 | help="Optional input sequence length after tokenization." 322 | "The training dataset will be truncated in block of" 323 | " this size for training." 324 | "Default to the model max input length for single sentence" 325 | " inputs (take into account special tokens).", 326 | ) 327 | parser.add_argument("--do_train", action="store_true", 328 | help="Whether to run training.") 329 | parser.add_argument("--do_eval", action="store_true", 330 | help="Whether to run eval on the dev set.") 331 | parser.add_argument("--per_gpu_train_batch_size", default=4, type=int, 332 | help="Batch size per GPU/CPU for training.") 333 | parser.add_argument( 334 | "--per_gpu_eval_batch_size", default=4, type=int, 335 | help="Batch size per GPU/CPU for evaluation." 336 | ) 337 | parser.add_argument( 338 | "--gradient_accumulation_steps", 339 | type=int, 340 | default=1, 341 | help="Number of updates steps to accumulate " 342 | "before performing a backward/update pass.", 343 | ) 344 | parser.add_argument("--learning_rate", default=2e-5, type=float, 345 | help="The initial learning rate for Adam.") 346 | parser.add_argument("--weight_decay", default=0.0, type=float, 347 | help="Weight decay if we apply some.") 348 | parser.add_argument("--adam_epsilon", default=1e-8, type=float, 349 | help="Epsilon for Adam optimizer.") 350 | parser.add_argument("--max_grad_norm", default=1.0, type=float, 351 | help="Max gradient norm.") 352 | parser.add_argument( 353 | "--num_train_epochs", default=2.0, type=float, 354 | help="Total number of training epochs to perform." 355 | ) 356 | parser.add_argument( 357 | "--max_steps", 358 | default=-1, 359 | type=int, 360 | help="If > 0: set total number of training steps " 361 | "to perform. Override num_train_epochs.", 362 | ) 363 | parser.add_argument("--warmup_steps", default=100, type=int, 364 | help="Linear warmup over warmup_steps.") 365 | parser.add_argument("--logging_steps", type=int, default=1000, 366 | help="Log every X updates steps.") 367 | parser.add_argument("--save_steps", type=int, default=0, 368 | help="Save checkpoint every X updates steps.") 369 | parser.add_argument( 370 | "--save_total_limit", 371 | type=int, 372 | default=5, 373 | help="Limit the total amount of checkpoints, delete the older checkpoints" 374 | " in the output_dir, does not delete by default", 375 | ) 376 | parser.add_argument("--no_cuda", action="store_true", 377 | help="Avoid using CUDA when available") 378 | parser.add_argument( 379 | "--overwrite_output_dir", action="store_true", 380 | help="Overwrite the content of the output directory" 381 | ) 382 | parser.add_argument( 383 | "--overwrite_cache", action="store_true", 384 | help="Overwrite the cached training and evaluation sets" 385 | ) 386 | parser.add_argument("--seed", type=int, default=42, 387 | help="random seed for initialization") 388 | 389 | args = parser.parse_args() 390 | args.n_gpu = 1 391 | 392 | # Setup CUDA, GPU & distributed training 393 | device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu") 394 | args.device = device 395 | 396 | # Setup logging 397 | logging.basicConfig( 398 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 399 | datefmt="%m/%d/%Y %H:%M:%S", 400 | level=logging.INFO, 401 | ) 402 | 403 | # Set seed 404 | set_seed(args) 405 | 406 | tokenizer = PolyEncoderTokenizer.from_pretrained() 407 | 408 | if args.block_size <= 0: 409 | args.block_size = tokenizer.max_len 410 | # Our input block size will be the max possible for the model 411 | else: 412 | args.block_size = min(args.block_size, tokenizer.max_len) 413 | 414 | model = PolyEncoderLM.from_pretrained() 415 | 416 | model.to(args.device) 417 | logger.info("Training/evaluation parameters %s", args) 418 | 419 | # Training 420 | if args.do_train: 421 | train_dataset = load_and_cache_examples(args, tokenizer, evaluating=False) 422 | global_step, tr_loss = train(args, train_dataset, model, tokenizer) 423 | logger.info(" global_step = %s, average loss = %s", global_step, tr_loss) 424 | 425 | # Saving best-practices: if you use save_pretrained for the 426 | # model and tokenizer, you can reload them using from_pretrained() 427 | if args.do_train: 428 | os.makedirs(args.output_dir, exist_ok=True) 429 | 430 | logger.info("Saving model checkpoint to %s", args.output_dir) 431 | # Save a trained model, configuration and tokenizer using 432 | # `save_pretrained()`. They can then be reloaded using `from_pretrained()` 433 | model_to_save = ( 434 | model.module if hasattr(model, "module") else model 435 | ) # Take care of distributed/parallel training 436 | model_to_save.save_pretrained(args.output_dir) 437 | 438 | # Good practice: save your training arguments 439 | # together with the trained model 440 | torch.save(args, os.path.join(args.output_dir, "training_args.bin")) 441 | 442 | # Evaluation 443 | results = {} 444 | if args.do_eval: 445 | checkpoints = [args.output_dir] 446 | logger.info("Evaluate the following checkpoints: %s", checkpoints) 447 | tokenizer = PolyEncoderTokenizer.from_pretrained() 448 | 449 | for checkpoint in checkpoints: 450 | global_step = checkpoint.split("-")[-1] if len(checkpoints) > 1 else "" 451 | prefix = checkpoint.split("/")[-1] if checkpoint.find("checkpoint") != -1 else "" 452 | model = PolyEncoderLM.from_pretrained(checkpoint=checkpoint) 453 | model.to(args.device) 454 | result = evaluate(args, model, tokenizer, prefix=prefix) 455 | logger.info((global_step, result)) 456 | 457 | return results 458 | 459 | 460 | if __name__ == "__main__": 461 | main() 462 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csong27/collision-bert/43eda087bf6d632bdb150d98e934206327f8d082/utils/__init__.py -------------------------------------------------------------------------------- /utils/constraints_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from nltk.corpus import stopwords 3 | from transformers import BertTokenizer, GPT2Tokenizer 4 | 5 | COMMON_WORDS = ['the', 'of', 'and', 'a', 'to', 'in', 'is', 'you', 'that', 'it'] 6 | STOPWORDS = set(stopwords.words('english')) 7 | 8 | 9 | def get_inputs_filter_ids(inputs, tokenizer): 10 | tokens = [w for w in tokenizer.tokenize(inputs) if w.isalpha() and w not in STOPWORDS] 11 | return tokenizer.convert_tokens_to_ids(tokens) 12 | 13 | 14 | def get_sub_masks(tokenizer, device, prob=False): 15 | # masking for all subwords in the vocabulary 16 | vocab = tokenizer.get_vocab() 17 | 18 | def is_special_token(w): 19 | if isinstance(tokenizer, BertTokenizer) and w.startswith('##'): 20 | return True 21 | if isinstance(tokenizer, GPT2Tokenizer) and not w.startswith('Ġ'): 22 | return True 23 | if w[0] == '[' and w[-1] == ']': 24 | return True 25 | if w[0] == '<' and w[-1] == '>': 26 | return True 27 | if w in ['=', '@', 'Ġ=', 'Ġ@'] and w in vocab: 28 | return True 29 | return False 30 | 31 | filter_ids = [vocab[w] for w in vocab if is_special_token(w)] 32 | if prob: 33 | prob_mask = torch.ones(tokenizer.vocab_size, device=device) 34 | prob_mask[filter_ids] = 0. 35 | else: 36 | prob_mask = torch.zeros(tokenizer.vocab_size, device=device) 37 | prob_mask[filter_ids] = -1e9 38 | return prob_mask 39 | 40 | 41 | def get_poly_sub_masks(tokenizer, device, prob=False): 42 | filter_ids = [tokenizer.dict[w] for w in tokenizer.dict.tok2ind 43 | if not w.isalnum()] 44 | if prob: 45 | prob_mask = torch.ones(tokenizer.vocab_size, device=device) 46 | prob_mask[filter_ids] = 0. 47 | else: 48 | prob_mask = torch.zeros(tokenizer.vocab_size, device=device) 49 | prob_mask[filter_ids] = -1e9 50 | return prob_mask 51 | 52 | 53 | def create_constraints(seq_len, tokenizer, device, prob=False): 54 | stopword_ids = [tokenizer.vocab[w] for w in COMMON_WORDS[:5] if w in tokenizer.vocab] 55 | if prob: 56 | masks = torch.zeros(seq_len, tokenizer.vocab_size, device=device) 57 | else: 58 | masks = torch.zeros(seq_len, tokenizer.vocab_size, device=device) - 1e9 59 | 60 | for t in range(seq_len): 61 | if t >= seq_len // 2: 62 | masks[t, stopword_ids] = 1.0 if prob else 0.0 63 | else: 64 | masks[t] = 1.0 if prob else 0. 65 | return masks 66 | 67 | 68 | def create_poly_constraints(seq_len, tokenizer, device, prob=False): 69 | stopword_ids = [tokenizer.dict[w] for w in COMMON_WORDS[:5] if w in tokenizer.dict.tok2ind] 70 | if prob: 71 | masks = torch.zeros(seq_len, tokenizer.vocab_size, device=device) 72 | else: 73 | masks = torch.zeros(seq_len, tokenizer.vocab_size, device=device) - 1e9 74 | 75 | for t in range(seq_len): 76 | if t >= seq_len // 3: 77 | masks[t, stopword_ids] = 1.0 if prob else 0.0 78 | else: 79 | masks[t] = 1.0 if prob else 0. 80 | return masks 81 | -------------------------------------------------------------------------------- /utils/logging_utils.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | 4 | def log(msg): 5 | if msg[-1] != '\n': 6 | msg += '\n' 7 | sys.stderr.write(msg) 8 | sys.stderr.flush() 9 | -------------------------------------------------------------------------------- /utils/optimization_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from utils.logging_utils import log 4 | 5 | SMALL_CONST = 1e-15 6 | 7 | 8 | def to_var(p, device): 9 | return torch.tensor(p, requires_grad=True, device=device) 10 | 11 | 12 | def kl_loss(probs, unpert_probs): 13 | unpert_probs = unpert_probs + SMALL_CONST * (unpert_probs <= SMALL_CONST).float().detach() 14 | correction = SMALL_CONST * (probs <= SMALL_CONST).float().detach() 15 | corrected_probs = probs + correction.detach() 16 | kl_loss = (corrected_probs * (corrected_probs / unpert_probs).log()).sum() 17 | return kl_loss 18 | 19 | 20 | def perturb_logits( 21 | unpert_logits, 22 | stepsize=0.01, 23 | target_model_wrapper=None, 24 | num_iterations=3, 25 | kl_scale=0.01, 26 | temperature=1.0, 27 | device="cuda", 28 | verbose=False, 29 | logit_mask=0., 30 | ): 31 | # Generate inital perturbed past 32 | grad_accumulator = np.zeros(unpert_logits.shape, dtype=np.float32) 33 | perturbation = to_var(grad_accumulator, device=device) 34 | optimizer = torch.optim.Adam([perturbation], lr=stepsize) 35 | 36 | # accumulate perturbations for num_iterations 37 | for i in range(num_iterations): 38 | optimizer.zero_grad() 39 | # Compute hidden using perturbed past 40 | logits = unpert_logits * temperature + perturbation + logit_mask 41 | probs = torch.softmax(logits / temperature, -1) 42 | unpert_probs = torch.softmax(unpert_logits, -1) 43 | 44 | loss = torch.scalar_tensor(0.0).to(device) 45 | loss_list = [] 46 | 47 | if target_model_wrapper is not None: 48 | discrim_loss = target_model_wrapper(probs) 49 | if verbose and i % 2 == 0: 50 | log(f"Iteration {i + 1}, pplm_discrim_loss: {discrim_loss.data.cpu().numpy()}") 51 | loss += discrim_loss 52 | loss_list.append(discrim_loss) 53 | 54 | if kl_scale > 0.0: 55 | unpert_probs = unpert_probs + SMALL_CONST * (unpert_probs <= SMALL_CONST).float().to(device).detach() 56 | correction = SMALL_CONST * (probs <= SMALL_CONST).float().to(device).detach() 57 | corrected_probs = probs + correction.detach() 58 | kl_loss = kl_scale * (corrected_probs * (corrected_probs / unpert_probs).log()).sum() 59 | loss += kl_loss 60 | 61 | # compute gradients 62 | loss.backward() 63 | optimizer.step() 64 | 65 | # apply the accumulated perturbations to the past 66 | pert_logits = unpert_logits * temperature + perturbation 67 | return pert_logits 68 | -------------------------------------------------------------------------------- /utils/tokenizer_utils.py: -------------------------------------------------------------------------------- 1 | from transformers import BertTokenizer 2 | from utils.logging_utils import log 3 | 4 | 5 | def tokenize_adversarial_example(input_ids, tokenizer): 6 | if not isinstance(input_ids, list): 7 | input_ids = input_ids.squeeze().cpu().tolist() 8 | 9 | # make sure decoded string can be tokenized to the same tokens 10 | sep_indices = [] 11 | for i, token_id in enumerate(input_ids): 12 | if token_id == tokenizer.sep_token_id: 13 | sep_indices.append(i) 14 | 15 | if len(sep_indices) == 1 or tokenizer.sep_token_id == tokenizer.cls_token_id: 16 | # input is a single text 17 | decoded = tokenizer.decode(input_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False) 18 | encoded_ids = tokenizer.encode(decoded) 19 | else: 20 | # input is a pair of texts 21 | assert len(sep_indices) == 2, sep_indices 22 | a_input_ids = input_ids[1:sep_indices[0]] 23 | b_input_ids = input_ids[sep_indices[0] + 1: sep_indices[1]] 24 | a_decoded = tokenizer.decode(a_input_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False) 25 | b_decoded = tokenizer.decode(b_input_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False) 26 | encoded_ids = tokenizer.encode(a_decoded, b_decoded) 27 | 28 | return encoded_ids 29 | 30 | 31 | def valid_tokenization(input_ids, tokenizer: BertTokenizer, verbose=False): 32 | input_ids = input_ids.squeeze().cpu().tolist() 33 | 34 | if input_ids[0] != tokenizer.cls_token_id: 35 | input_ids = [tokenizer.cls_token_id] + input_ids 36 | if input_ids[-1] != tokenizer.sep_token_id: 37 | input_ids = input_ids + [tokenizer.sep_token_id] 38 | 39 | # make sure decoded string can be tokenized to the same tokens 40 | encoded_ids = tokenize_adversarial_example(input_ids, tokenizer) 41 | valid = len(input_ids) == len(encoded_ids) and all(i == j for i, j in zip(input_ids, encoded_ids)) 42 | if verbose and not valid: 43 | log(f'Inputs: {tokenizer.convert_ids_to_tokens(input_ids)}') 44 | log(f'Re-encoded: {tokenizer.convert_ids_to_tokens(encoded_ids)}') 45 | return valid, encoded_ids 46 | --------------------------------------------------------------------------------