├── LICENSE ├── README.md ├── generate_keyphrase_universe.py ├── model ├── __init__.py ├── model.py └── outputs.py ├── model_cards ├── KBIR.md └── KeyBART.md ├── pretrain_runner.py ├── run_pretrain_kp_infill_replacement_bart_kg_oagkx.sh ├── run_pretrain_kp_infill_replacement_oagkx.sh ├── trainer ├── __init__.py └── trainer.py └── utils ├── __init__.py ├── data_collators.py ├── logger.py └── utils.py /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Description 2 | This repository contains the experimental code used in pre-training the KBIR and KeyBART models as described in Learning Rich Representation for Keyphrases (https://arxiv.org/pdf/2112.08547.pdf) and to appear in Findings of NAACL 2022. 3 | 4 | Some of the code builds on top of code from HuggingFace Transformers (https://github.com/huggingface/transformers) and also takes inspiration from SpanBERT (https://github.com/facebookresearch/SpanBERT) 5 | 6 | # Running the pre-training 7 | Use the two bash scripts for running pre-training for KBIR and KeyBART respectively. 8 | 9 | # Accessing Pre-trained models 10 | Models are uploaded to HuggingFace along with Model Cards describing usage. 11 | 12 | KBIR: https://huggingface.co/bloomberg/KBIR 13 | 14 | KeyBART: https://huggingface.co/bloomberg/KeyBART 15 | 16 | ## Citation 17 | ``` 18 | @article{kulkarni2021kbirkeybart, 19 | title={Learning Rich Representation of Keyphrases from Text}, 20 | author={Mayank Kulkarni and Debanjan Mahata and Ravneet Arora and Rajarshi Bhowmik}, 21 | journal={arXiv preprint arXiv:2112.08547}, 22 | year={2021} 23 | } 24 | ``` 25 | 26 | ## License 27 | KBIR and KeyBART are Apache 2.0. The license applies to the pre-trained models as well. 28 | 29 | # Contact 30 | For any questions reach out to mkulkarni24@bloomberg.net 31 | -------------------------------------------------------------------------------- /generate_keyphrase_universe.py: -------------------------------------------------------------------------------- 1 | from glob import glob 2 | import json 3 | import os 4 | import argparse 5 | import random 6 | 7 | 8 | def parse_keyphrases(text, keywords): 9 | keyphrases = [] 10 | for keyphrase in keywords: 11 | keyphrase_index = text.find(keyphrase) 12 | if keyphrase_index == -1: 13 | keyphrase_index = text.lower().find(keyphrase) 14 | # Can't find keyphrase in text 15 | if keyphrase_index == -1: 16 | continue 17 | keyphrase = text[keyphrase_index : keyphrase_index + len(keyphrase)] 18 | # Decide whether a space is required before for the tokenizer to have a consistent behavior 19 | if keyphrase_index > 0: 20 | if text[keyphrase_index - 1] == " ": 21 | keyphrase = " " + keyphrase 22 | keyphrases.append(keyphrase) 23 | 24 | return keyphrases 25 | 26 | 27 | def main(args): 28 | random.seed(42) 29 | keyphrase_universe = [] 30 | max_keyphrase_pairs = 0 31 | corpus_dirs = [args.train_data_dir] 32 | for corpus_dir in corpus_dirs: 33 | data_files = glob(corpus_dir + "/*") 34 | 35 | for fname in data_files: 36 | with open(fname) as f: 37 | for line in f: 38 | data = json.loads(line) 39 | if "keywords" in data: 40 | title = data["title"] 41 | abstract = data["abstract"] 42 | text = title + ". " + abstract 43 | keywords = data["keywords"].split(" , ") 44 | keyphrases = parse_keyphrases(text, keywords) 45 | if len(keyphrases) > max_keyphrase_pairs: 46 | max_keyphrase_pairs = len(keyphrases) 47 | keyphrase_universe += keyphrases 48 | 49 | print("Max Keyphrase Pairs: ", max_keyphrase_pairs) 50 | keyphrase_universe = list(set(keyphrase_universe)) 51 | random.shuffle(keyphrase_universe) 52 | with open(os.path.join(args.output_dir, "keyphrase_universe.txt"), "w+") as outf: 53 | for keyphrase in keyphrase_universe: 54 | outf.write(keyphrase) 55 | outf.write("\n") 56 | 57 | 58 | def parse_args(): 59 | parser = argparse.ArgumentParser() 60 | parser.add_argument( 61 | "--train-data-dir", 62 | type=str, 63 | help="Train files from which the keyphrase universe should be computed", 64 | ) 65 | parser.add_argument( 66 | "--output-dir", 67 | type=str, 68 | help="Output file containing all keyphrases from the corpus", 69 | ) 70 | 71 | return parser.parse_args() 72 | 73 | 74 | if __name__ == "__main__": 75 | args = parse_args() 76 | 77 | main(args) 78 | -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- 1 | from .model import * 2 | -------------------------------------------------------------------------------- /model/model.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from transformers import PreTrainedModel, RobertaConfig, RobertaModel 3 | from transformers.activations import gelu, ACT2FN 4 | from transformers.models.roberta.modeling_roberta import ( 5 | RobertaPreTrainedModel, 6 | RobertaLMHead, 7 | ) 8 | import torch 9 | from torch import nn 10 | from torch.nn import CrossEntropyLoss, BCELoss 11 | from klm.model.outputs import KLMForReplacementAndMaskedLMOutput 12 | from typing import Optional 13 | 14 | from transformers.configuration_utils import PretrainedConfig 15 | from transformers.file_utils import ( 16 | add_start_docstrings, 17 | add_start_docstrings_to_model_forward, 18 | replace_return_docstrings, 19 | ) 20 | from transformers.modeling_outputs import Seq2SeqLMOutput, TokenClassifierOutput 21 | from transformers.models.encoder_decoder.configuration_encoder_decoder import ( 22 | EncoderDecoderConfig, 23 | ) 24 | 25 | logger = logging.getLogger(__name__) 26 | 27 | 28 | class ReplacementClassificationHead(nn.Module): 29 | def __init__(self, config, use_doc_emb=True): 30 | super(ReplacementClassificationHead, self).__init__() 31 | self.use_doc_emb = use_doc_emb 32 | classifier_hidden_size = 2 * config.hidden_size 33 | if self.use_doc_emb: 34 | classifier_hidden_size += config.hidden_size 35 | self.num_labels = 2 36 | self.classifier = nn.Linear(classifier_hidden_size, self.num_labels) 37 | self.bias = nn.Parameter(torch.zeros(self.num_labels)) 38 | 39 | def forward(self, hidden_states, pooled_states, pairs): 40 | bs, num_pairs, _ = pairs.size() 41 | bs, seq_len, dim = hidden_states.size() 42 | bs, dim = pooled_states.size() 43 | # pair indices: (bs, num_pairs) 44 | left, right = pairs[:, :, 0], pairs[:, :, 1] 45 | # (bs, num_pairs, dim) 46 | left_hidden = torch.gather( 47 | hidden_states, 1, left.unsqueeze(2).repeat(1, 1, dim) 48 | ) 49 | # bs * num_pairs, dim 50 | left_hidden = left_hidden.contiguous().view(bs * num_pairs, dim) 51 | # (bs, num_pairs, dim) 52 | right_hidden = torch.gather( 53 | hidden_states, 1, right.unsqueeze(2).repeat(1, 1, dim) 54 | ) 55 | # bs * num_pairs, dim 56 | right_hidden = right_hidden.contiguous().view(bs * num_pairs, dim) 57 | # bs * num_pairs, 2*dim 58 | hidden_states = torch.cat((left_hidden, right_hidden), -1) 59 | 60 | if self.use_doc_emb: 61 | # bs * num_pairs, dim 62 | pooled_states = ( 63 | pooled_states.unsqueeze(1) 64 | .repeat(1, num_pairs, 1) 65 | .view(bs * num_pairs, dim) 66 | ) 67 | hidden_states = torch.cat((pooled_states, hidden_states), -1) 68 | 69 | # target scores : bs * num_pairs, num_labels 70 | target_scores = self.classifier(hidden_states) + self.bias 71 | target_scores = torch.reshape(target_scores, (bs, num_pairs, self.num_labels)) 72 | return target_scores 73 | 74 | 75 | class BertLayerNorm(nn.Module): 76 | def __init__(self, hidden_size, eps=1e-12): 77 | super(BertLayerNorm, self).__init__() 78 | self.gamma = nn.Parameter(torch.ones(hidden_size)) 79 | self.beta = nn.Parameter(torch.zeros(hidden_size)) 80 | self.variance_epsilon = eps 81 | 82 | def forward(self, x): 83 | u = x.mean(-1, keepdim=True) 84 | s = (x - u).pow(2).mean(-1, keepdim=True) 85 | x = (x - u) / torch.sqrt(s + self.variance_epsilon) 86 | return self.gamma * x + self.beta 87 | 88 | 89 | class MLPWithLayerNorm(nn.Module): 90 | def __init__(self, config, input_size): 91 | super(MLPWithLayerNorm, self).__init__() 92 | self.config = config 93 | self.linear1 = nn.Linear(input_size, config.hidden_size) 94 | self.non_lin1 = ( 95 | ACT2FN[config.hidden_act] 96 | if isinstance(config.hidden_act, str) 97 | else config.hidden_act 98 | ) 99 | self.layer_norm1 = BertLayerNorm(config.hidden_size, eps=1e-12) 100 | self.linear2 = nn.Linear(config.hidden_size, config.hidden_size) 101 | self.non_lin2 = ( 102 | ACT2FN[config.hidden_act] 103 | if isinstance(config.hidden_act, str) 104 | else config.hidden_act 105 | ) 106 | self.layer_norm2 = BertLayerNorm(config.hidden_size, eps=1e-12) 107 | 108 | def forward(self, hidden): 109 | return self.layer_norm2( 110 | self.non_lin2( 111 | self.linear2(self.layer_norm1(self.non_lin1(self.linear1(hidden)))) 112 | ) 113 | ) 114 | 115 | 116 | class InfillingHead(nn.Module): 117 | def __init__( 118 | self, 119 | config, 120 | roberta_model_embedding_weights, 121 | kp_max_seq_len=10, 122 | position_embedding_size=200, 123 | ): 124 | super(InfillingHead, self).__init__() 125 | classifier_hidden_size = 2 * config.hidden_size 126 | self.num_labels = kp_max_seq_len 127 | self.num_tok_classifier = nn.Linear(classifier_hidden_size, self.num_labels) 128 | self.bias = nn.Parameter(torch.zeros(self.num_labels)) 129 | self.position_embeddings = nn.Embedding(kp_max_seq_len, position_embedding_size) 130 | self.mlp_layer_norm = MLPWithLayerNorm( 131 | config, config.hidden_size * 2 + position_embedding_size 132 | ) 133 | # The output weights are the same as the input embeddings, but there is 134 | # an output-only bias for each token. 135 | self.decoder = nn.Linear( 136 | roberta_model_embedding_weights.size(1), 137 | roberta_model_embedding_weights.size(0), 138 | bias=False, 139 | ) 140 | self.decoder.weight = roberta_model_embedding_weights 141 | self.bias = nn.Parameter(torch.zeros(roberta_model_embedding_weights.size(0))) 142 | self.kp_max_seq_len = kp_max_seq_len 143 | 144 | def forward(self, hidden_states, pairs): 145 | bs, num_pairs, _ = pairs.size() 146 | bs, seq_len, dim = hidden_states.size() 147 | # pair indices: (bs, num_pairs) 148 | left, right = pairs[:, :, 0], pairs[:, :, 1] 149 | # (bs, num_pairs, dim) 150 | left_hidden = torch.gather( 151 | hidden_states, 1, left.unsqueeze(2).repeat(1, 1, dim) 152 | ) 153 | # pair states: bs * num_pairs, kp_max_seq_len, dim 154 | kp_left_hidden = ( 155 | left_hidden.contiguous() 156 | .view(bs * num_pairs, dim) 157 | .unsqueeze(1) 158 | .repeat(1, self.kp_max_seq_len, 1) 159 | ) 160 | # bs * num_pairs, dim 161 | num_tok_left_hidden = left_hidden.contiguous().view(bs * num_pairs, dim) 162 | # (bs, num_pairs, dim) 163 | right_hidden = torch.gather( 164 | hidden_states, 1, right.unsqueeze(2).repeat(1, 1, dim) 165 | ) 166 | # pair states: bs * num_pairs, kp_max_seq_len, dim 167 | kp_right_hidden = ( 168 | right_hidden.contiguous() 169 | .view(bs * num_pairs, dim) 170 | .unsqueeze(1) 171 | .repeat(1, self.kp_max_seq_len, 1) 172 | ) 173 | # bs * num_pairs, dim 174 | num_tok_right_hidden = right_hidden.contiguous().view(bs * num_pairs, dim) 175 | # bs * num_pairs, 2*dim 176 | hidden_states = torch.cat((num_tok_left_hidden, num_tok_right_hidden), -1) 177 | 178 | # target scores : bs * num_pairs, num_labels 179 | num_tok_scores = self.num_tok_classifier(hidden_states) 180 | num_tok_scores = torch.reshape(num_tok_scores, (bs, num_pairs, self.num_labels)) 181 | 182 | # (max_targets, dim) 183 | position_embeddings = self.position_embeddings.weight 184 | hidden_states = self.mlp_layer_norm( 185 | torch.cat( 186 | ( 187 | kp_left_hidden, 188 | kp_right_hidden, 189 | position_embeddings.unsqueeze(0).repeat(bs * num_pairs, 1, 1), 190 | ), 191 | -1, 192 | ) 193 | ) 194 | # target scores : bs * num_pairs, kp_max_seq_len, vocab_size 195 | kp_logits = self.decoder(hidden_states) + self.bias 196 | return kp_logits, num_tok_scores 197 | 198 | 199 | class KLMForReplacementAndMaskedLM(RobertaPreTrainedModel): 200 | _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"] 201 | _keys_to_ignore_on_load_unexpected = [] 202 | 203 | def __init__( 204 | self, 205 | config, 206 | use_doc_emb=False, 207 | kp_max_seq_len=10, 208 | mlm_loss_weight=1.0, 209 | replacement_loss_weight=1.0, 210 | keyphrase_infill_loss_weight=1.0, 211 | infill_num_tok_loss_weight=1.0, 212 | ): 213 | super().__init__(config) 214 | 215 | if config.is_decoder: 216 | logger.warning( 217 | "If you want to use `RobertaForMaskedLM` make sure `config.is_decoder=False` for " 218 | "bi-directional self-attention." 219 | ) 220 | 221 | self.roberta = RobertaModel(config, add_pooling_layer=True) 222 | self.lm_head = RobertaLMHead(config) 223 | 224 | self.init_weights() 225 | self.replacement_classification_head = ReplacementClassificationHead( 226 | config, use_doc_emb 227 | ) 228 | self.infilling_head = InfillingHead( 229 | config, self.roberta.embeddings.word_embeddings.weight, kp_max_seq_len 230 | ) 231 | self.mlm_loss_weight = mlm_loss_weight 232 | self.replacement_loss_weight = replacement_loss_weight 233 | self.keyphrase_infill_loss_weight = keyphrase_infill_loss_weight 234 | self.infill_num_tok_loss_weight = infill_num_tok_loss_weight 235 | self.kp_max_seq_len = kp_max_seq_len 236 | 237 | def get_output_embeddings(self): 238 | return self.lm_head.decoder 239 | 240 | def set_output_embeddings(self, new_embeddings): 241 | self.lm_head.decoder = new_embeddings 242 | 243 | def forward( 244 | self, 245 | input_ids=None, 246 | attention_mask=None, 247 | token_type_ids=None, 248 | position_ids=None, 249 | head_mask=None, 250 | inputs_embeds=None, 251 | encoder_hidden_states=None, 252 | encoder_attention_mask=None, 253 | labels=None, 254 | output_attentions=None, 255 | output_hidden_states=None, 256 | return_dict=None, 257 | keyphrases_input_ids=None, 258 | keyphrase_pairs=None, 259 | replacement_labels=None, 260 | masked_keyphrase_pairs=None, 261 | masked_keyphrase_labels=None, 262 | keyphrase_mask_num_tok_labels=None, 263 | ): 264 | 265 | return_dict = ( 266 | return_dict if return_dict is not None else self.config.use_return_dict 267 | ) 268 | outputs = self.roberta( 269 | input_ids, 270 | attention_mask=attention_mask, 271 | token_type_ids=token_type_ids, 272 | position_ids=position_ids, 273 | head_mask=head_mask, 274 | inputs_embeds=inputs_embeds, 275 | encoder_hidden_states=encoder_hidden_states, 276 | encoder_attention_mask=encoder_attention_mask, 277 | output_attentions=output_attentions, 278 | output_hidden_states=output_hidden_states, 279 | return_dict=return_dict, 280 | ) 281 | sequence_output = outputs[0] 282 | 283 | prediction_scores = self.lm_head(sequence_output) 284 | pooled_output = outputs[1] 285 | 286 | masked_lm_loss = None 287 | if labels is not None: 288 | loss_fct = CrossEntropyLoss() 289 | masked_lm_loss = loss_fct( 290 | prediction_scores.view(-1, self.config.vocab_size), labels.view(-1) 291 | ) 292 | masked_lm_loss = self.mlm_loss_weight * masked_lm_loss 293 | 294 | replacement_logits = None 295 | if keyphrase_pairs is not None and replacement_labels is not None: 296 | replacement_logits = self.replacement_classification_head( 297 | sequence_output, pooled_output, keyphrase_pairs 298 | ) 299 | if replacement_labels is not None: 300 | loss_fct = CrossEntropyLoss() 301 | # As this is a binary classification num_classes is fixed at 2 302 | num_class = 2 303 | replacement_classification_loss = loss_fct( 304 | replacement_logits.view(-1, num_class), replacement_labels.view(-1) 305 | ) 306 | masked_lm_loss += ( 307 | self.replacement_loss_weight * replacement_classification_loss 308 | ) 309 | 310 | if ( 311 | masked_keyphrase_pairs is not None 312 | and masked_keyphrase_labels is not None 313 | and keyphrase_mask_num_tok_labels is not None 314 | ): 315 | label_logits, num_toks_logits = self.infilling_head( 316 | sequence_output, masked_keyphrase_pairs 317 | ) 318 | loss_fct = CrossEntropyLoss() 319 | masked_keyphrase_loss = loss_fct( 320 | label_logits.view(-1, self.config.vocab_size), 321 | masked_keyphrase_labels.view(-1), 322 | ) 323 | masked_lm_loss += self.keyphrase_infill_loss_weight * masked_keyphrase_loss 324 | num_tok_loss_fct = CrossEntropyLoss() 325 | num_tok_loss = num_tok_loss_fct( 326 | num_toks_logits.view(-1, self.kp_max_seq_len), 327 | keyphrase_mask_num_tok_labels.view(-1), 328 | ) 329 | masked_lm_loss += self.infill_num_tok_loss_weight * num_tok_loss 330 | 331 | if not return_dict: 332 | output = (prediction_scores,) + outputs[2:] 333 | return ( 334 | ((masked_lm_loss,) + output) if masked_lm_loss is not None else output 335 | ) 336 | 337 | return KLMForReplacementAndMaskedLMOutput( 338 | loss=masked_lm_loss, 339 | logits=prediction_scores, 340 | hidden_states=outputs.hidden_states, 341 | attentions=outputs.attentions, 342 | replacement_logits=replacement_logits, 343 | ) 344 | -------------------------------------------------------------------------------- /model/outputs.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from transformers.modeling_outputs import MaskedLMOutput 3 | from typing import Optional, Tuple 4 | import torch 5 | 6 | 7 | @dataclass 8 | class KLMForReplacementAndMaskedLMOutput(MaskedLMOutput): 9 | replacement_logits: torch.FloatTensor = None 10 | replacement_hidden_states: Optional[Tuple[torch.FloatTensor]] = None 11 | replacement_attentions: Optional[Tuple[torch.FloatTensor]] = None 12 | -------------------------------------------------------------------------------- /model_cards/KBIR.md: -------------------------------------------------------------------------------- 1 | # Keyphrase Boundary Infilling with Replacement (KBIR) 2 | The KBIR model as described in Learning Rich Representations of Keyphrases from Text (https://arxiv.org/pdf/2112.08547.pdf) builds on top of the RoBERTa architecture by adding an Infilling head and a Replacement Classification head that is used during pre-training. However, these heads are not used during the downstream evaluation of the model and we only leverage the pre-trained embeddings. Discarding the heads thereby allows us to be compatible with all AutoModel classes that RoBERTa supports. 3 | 4 | We provide examples on how to perform downstream evaluation on some of the tasks reported in the paper. 5 | ## Downstream Evaluation 6 | 7 | ### Keyphrase Extraction 8 | ``` 9 | from transformers import AutoTokenizer, AutoModelForTokenClassification 10 | 11 | tokenizer = AutoTokenizer.from_pretrained("bloomberg/KBIR") 12 | model = AutoModelForTokenClassification.from_pretrained("bloomberg/KBIR") 13 | 14 | from datasets import load_dataset 15 | 16 | dataset = load_dataset("midas/semeval2017_ke_tagged") 17 | ``` 18 | 19 | Reported Results: 20 | 21 | | Model | Inspec | SE10 | SE17 | 22 | |-----------------------|--------|-------|-------| 23 | | RoBERTa+BiLSTM-CRF | 59.5 | 27.8 | 50.8 | 24 | | RoBERTa+TG-CRF | 60.4 | 29.7 | 52.1 | 25 | | SciBERT+Hypernet-CRF | 62.1 | 36.7 | 54.4 | 26 | | RoBERTa+Hypernet-CRF | 62.3 | 34.8 | 53.3 | 27 | | RoBERTa-extended-CRF* | 62.09 | 40.61 | 52.32 | 28 | | KBI-CRF* | 62.61 | 40.81 | 59.7 | 29 | | KBIR-CRF* | 62.72 | 40.15 | 62.56 | 30 | 31 | ### Named Entity Recognition 32 | ``` 33 | from transformers import AutoTokenizer, AutoModelForTokenClassification 34 | 35 | tokenizer = AutoTokenizer.from_pretrained("bloomberg/KBIR") 36 | model = AutoModelForTokenClassification.from_pretrained("bloomberg/KBIR") 37 | 38 | from datasets import load_dataset 39 | 40 | dataset = load_dataset("conll2003") 41 | ``` 42 | 43 | Reported Results: 44 | 45 | | Model | F1 | 46 | |---------------------------------|-------| 47 | | LSTM-CRF (Lample et al., 2016) | 91.0 | 48 | | ELMo (Peters et al., 2018) | 92.2 | 49 | | BERT (Devlin et al., 2018) | 92.8 | 50 | | (Akbik et al., 2019) | 93.1 | 51 | | (Baevski et al., 2019) | 93.5 | 52 | | LUKE (Yamada et al., 2020) | 94.3 | 53 | | LUKE w/o entity attention | 94.1 | 54 | | RoBERTa (Yamada et al., 2020) | 92.4 | 55 | | RoBERTa-extended* | 92.54 | 56 | | KBI* | 92.73 | 57 | | KBIR* | 92.97 | 58 | 59 | ### Question Answering 60 | ``` 61 | from transformers import AutoTokenizer, AutoModelForQuestionAnswering 62 | 63 | tokenizer = AutoTokenizer.from_pretrained("bloomberg/KBIR") 64 | model = AutoModelForQuestionAnswering.from_pretrained("bloomberg/KBIR") 65 | 66 | from datasets import load_dataset 67 | 68 | dataset = load_dataset("squad") 69 | ``` 70 | Reported Results: 71 | 72 | | Model | EM | F1 | 73 | |------------------------|-------|-------| 74 | | BERT | 84.2 | 91.1 | 75 | | XLNet | 89.0 | 94.5 | 76 | | ALBERT | 89.3 | 94.8 | 77 | | LUKE | 89.8 | 95.0 | 78 | | LUKE w/o entity attention | 89.2 | 94.7 | 79 | | RoBERTa | 88.9 | 94.6 | 80 | | RoBERTa-extended* | 88.88 | 94.55 | 81 | | KBI* | 88.97 | 94.7 | 82 | | KBIR* | 89.04 | 94.75 | 83 | 84 | ## Any other classification task 85 | As mentioned above since KBIR is built on top of the RoBERTa architecture, it is compatible with any AutoModel setting that RoBERTa is also compatible with. 86 | 87 | We encourage you to try fine-tuning KBIR on different datasets and report the downstream results. 88 | 89 | ## Contact 90 | For any questions contact mkulkarni24@bloomberg.net 91 | -------------------------------------------------------------------------------- /model_cards/KeyBART.md: -------------------------------------------------------------------------------- 1 | # KeyBART 2 | KeyBART as described in Learning Rich Representations of Keyphrase from Text (https://arxiv.org/pdf/2112.08547.pdf), pre-trains a BART-based architecture to produce a concatenated sequence of keyphrases in the CatSeqD format. 3 | 4 | We provide some examples on Downstream Evaluations setups and and also how it can be used for Text-to-Text Generation in a zero-shot setting. 5 | 6 | ## Downstream Evaluation 7 | 8 | ### Keyphrase Generation 9 | ``` 10 | from transformers import AutoTokenizer, AutoModelForSeq2SeqLM 11 | 12 | tokenizer = AutoTokenizer.from_pretrained("bloomberg/KeyBART") 13 | model = AutoModelForSeq2SeqLM.from_pretrained("bloomberg/KeyBART") 14 | 15 | from datasets import load_dataset 16 | 17 | dataset = load_dataset("midas/kp20k") 18 | ``` 19 | 20 | Reported Results: 21 | 22 | #### Present Keyphrase Generation 23 | | | Inspec | | NUS | | Krapivin | | SemEval | | KP20k | | 24 | |---------------|--------|-------|-------|-------|----------|-------|---------|-------|-------|-------| 25 | | Model | F1@5 | F1@M | F1@5 | F1@M | F1@5 | F1@M | F1@5 | F1@M | F1@5 | F1@M | 26 | | catSeq | 22.5 | 26.2 | 32.3 | 39.7 | 26.9 | 35.4 | 24.2 | 28.3 | 29.1 | 36.7 | 27 | | catSeqTG | 22.9 | 27 | 32.5 | 39.3 | 28.2 | 36.6 | 24.6 | 29.0 | 29.2 | 36.6 | 28 | | catSeqTG-2RF1 | 25.3 | 30.1 | 37.5 | 43.3 | 30 | 36.9 | 28.7 | 32.9 | 32.1 | 38.6 | 29 | | GANMR | 25.8 | 29.9 | 34.8 | 41.7 | 28.8 | 36.9 | N/A | N/A | 30.3 | 37.8 | 30 | | ExHiRD-h | 25.3 | 29.1 | N/A | N/A | 28.6 | 34.7 | 28.4 | 33.5 | 31.1 | 37.4 | 31 | | Transformer (Ye et al., 2021) | 28.15 | 32.56 | 37.07 | 41.91 | 31.58 | 36.55 | 28.71 | 32.52 | 33.21 | 37.71 | 32 | | BART* | 23.59 | 28.46 | 35.00 | 42.65 | 26.91 | 35.37 | 26.72 | 31.91 | 29.25 | 37.51 | 33 | | KeyBART-DOC* | 24.42 | 29.57 | 31.37 | 39.24 | 24.21 | 32.60 | 24.69 | 30.50 | 28.82 | 37.59 | 34 | | KeyBART* | 24.49 | 29.69 | 34.77 | 43.57 | 29.24 | 38.62 | 27.47 | 33.54 | 30.71 | 39.76 | 35 | | KeyBART* (Zero-shot) | 30.72 | 36.89 | 18.86 | 21.67 | 18.35 | 20.46 | 20.25 | 25.82 | 12.57 | 15.41 | 36 | 37 | #### Absent Keyphrase Generation 38 | | | Inspec | | NUS | | Krapivin | | SemEval | | KP20k | | 39 | |---------------|--------|------|------|------|----------|------|---------|------|-------|------| 40 | | Model | F1@5 | F1@M | F1@5 | F1@M | F1@5 | F1@M | F1@5 | F1@M | F1@5 | F1@M | 41 | | catSeq | 0.4 | 0.8 | 1.6 | 2.8 | 1.8 | 3.6 | 1.6 | 2.8 | 1.5 | 3.2 | 42 | | catSeqTG | 0.5 | 1.1 | 1.1 | 1.8 | 1.8 | 3.4 | 1.1 | 1.8 | 1.5 | 3.2 | 43 | | catSeqTG-2RF1 | 1.2 | 2.1 | 1.9 | 3.1 | 3.0 | 5.3 | 2.1 | 3.0 | 2.7 | 5.0 | 44 | | GANMR | 1.3 | 1.9 | 2.6 | 3.8 | 4.2 | 5.7 | N/A | N/A | 3.2 | 4.5 | 45 | | ExHiRD-h | 1.1 | 2.2 | N/A | N/A | 2.2 | 4.3 | 1.7 | 2.5 | 1.6 | 3.2 | 46 | | Transformer (Ye et al., 2021) | 1.02 | 1.94 | 2.82 | 4.82 | 3.21 | 6.04 | 2.05 | 2.33 | 2.31 | 4.61 | 47 | | BART* | 1.08 | 1.96 | 1.80 | 2.75 | 2.59 | 4.91 | 1.34 | 1.75 | 1.77 | 3.56 | 48 | | KeyBART-DOC* | 0.99 | 2.03 | 1.39 | 2.74 | 2.40 | 4.58 | 1.07 | 1.39 | 1.69 | 3.38 | 49 | | KeyBART* | 0.95 | 1.81 | 1.23 | 1.90 | 3.09 | 6.08 | 1.96 | 2.65 | 2.03 | 4.26 | 50 | | KeyBART* (Zero-shot) | 1.83 | 2.92 | 1.46 | 2.19 | 1.29 | 2.09 | 1.12 | 1.45 | 0.70 | 1.14 | 51 | 52 | 53 | ### Abstractive Summarization 54 | ``` 55 | from transformers import AutoTokenizer, AutoModelForSeq2SeqLM 56 | 57 | tokenizer = AutoTokenizer.from_pretrained("bloomberg/KeyBART") 58 | model = AutoModelForSeq2SeqLM.from_pretrained("bloomberg/KeyBART") 59 | 60 | from datasets import load_dataset 61 | 62 | dataset = load_dataset("cnn_dailymail") 63 | ``` 64 | 65 | Reported Results: 66 | 67 | | Model | R1 | R2 | RL | 68 | |--------------|-------|-------|-------| 69 | | BART (Lewis et al., 2019) | 44.16 | 21.28 | 40.9 | 70 | | BART* | 42.93 | 20.12 | 39.72 | 71 | | KeyBART-DOC* | 42.92 | 20.07 | 39.69 | 72 | | KeyBART* | 43.10 | 20.26 | 39.90 | 73 | 74 | ## Zero-shot settings 75 | ``` 76 | from transformers import AutoTokenizer, AutoModelForSeq2SeqLM 77 | 78 | tokenizer = AutoTokenizer.from_pretrained("bloomberg/KeyBART") 79 | model = AutoModelForSeq2SeqLM.from_pretrained("bloomberg/KeyBART") 80 | ``` 81 | 82 | Alternatively use the Hosted Inference API console provided in https://huggingface.co/bloomberg/KeyBART 83 | 84 | Sample Zero Shot result: 85 | 86 | ``` 87 | Input: In this work, we explore how to learn task specific language models aimed towards learning rich representation of keyphrases from text documents. 88 | We experiment with different masking strategies for pre-training transformer language models (LMs) in discriminative as well as generative settings. 89 | In the discriminative setting, we introduce a new pre-training objective - Keyphrase Boundary Infilling with Replacement (KBIR), 90 | showing large gains in performance (upto 9.26 points in F1) over SOTA, when LM pre-trained using KBIR is fine-tuned for the task of keyphrase extraction. 91 | In the generative setting, we introduce a new pre-training setup for BART - KeyBART, that reproduces the keyphrases related to the input text in the CatSeq 92 | format, instead of the denoised original input. This also led to gains in performance (upto 4.33 points in F1@M) over SOTA for keyphrase generation. 93 | Additionally, we also fine-tune the pre-trained language models on named entity recognition (NER), question answering (QA), relation extraction (RE), 94 | abstractive summarization and achieve comparable performance with that of the SOTA, showing that learning rich representation of keyphrases is indeed beneficial 95 | for many other fundamental NLP tasks. 96 | 97 | Output: language model;keyphrase generation;new pre-training objective;pre-training setup; 98 | 99 | ``` 100 | -------------------------------------------------------------------------------- /pretrain_runner.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import math 3 | import os 4 | import sys 5 | import torch 6 | import logging 7 | import time 8 | import traceback 9 | from datasets import load_dataset 10 | from glob import glob 11 | import transformers 12 | from transformers import RobertaForMaskedLM 13 | from transformers import RobertaConfig 14 | from transformers import RobertaTokenizer, RobertaTokenizerFast 15 | from transformers import Trainer, TrainingArguments 16 | from transformers import set_seed 17 | from transformers import PrinterCallback 18 | from transformers import DataCollatorForWholeWordMask 19 | 20 | from transformers import BartTokenizer, BartTokenizerFast 21 | from transformers import BartForConditionalGeneration 22 | 23 | from utils.data_collators import DataCollatorForKLM 24 | from model.model import KLMForReplacementAndMaskedLM, EncoderDecoderModel 25 | from trainer.trainer import TrainerWithEvalCollator 26 | 27 | logger = logging.getLogger(__name__) 28 | 29 | 30 | def is_main_process(local_rank): 31 | """ 32 | Whether or not the current process is the local process, based on `local_rank`. 33 | """ 34 | return local_rank in [-1, 0] 35 | 36 | 37 | def main(args): 38 | 39 | if args.do_eval: 40 | assert args.eval_data_dir is not None 41 | 42 | training_args = TrainingArguments( 43 | no_cuda=args.no_cuda, 44 | output_dir=args.model_dir, 45 | overwrite_output_dir=True, 46 | num_train_epochs=args.epochs, 47 | per_device_train_batch_size=args.train_batch_size, 48 | per_device_eval_batch_size=args.eval_batch_size, 49 | save_steps=args.save_steps, 50 | learning_rate=args.learning_rate, 51 | local_rank=args.local_rank, 52 | adam_epsilon=args.adam_epsilon, 53 | warmup_steps=args.warmup_steps, 54 | max_steps=args.max_steps, 55 | do_train=args.do_train, 56 | logging_steps=args.logging_steps, 57 | disable_tqdm=True, 58 | evaluation_strategy="steps", 59 | eval_steps=args.eval_steps, 60 | remove_unused_columns=False, 61 | gradient_accumulation_steps=args.gradient_accumulation_steps, 62 | fp16=args.fp16, 63 | fp16_opt_level=args.fp16_opt_level, 64 | ) 65 | 66 | # Setup logging 67 | logging.basicConfig( 68 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 69 | datefmt="%m/%d/%Y %H:%M:%S", 70 | level=logging.INFO 71 | if is_main_process(training_args.local_rank) 72 | else logging.WARN, 73 | ) 74 | 75 | # Log on each process the small summary: 76 | logger.warning( 77 | f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}" 78 | + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}" 79 | + f"16-bits optimization level {training_args.fp16_opt_level}" 80 | ) 81 | logger.warning(f"Process rank from env: {os.getenv('RANK')}") 82 | logger.info("There are %d GPU(s) available.", torch.cuda.device_count()) 83 | # Set the verbosity to info of the Transformers logger (on main process only): 84 | if is_main_process(training_args.local_rank): 85 | transformers.utils.logging.set_verbosity_info() 86 | transformers.utils.logging.enable_default_handler() 87 | transformers.utils.logging.enable_explicit_format() 88 | logger.info("Training/evaluation parameters %s", training_args) 89 | 90 | # Set seed before initializing model. 91 | set_seed(training_args.seed) 92 | 93 | # Choose this when trying to train from scratch 94 | # config 95 | # config = RobertaConfig.from_pretrained(args.roberta_mlm_model_dir) 96 | # model roberta 97 | # model = RobertaForMaskedLM(config=config) 98 | 99 | # Choose this when starting from an already pretrained model 100 | if ( 101 | args.do_keyphrase_replacement or args.do_keyphrase_infilling 102 | ) and not args.do_generation: 103 | logger.info("Loading KLMForReplacementAndMaskedLM") 104 | model = KLMForReplacementAndMaskedLM.from_pretrained( 105 | args.roberta_mlm_model_dir, 106 | use_doc_emb=args.use_doc_emb, 107 | kp_max_seq_len=args.kp_max_seq_len, 108 | mlm_loss_weight=args.mlm_loss_weight, 109 | replacement_loss_weight=args.replacement_loss_weight, 110 | keyphrase_infill_loss_weight=args.keyphrase_infill_loss_weight, 111 | infill_num_tok_loss_weight=args.infill_num_tok_loss_weight, 112 | ) 113 | elif args.do_generation: 114 | if args.use_bart: 115 | logger.info("Loading BART") 116 | model = BartForConditionalGeneration.from_pretrained(args.bart_model_dir) 117 | else: 118 | raise NotImplementedError( 119 | "Generation models outside of BART are currently not supported!" 120 | ) 121 | else: 122 | logger.info("Loading RobertaForMaskedLM") 123 | model = RobertaForMaskedLM.from_pretrained(args.roberta_mlm_model_dir) 124 | 125 | logger.info("Loaded pre-trained model") 126 | 127 | freeze_layers = list(range(args.num_frozen_layers)) 128 | if freeze_layers: 129 | for name, param in model.roberta.encoder.layer.named_children(): 130 | if int(name) in freeze_layers: 131 | logger.info("Freezing Layer: ", name) 132 | param.requires_grad = False 133 | 134 | if args.do_generation and args.use_bart: 135 | tokenizer = BartTokenizer.from_pretrained( 136 | args.bart_model_dir, use_fast=True, max_length=512 137 | ) 138 | # this tokenizer needs to be downloaded and we need to point to the path 139 | else: 140 | tokenizer = RobertaTokenizer.from_pretrained( 141 | args.roberta_tokenizer_dir, use_fast=True, max_length=512 142 | ) 143 | 144 | keyphrase_universe = set() 145 | keyphrase_universe_ids = None 146 | if args.do_keyphrase_replacement: 147 | logger.info("Loading Keyphrase Universe") 148 | with open(args.keyphrase_universe) as f: 149 | for line in f: 150 | keyphrase_universe.add(line) 151 | # Restrict size of keyphrase universe for computational reasons 152 | if ( 153 | args.keyphrase_universe_size != -1 154 | and len(keyphrase_universe) == args.keyphrase_universe_size 155 | ): 156 | break 157 | keyphrase_universe = list(keyphrase_universe) 158 | keyphrase_universe_ids = tokenizer( 159 | keyphrase_universe, truncation=True, add_special_tokens=False, 160 | )["input_ids"] 161 | assert len(keyphrase_universe) == len(keyphrase_universe_ids) 162 | 163 | def parse_keyphrases(text, keywords): 164 | keyphrases = [] 165 | for keyphrase in keywords.split(" , "): 166 | if not keyphrase.strip(): 167 | continue 168 | keyphrase_index = text.find(keyphrase) 169 | if keyphrase_index == -1: 170 | keyphrase_index = text.lower().find(keyphrase) 171 | # Can't find keyphrase in text 172 | if keyphrase_index == -1: 173 | continue 174 | keyphrase = text[keyphrase_index : keyphrase_index + len(keyphrase)] 175 | # Decide whether a space is required before for the tokenizer to have a consistent behavior 176 | if keyphrase_index > 0: 177 | if text[keyphrase_index - 1] == " ": 178 | keyphrase = " " + keyphrase 179 | keyphrases.append(keyphrase) 180 | 181 | return keyphrases 182 | 183 | def get_catseq_keyphrases(keywords): 184 | keyphrases = ";".join(list(set(keywords.split(" , ")))) 185 | 186 | return keyphrases 187 | 188 | def tokenize_klm_function(examples): 189 | try: 190 | text_data = [ 191 | title + ". " + abstract 192 | for title, abstract in zip(examples["title"], examples["abstract"]) 193 | ] 194 | examples_batch_encoding = tokenizer( 195 | text_data, 196 | truncation=True, 197 | return_special_tokens_mask=True, 198 | max_length=512, 199 | ) 200 | keyphrases_input_ids = [] 201 | catseq_keyphrases_input_ids = [] 202 | for text, keyphrase_list in zip(text_data, examples["keywords"]): 203 | keyphrases = parse_keyphrases(text, keyphrase_list) 204 | keyphrase_input_ids = tokenizer( 205 | keyphrases, truncation=True, add_special_tokens=False, 206 | )["input_ids"] 207 | keyphrases_input_ids.append(keyphrase_input_ids) 208 | catseq_keyphrases = get_catseq_keyphrases(keyphrase_list) 209 | catseq_keyphrase_input_ids = tokenizer( 210 | catseq_keyphrases, truncation=True, add_special_tokens=False, 211 | )["input_ids"] 212 | catseq_keyphrases_input_ids.append(catseq_keyphrase_input_ids) 213 | examples_batch_encoding["keyphrases_input_ids"] = keyphrases_input_ids 214 | examples_batch_encoding[ 215 | "catseq_keyphrase_input_ids" 216 | ] = catseq_keyphrases_input_ids 217 | return examples_batch_encoding 218 | except Exception as e: 219 | logger.info("Skipping batch due to errors") 220 | logger.info(e) 221 | 222 | def tokenize_mlm_function(examples): 223 | text = [ 224 | title + ". " + abstract 225 | for title, abstract in zip(examples["title"], examples["abstract"]) 226 | ] 227 | return tokenizer(text, truncation=True, max_length=512,) 228 | 229 | def filter_empty_keyphrases(example): 230 | text = example["title"] + ". " + example["abstract"] 231 | try: 232 | if ( 233 | not example["keywords"] 234 | or not example["title"] 235 | or not example["abstract"] 236 | or not parse_keyphrases(text, example["keywords"]) 237 | ): 238 | return False 239 | return True 240 | except: 241 | return False 242 | 243 | if training_args.do_eval: 244 | logger.info("Initializing Eval Dataset") 245 | eval_dataset = load_dataset( 246 | "json", data_files=glob(args.eval_data_dir + "/*.txt") 247 | ) 248 | 249 | logger.info("Initializing Eval DataCollator") 250 | if args.eval_task == "KLM": 251 | logger.info("Filter Eval Dataset") 252 | eval_dataset = eval_dataset.filter(filter_empty_keyphrases,) 253 | 254 | logger.info("Tokenize Eval Dataset") 255 | tokenized_eval_dataset = eval_dataset.map( 256 | tokenize_klm_function, 257 | batched=True, 258 | remove_columns=["title", "abstract", "keywords"], 259 | writer_batch_size=3_000, 260 | load_from_cache_file=False, 261 | ) 262 | logger.info("Setting up Eval DataCollator") 263 | eval_data_collator = DataCollatorForKLM( 264 | tokenizer=tokenizer, 265 | keyphrase_universe_ids=keyphrase_universe_ids, 266 | mlm_probability=args.mlm_probability, 267 | kp_mask_percentage=args.keyphrase_mask_percentage, 268 | kp_replace_percentage=args.keyphrase_replace_percentage, 269 | max_keyphrase_pairs=args.max_keyphrase_pairs, 270 | max_seq_len=args.max_seq_len, 271 | do_generation=args.do_generation, 272 | use_bart=args.use_bart, 273 | do_keyphrase_generation=args.do_keyphrase_generation, 274 | do_keyphrase_infilling=args.do_keyphrase_infilling, 275 | kp_max_seq_len=args.kp_max_seq_len, 276 | max_mask_keyphrase_pairs=args.max_mask_keyphrase_pairs, 277 | ) 278 | 279 | elif args.eval_task == "MLM": 280 | logger.info("Tokenize Eval Dataset") 281 | tokenized_eval_dataset = eval_dataset.map( 282 | tokenize_mlm_function, 283 | batched=True, 284 | remove_columns=["title", "abstract"], 285 | writer_batch_size=3_000, 286 | load_from_cache_file=False, 287 | ) 288 | logger.info("Setting up Eval DataCollator") 289 | eval_data_collator = DataCollatorForWholeWordMask(tokenizer=tokenizer) 290 | 291 | logger.info(f"Task = {args.task}") 292 | if training_args.do_train: 293 | logger.info("Initializing Train Dataset") 294 | train_dataset = load_dataset( 295 | "json", data_files=glob(args.train_data_dir + "/*.txt") 296 | ) 297 | 298 | logger.info("Initializing Train DataCollator") 299 | if args.task == "KLM": 300 | logger.info("Filter Train Dataset") 301 | train_dataset = train_dataset.filter(filter_empty_keyphrases,) 302 | logger.info(train_dataset) 303 | logger.info("Tokenize Train Dataset") 304 | tokenized_train_dataset = train_dataset.map( 305 | tokenize_klm_function, 306 | batched=True, 307 | remove_columns=["title", "abstract", "keywords"], 308 | writer_batch_size=3_000, 309 | load_from_cache_file=False, 310 | ) 311 | logger.info(tokenized_train_dataset) 312 | logger.info("Setting up Train DataCollator") 313 | train_data_collator = DataCollatorForKLM( 314 | tokenizer=tokenizer, 315 | keyphrase_universe_ids=keyphrase_universe_ids, 316 | mlm_probability=args.mlm_probability, 317 | kp_mask_percentage=args.keyphrase_mask_percentage, 318 | kp_replace_percentage=args.keyphrase_replace_percentage, 319 | max_keyphrase_pairs=args.max_keyphrase_pairs, 320 | max_seq_len=args.max_seq_len, 321 | do_generation=args.do_generation, 322 | use_bart=args.use_bart, 323 | do_keyphrase_generation=args.do_keyphrase_generation, 324 | do_keyphrase_infilling=args.do_keyphrase_infilling, 325 | kp_max_seq_len=args.kp_max_seq_len, 326 | max_mask_keyphrase_pairs=args.max_mask_keyphrase_pairs, 327 | ) 328 | 329 | elif args.task == "MLM": 330 | logger.info("Tokenize Train Dataset") 331 | tokenized_train_dataset = train_dataset.map( 332 | tokenize_mlm_function, 333 | batched=True, 334 | remove_columns=["title", "abstract"], 335 | writer_batch_size=3_000, 336 | load_from_cache_file=False, 337 | ) 338 | logger.info("Setting up Train DataCollator") 339 | train_data_collator = DataCollatorForWholeWordMask(tokenizer=tokenizer) 340 | 341 | logger.info("Initializing Trainer") 342 | if args.do_eval: 343 | trainer = TrainerWithEvalCollator( 344 | model=model, 345 | args=training_args, 346 | data_collator=train_data_collator if training_args.do_train else None, 347 | eval_data_collator=eval_data_collator if training_args.do_eval else None, 348 | train_dataset=tokenized_train_dataset if training_args.do_train else None, 349 | eval_dataset=tokenized_eval_dataset if training_args.do_eval else None, 350 | tokenizer=tokenizer, 351 | ) 352 | else: 353 | trainer = Trainer( 354 | model=model, 355 | args=training_args, 356 | data_collator=train_data_collator if training_args.do_train else None, 357 | train_dataset=tokenized_train_dataset if training_args.do_train else None, 358 | tokenizer=tokenizer, 359 | ) 360 | trainer.add_callback(PrinterCallbackWithFlush) 361 | 362 | if training_args.do_train: 363 | logger.info("Training Model") 364 | checkpoint = None 365 | if args.is_checkpoint: 366 | logger.info("Loading from checkpoint") 367 | checkpoint = ( 368 | args.roberta_mlm_model_dir 369 | if args.roberta_mlm_model_dir 370 | else args.bart_model_dir 371 | ) 372 | trainer.train(resume_from_checkpoint=checkpoint) 373 | 374 | results = {} 375 | if training_args.do_eval: 376 | logger.info("*** Evaluate ***") 377 | 378 | eval_output = trainer.evaluate() 379 | 380 | perplexity = math.exp(eval_output["eval_loss"]) 381 | results["perplexity"] = perplexity 382 | 383 | output_eval_file = os.path.join( 384 | training_args.output_dir, "eval_results_mlm_wwm.txt" 385 | ) 386 | if trainer.is_world_process_zero(): 387 | with open(output_eval_file, "w") as writer: 388 | logger.info("***** Eval results *****") 389 | for key, value in sorted(results.items()): 390 | logger.info(f" {key} = {value}") 391 | writer.write(f"{key} = {value}\n") 392 | 393 | 394 | class PrinterCallbackWithFlush(PrinterCallback): 395 | def __init__(self): 396 | self.prev_steps = None 397 | self.prev_time = None 398 | 399 | def on_log(self, args, state, control, logs=None, **kwargs): 400 | curr_steps = state.global_step 401 | curr_time = time.time() 402 | if self.prev_steps: 403 | print( 404 | f"Steps since last log: {curr_steps - self.prev_steps}, " 405 | f"Global steps: {curr_steps}, " 406 | f"Max steps: {state.max_steps}, " 407 | f"Time since last log: {curr_time - self.prev_time}", 408 | flush=True, 409 | ) 410 | else: 411 | print("Starting steps and time history", flush=True) 412 | self.prev_steps = curr_steps 413 | self.prev_time = curr_time 414 | 415 | def on_evaluate(self, args, state, control, metrics=None, **kwargs): 416 | perplexity = math.exp(metrics["eval_loss"]) 417 | print({"perplexity": perplexity}) 418 | 419 | 420 | def world_size(): 421 | """Returns the total number of processes in a distributed job (num_nodes x gpus_per_node). 422 | Returns 1 in a non-distributed job. 423 | """ 424 | return int(os.environ.get("WORLD_SIZE", "1")) 425 | 426 | 427 | def is_distributed(): 428 | """Returns True iff this is a distributed job (more than one process).""" 429 | return world_size() > 1 430 | 431 | 432 | if __name__ == "__main__": 433 | parser = argparse.ArgumentParser() 434 | 435 | parser.add_argument( 436 | "--train-data-dir", 437 | type=str, 438 | required=True, 439 | help="The input data dir. Should contain the .tsv files (or other data files) for the task.", 440 | ) 441 | parser.add_argument( 442 | "--eval-data-dir", 443 | type=str, 444 | required=False, 445 | help="The eval data dir. Required if --do-eval is set.", 446 | ) 447 | parser.add_argument( 448 | "--keyphrase-universe", 449 | type=str, 450 | required=False, 451 | help="File containing all the keyphrases in across the train and dev set used for keyphrase replacement.", 452 | ) 453 | parser.add_argument( 454 | "--model-dir", 455 | default="./", 456 | type=str, 457 | required=True, 458 | help="The output directory where the model predictions and checkpoints will be written.", 459 | ) 460 | parser.add_argument( 461 | "--roberta-tokenizer-dir", 462 | type=str, 463 | required=False, 464 | default=None, 465 | help="The directory from where the RoBERTa tokenizer is loaded.", 466 | ) 467 | parser.add_argument( 468 | "--roberta-mlm-model-dir", 469 | type=str, 470 | required=False, 471 | default=None, 472 | help="The directory from where the pre-trained RoBERTa model is loaded.", 473 | ) 474 | parser.add_argument( 475 | "--bart-model-dir", 476 | type=str, 477 | required=False, 478 | default=None, 479 | help="The directory from where the pre-trained RoBERTa model is loaded.", 480 | ) 481 | parser.add_argument( 482 | "--epochs", default=3.0, type=float, help="Number of epochs to train the model" 483 | ) 484 | parser.add_argument( 485 | "--learning-rate", 486 | default=5e-5, 487 | type=float, 488 | help="Learning rate to use for training the model", 489 | ) 490 | parser.add_argument( 491 | "--num-frozen-layers", 492 | default=0, 493 | type=int, 494 | help="Number of RoBERTa encoder layers to freeze during training", 495 | ) 496 | parser.add_argument( 497 | "--adam-epsilon", 498 | default=1e-8, 499 | type=float, 500 | help="The epsilon hyperparameter for the Adam optimizer", 501 | ) 502 | parser.add_argument( 503 | "--train-batch-size", 504 | default=64, 505 | type=int, 506 | help="Training dataset batch size per GPU", 507 | ) 508 | parser.add_argument( 509 | "--eval-batch-size", 510 | default=128, 511 | type=int, 512 | help="Eval dataset batch size per GPU", 513 | ) 514 | parser.add_argument( 515 | "--warmup-steps", 516 | default=0, 517 | type=int, 518 | help="Number of steps used for a linear warmup from 0 to learning_rate.", 519 | ) 520 | parser.add_argument( 521 | "--max-steps", 522 | default=-1, 523 | type=int, 524 | help="If set to a positive number, the total number of training steps to perform. Overrides num_train_epochs.", 525 | ) 526 | parser.add_argument( 527 | "--save-steps", 528 | default=500, 529 | type=int, 530 | help="Number of updates steps before two checkpoint saves.", 531 | ) 532 | parser.add_argument( 533 | "--logging-steps", 534 | default=500, 535 | type=int, 536 | help="Number of updates steps before logs", 537 | ) 538 | parser.add_argument( 539 | "--no-cuda", action="store_true", help="Whether not to use CUDA when available" 540 | ) 541 | parser.add_argument( 542 | "--do-train", action="store_true", help="Whether to run training or not." 543 | ) 544 | parser.add_argument( 545 | "--do-eval", action="store_true", help="Whether to run evaluation or not." 546 | ) 547 | parser.add_argument( 548 | "--is-checkpoint", 549 | action="store_true", 550 | help="Whether to treat the model path as a checkpoint to resume training from.", 551 | ) 552 | parser.add_argument( 553 | "--eval-steps", 554 | type=int, 555 | default=500, 556 | help="Run evaluation every these many steps", 557 | ) 558 | parser.add_argument( 559 | "--dataloader-num-workers", type=int, default=0, help="dataloader num workers" 560 | ) 561 | 562 | parser.add_argument( 563 | "--local-rank", 564 | type=int, 565 | default=-1, 566 | help="Local rank when doing distributed training, set to -1 if running non-distributed", 567 | ) 568 | parser.add_argument( 569 | "--rank", 570 | type=int, 571 | default=int(os.getenv("RANK", 0)), 572 | help="Rank when doing distributed training, doesn't matter for non-distributed", 573 | ) 574 | parser.add_argument( 575 | "--world-size", 576 | type=int, 577 | default=world_size(), 578 | help="world size when doing distributed training, set to -1 if running non-distributed", 579 | ) 580 | parser.add_argument( 581 | "--task", 582 | choices=["KLM", "MLM"], 583 | default="KLM", 584 | help="KLM training or whole word masking training", 585 | ) 586 | parser.add_argument( 587 | "--eval-task", 588 | choices=["KLM", "MLM"], 589 | default="MLM", 590 | help="What masking to use for eval: KLM/MLM", 591 | ) 592 | parser.add_argument( 593 | "--gradient-accumulation-steps", type=int, default=1, 594 | ) 595 | 596 | parser.add_argument( 597 | "--do-keyphrase-replacement", 598 | action="store_true", 599 | help="Whether to enable keyphrase replacement during KLM.", 600 | ) 601 | parser.add_argument( 602 | "--do-keyphrase-generation", 603 | action="store_true", 604 | help="Whether to have the generation labels correspond to CatSeq representation of keyphrases.", 605 | ) 606 | parser.add_argument( 607 | "--do-keyphrase-infilling", 608 | action="store_true", 609 | help="Whether to use the text in-filling pre-training setup for BART.", 610 | ) 611 | parser.add_argument( 612 | "--keyphrase-universe-size", 613 | type=int, 614 | default=-1, 615 | help="Size of universe used during keyphrase replacement during KLM.", 616 | ) 617 | parser.add_argument( 618 | "--max-seq-len", 619 | type=int, 620 | default=512, 621 | help="Max Sequence Length considered for an input", 622 | ) 623 | parser.add_argument( 624 | "--kp-max-seq-len", 625 | type=int, 626 | default=10, 627 | help="Max Sequence Length considered for an keyphrase during infilling", 628 | ) 629 | parser.add_argument( 630 | "--mlm-probability", 631 | type=float, 632 | default=0.15, 633 | help="Probability of masking a token in the input during pre-training", 634 | ) 635 | parser.add_argument( 636 | "--keyphrase-mask-percentage", 637 | type=float, 638 | default=0.4, 639 | help="If training on the KLM objective, percentage of keyphrases tokens to mask as a percentage of a total input tokens. When used with infilling this is the percentage of keyphrases to mask.", 640 | ) 641 | parser.add_argument( 642 | "--keyphrase-replace-percentage", 643 | type=float, 644 | default=0.1, 645 | help="If training on the KLM objective and replacing keyphrases, percentage of keyphrases to replace as a percentage of a total keyphrases", 646 | ) 647 | parser.add_argument( 648 | "--max-keyphrase-pairs", 649 | type=int, 650 | default=20, 651 | help="If training on the KLM objective and replacing keyphrases, max number of keyphrases to consider in replacement task", 652 | ) 653 | parser.add_argument( 654 | "--max-mask-keyphrase-pairs", 655 | type=int, 656 | default=10, 657 | help="If training on the KLM objective and infilling keyphrases, max number of keyphrases to consider for masking", 658 | ) 659 | parser.add_argument( 660 | "--use-doc-emb", 661 | action="store_true", 662 | help="Whether to use [CLS] document embedding during keyphrase replacement.", 663 | ) 664 | parser.add_argument( 665 | "--do-generation", 666 | action="store_true", 667 | help="Whether to set up pre-training as an denoisining autoencoder with a decoder head.", 668 | ) 669 | parser.add_argument( 670 | "--use-bart", 671 | action="store_true", 672 | help="Whether to use bart when --do-generation is set, needs --bart-model-dir to be non-null", 673 | ) 674 | parser.add_argument( 675 | "--mlm-loss-weight", 676 | type=float, 677 | default=1.0, 678 | help="Co-efficient for masked language modelling loss in overall loss", 679 | ) 680 | parser.add_argument( 681 | "--replacement-loss-weight", 682 | type=float, 683 | default=1.0, 684 | help="Co-efficient for keyphrase replacement classification loss in overall loss", 685 | ) 686 | parser.add_argument( 687 | "--keyphrase-infill-loss-weight", 688 | type=float, 689 | default=1.0, 690 | help="Co-efficient for keyphrase infilling loss in overall loss", 691 | ) 692 | parser.add_argument( 693 | "--infill-num-tok-loss-weight", 694 | type=float, 695 | default=1.0, 696 | help="Co-efficient for keyphrase number of token classification loss in overall loss", 697 | ) 698 | parser.add_argument("--fp16", action="store_true", default=False) 699 | parser.add_argument("--fp16_opt_level", type=str, default="O2") 700 | 701 | args = parser.parse_args() 702 | 703 | if is_distributed(): 704 | print("Process is distributed") 705 | # args.local_rank = int(os.getenv('LOCAL_RANK')) 706 | # To avoid deadlocks on the tokenizer 707 | os.environ["TOKENIZERS_PARALLELISM"] = "false" 708 | 709 | print(f"args.local_rank = {args.local_rank}") 710 | print(f"args.rank = {args.rank}") 711 | print(f"args.world_size = {args.world_size}") 712 | print(f"os.getenv('RANK') = {os.getenv('RANK')}") 713 | print(f"os.getenv('LOCAL_RANK') = {os.getenv('LOCAL_RANK')}") 714 | print(f"os.getenv('WORLD_SIZE') = {os.getenv('WORLD_SIZE')}") 715 | print(args) 716 | 717 | for _ in range(5): 718 | try: 719 | main(args) 720 | break 721 | except Exception as e: 722 | print(e) 723 | traceback.print_exc() 724 | print("Trying to reconnect to master") 725 | time.sleep(1) 726 | else: 727 | raise RuntimeError("Could not successfully finish job") 728 | -------------------------------------------------------------------------------- /run_pretrain_kp_infill_replacement_bart_kg_oagkx.sh: -------------------------------------------------------------------------------- 1 | BASE_RESOURCES= 2 | 3 | run pretrain_runner.py \ 4 | --bart-model-dir ${BASE_RESOURCES}/resources/bart_large \ 5 | --train-data-dir ${BASE_RESOURCES}/resources/oagkx \ 6 | --eval-data-dir ${BASE_RESOURCES}/resources/oagkx_eval \ 7 | --keyphrase-universe ${BASE_RESOURCES}/resources/oagkx_keyphrase_universe/keyphrase_universe.txt \ 8 | --keyphrase-universe-size 500000 \ 9 | --train-batch-size 4 \ 10 | --eval-batch-size 2 \ 11 | --learning-rate 1e-5 \ 12 | --adam-epsilon 1e-6 \ 13 | --max-steps 130000 \ 14 | --save-steps 10000 \ 15 | --eval-steps 140000 \ 16 | --logging-steps 10000 \ 17 | --warmup-steps 2500 \ 18 | --mlm-probability 0.05 \ 19 | --keyphrase-mask-percentage 0.2 \ 20 | --keyphrase-replace-percentage 0.4 \ 21 | --do-train \ 22 | --do-eval \ 23 | --do-generation \ 24 | --use-bart \ 25 | --do-keyphrase-infilling \ 26 | --do-keyphrase-replacement \ 27 | --do-keyphrase-generation \ 28 | --task KLM \ 29 | --eval-task KLM \ 30 | --max-mask-keyphrase-pairs 10 \ 31 | --max-keyphrase-pairs 20 \ 32 | --kp-max-seq-len 10 \ 33 | --model-dir ${BASE_RESOURCES}/bart-kg-infill-replacement-models/ \ 34 | --dataloader-num-workers 5 \ 35 | -------------------------------------------------------------------------------- /run_pretrain_kp_infill_replacement_oagkx.sh: -------------------------------------------------------------------------------- 1 | BASE_RESOURCES= 2 | 3 | run pretrain_runner.py \ 4 | --roberta-tokenizer-dir ${BASE_RESOURCES}/infill-replacement-models/checkpoint-130000 \ 5 | --roberta-mlm-model-dir ${BASE_RESOURCES}/infill-replacement-models/checkpoint-130000 \ 6 | --train-data-dir ${BASE_RESOURCES}/resources/oagkx \ 7 | --eval-data-dir ${BASE_RESOURCES}/resources/oagkx_eval \ 8 | --keyphrase-universe ${BASE_RESOURCES}/resources/oagkx_keyphrase_universe/keyphrase_universe.txt \ 9 | --keyphrase-universe-size 500000 \ 10 | --train-batch-size 2 \ 11 | --eval-batch-size 2 \ 12 | --learning-rate 1e-5 \ 13 | --adam-epsilon 1e-6 \ 14 | --max-steps 260000 \ 15 | --save-steps 10000 \ 16 | --eval-steps 340000 \ 17 | --logging-steps 1000 \ 18 | --warmup-steps 2500 \ 19 | --mlm-probability 0.05 \ 20 | --keyphrase-mask-percentage 0.2 \ 21 | --keyphrase-replace-percentage 0.4 \ 22 | --do-train \ 23 | --do-eval \ 24 | --do-keyphrase-infilling \ 25 | --do-keyphrase-replacement \ 26 | --task KLM \ 27 | --eval-task KLM \ 28 | --max-mask-keyphrase-pairs 10 \ 29 | --max-keyphrase-pairs 20 \ 30 | --kp-max-seq-len 10 \ 31 | --mlm-loss-weight 1.0 \ 32 | --keyphrase-infill-loss-weight 0.3 \ 33 | --infill-num-tok-loss-weight 1.0 \ 34 | --replacement-loss-weight 2.0 \ 35 | --model-dir ${BASE_RESOURCES}/infill-replacement-models/ \ 36 | --dataloader-num-workers 5 \ 37 | -------------------------------------------------------------------------------- /trainer/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /trainer/trainer.py: -------------------------------------------------------------------------------- 1 | import collections 2 | from torch.utils.data.dataloader import DataLoader 3 | from torch.utils.data.dataset import Dataset 4 | from transformers import Trainer 5 | from typing import Optional 6 | 7 | class TrainerWithEvalCollator(Trainer): 8 | def __init__(self, *args, **kwargs): 9 | if "eval_data_collator" in kwargs: 10 | self.eval_data_collator = kwargs["eval_data_collator"] 11 | del kwargs["eval_data_collator"] 12 | else: 13 | self.eval_data_collator = None 14 | 15 | super(TrainerWithEvalCollator, self).__init__(*args, **kwargs) 16 | 17 | def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoader: 18 | if self.eval_data_collator is None: 19 | return super().get_eval_dataloader(eval_dataset) 20 | 21 | # Guards borrowed from base Trainer 22 | if eval_dataset is None and self.eval_dataset is None: 23 | raise ValueError("TrainerWithEvalCollator: evaluation requires an eval_dataset.") 24 | elif eval_dataset is not None and not isinstance(eval_dataset, collections.abc.Sized): 25 | raise ValueError("eval_dataset must implement __len__") 26 | 27 | eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset 28 | eval_sampler = self._get_eval_sampler(eval_dataset) 29 | 30 | return DataLoader( 31 | eval_dataset, 32 | sampler=eval_sampler, 33 | batch_size=self.args.eval_batch_size, 34 | collate_fn=self.eval_data_collator, 35 | drop_last=self.args.dataloader_drop_last, 36 | num_workers=self.args.dataloader_num_workers, 37 | ) 38 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bloomberg/kbir_keybart/e06f9acbb6232e791f9be2c085d682a17d689eb2/utils/__init__.py -------------------------------------------------------------------------------- /utils/data_collators.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import time 3 | import torch 4 | import random 5 | from dataclasses import dataclass 6 | from typing import Dict 7 | import time 8 | 9 | from transformers import DataCollatorForLanguageModeling 10 | from transformers.data.data_collator import _collate_batch, tolist 11 | from transformers import PreTrainedTokenizer 12 | from transformers.tokenization_utils_base import BatchEncoding 13 | 14 | logger = logging.getLogger() 15 | 16 | 17 | @dataclass 18 | class DataCollatorForKLM(DataCollatorForLanguageModeling): 19 | def __init__( 20 | self, 21 | tokenizer, 22 | mlm_probability=0.15, 23 | kp_mask_percentage=0.4, 24 | kp_replace_percentage=0.4, 25 | keyphrase_universe_ids=None, 26 | max_keyphrase_pairs=20, 27 | max_seq_len=512, 28 | label_ignore_index=-100, 29 | do_generation=False, 30 | use_bart=False, 31 | do_keyphrase_generation=False, 32 | do_keyphrase_infilling=False, 33 | kp_max_seq_len=10, 34 | max_mask_keyphrase_pairs=10, 35 | ): 36 | self.tokenizer = tokenizer 37 | self.mlm_probability = mlm_probability 38 | self.kp_mask_percentage = kp_mask_percentage 39 | self.kp_replace_percentage = kp_replace_percentage 40 | self.keyphrase_universe_ids = keyphrase_universe_ids 41 | self.max_keyphrase_pairs = max_keyphrase_pairs 42 | self.max_mask_keyphrase_pairs = max_mask_keyphrase_pairs 43 | self.max_seq_len = max_seq_len 44 | self.kp_max_seq_len = kp_max_seq_len 45 | self.label_ignore_index = label_ignore_index 46 | self.pad_start_index = 1 47 | self.pad_end_index = 0 48 | self.do_generation = do_generation 49 | self.use_bart = use_bart 50 | self.do_keyphrase_generation = do_keyphrase_generation 51 | self.do_keyphrase_infilling = do_keyphrase_infilling 52 | if self.do_keyphrase_infilling: 53 | self.kp_mask_percentage = 0.0 54 | self.kp_infill_percentage = kp_mask_percentage 55 | else: 56 | self.kp_infill_percentage = 0.0 57 | 58 | def __call__(self, examples): 59 | original_input_ids = [] 60 | updated_input_ids = [] 61 | mask_labels = [] 62 | kp_mask_labels = [] 63 | overall_keyphrase_indexes = [] 64 | overall_keyphrase_replacement_labels = [] 65 | overall_catseq_keyphrase_input_ids = [] 66 | overall_catseq_keyphrase_decoder_input_ids = [] 67 | overall_masked_keyphrase_indexes = [] 68 | overall_masked_keyphrase_labels = [] 69 | overall_keyphrase_mask_num_tok_labels = [] 70 | for idx, e in enumerate(examples): 71 | input_ids = e["input_ids"] 72 | keyphrases_input_ids = e["keyphrases_input_ids"] 73 | ( 74 | input_ids, 75 | keyphrases_input_ids, 76 | keyphrase_indexes, 77 | replaced_keyphrase_indexes, 78 | keyphrase_replacement_labels, 79 | masked_keyphrase_indexes, 80 | keyphrase_mask_labels, 81 | keyphrase_mask_num_tok_labels, 82 | ) = self.replace_keyphrases(input_ids, keyphrases_input_ids) 83 | # Truncate input ids to max seq len 84 | input_ids = input_ids[: self.max_seq_len] 85 | mask_res = self.kp_and_whole_word_mask( 86 | input_ids, 87 | keyphrases_input_ids, 88 | replaced_kp_indexes=replaced_keyphrase_indexes, 89 | ) 90 | 91 | if self.keyphrase_universe_ids is not None: 92 | # Skip predictions for replacement on Keyphrases that were masked 93 | mask_skipped_keyphrase_indexes = [] 94 | mask_skipped_replacement_labels = [] 95 | assert len(keyphrase_indexes) == len(keyphrase_replacement_labels) 96 | for keyphrase_index, label in zip( 97 | keyphrase_indexes, keyphrase_replacement_labels 98 | ): 99 | # Skip keyphrases that are out of scope 100 | if ( 101 | not keyphrase_index 102 | or len(keyphrase_index) == 0 103 | or keyphrase_index[0] >= self.max_seq_len 104 | ): 105 | continue 106 | if ( 107 | not self.is_keyphrase_masked(mask_res[1], keyphrase_index) 108 | and len(mask_skipped_keyphrase_indexes) 109 | < self.max_keyphrase_pairs 110 | ): 111 | # Find boundary word indexes 112 | keyphrase_start_idx = ( 113 | keyphrase_index[0] - 1 114 | if (keyphrase_index[0] - 1) >= 0 115 | else 0 116 | ) 117 | keyphrase_end_idx = ( 118 | keyphrase_index[-1] + 1 119 | if (keyphrase_index[-1] + 1) < self.max_seq_len 120 | else self.max_seq_len - 1 121 | ) 122 | mask_skipped_keyphrase_indexes.append( 123 | (keyphrase_start_idx, keyphrase_end_idx) 124 | ) 125 | mask_skipped_replacement_labels.append(label) 126 | # Truncate to max length 127 | mask_skipped_keyphrase_indexes = mask_skipped_keyphrase_indexes[ 128 | : self.max_keyphrase_pairs 129 | ] 130 | mask_skipped_replacement_labels = mask_skipped_replacement_labels[ 131 | : self.max_keyphrase_pairs 132 | ] 133 | # Pad pairs to an equal length 134 | pair_padding_length = self.max_keyphrase_pairs - len( 135 | mask_skipped_keyphrase_indexes 136 | ) 137 | mask_skipped_keyphrase_indexes += [ 138 | (self.pad_start_index, self.pad_end_index) 139 | ] * pair_padding_length 140 | mask_skipped_replacement_labels += [ 141 | self.label_ignore_index 142 | ] * pair_padding_length 143 | 144 | # Add to batch 145 | overall_keyphrase_indexes.append(mask_skipped_keyphrase_indexes) 146 | overall_keyphrase_replacement_labels.append( 147 | mask_skipped_replacement_labels 148 | ) 149 | 150 | if self.do_generation: 151 | # Force same length batches 152 | catseq_keyphrase_input_ids = e["catseq_keyphrase_input_ids"][ 153 | : self.max_seq_len 154 | ] 155 | catseq_keyphrase_decoder_input_ids = e[ 156 | "catseq_keyphrase_input_ids" 157 | ][: self.max_seq_len] 158 | catseq_padding_length = self.max_seq_len - len( 159 | catseq_keyphrase_input_ids 160 | ) 161 | catseq_keyphrase_input_ids += [ 162 | self.label_ignore_index 163 | ] * catseq_padding_length 164 | overall_catseq_keyphrase_input_ids.append( 165 | catseq_keyphrase_input_ids 166 | ) 167 | catseq_keyphrase_decoder_input_ids += [ 168 | self.tokenizer.pad_token_id 169 | ] * catseq_padding_length 170 | overall_catseq_keyphrase_decoder_input_ids.append( 171 | catseq_keyphrase_decoder_input_ids 172 | ) 173 | 174 | # Truncate to max length 175 | input_ids = input_ids[: self.max_seq_len] 176 | pad_mask_label = mask_res[0][: self.max_seq_len] 177 | pad_kp_mask_label = mask_res[1][: self.max_seq_len] 178 | # Add padding 179 | input_padding_length = self.max_seq_len - len(input_ids) 180 | input_ids += [self.tokenizer.pad_token_id] * input_padding_length 181 | pad_mask_label += [1] * input_padding_length 182 | pad_kp_mask_label += [1] * input_padding_length 183 | 184 | # Update input ids 185 | original_input_ids.append(e["input_ids"]) 186 | updated_input_ids.append(input_ids) 187 | mask_labels.append(pad_mask_label) 188 | kp_mask_labels.append(pad_kp_mask_label) 189 | 190 | if self.do_keyphrase_infilling: 191 | # Truncate to max length 192 | masked_keyphrase_indexes = masked_keyphrase_indexes[ 193 | : self.max_mask_keyphrase_pairs 194 | ] 195 | keyphrase_mask_labels = keyphrase_mask_labels[ 196 | : self.max_mask_keyphrase_pairs 197 | ] 198 | keyphrase_mask_num_tok_labels = keyphrase_mask_num_tok_labels[ 199 | : self.max_mask_keyphrase_pairs 200 | ] 201 | # Add padding if required 202 | pair_padding_length = self.max_mask_keyphrase_pairs - len( 203 | masked_keyphrase_indexes 204 | ) 205 | masked_keyphrase_indexes += [ 206 | (self.pad_start_index, self.pad_end_index) 207 | ] * pair_padding_length 208 | keyphrase_mask_labels += [ 209 | ([self.label_ignore_index] * self.kp_max_seq_len) 210 | for _ in range(pair_padding_length) 211 | ] 212 | keyphrase_mask_num_tok_labels += [ 213 | self.label_ignore_index 214 | ] * pair_padding_length 215 | 216 | overall_masked_keyphrase_indexes.append(masked_keyphrase_indexes) 217 | overall_masked_keyphrase_labels.append(keyphrase_mask_labels) 218 | overall_keyphrase_mask_num_tok_labels.append( 219 | keyphrase_mask_num_tok_labels 220 | ) 221 | 222 | # collate 223 | # batches with pad token defined in tokenizer 224 | batch_input = _collate_batch(updated_input_ids, self.tokenizer) 225 | batch_mask = _collate_batch(mask_labels, self.tokenizer) 226 | kp_batch_mask = _collate_batch(kp_mask_labels, self.tokenizer) 227 | # mask 228 | inputs, labels = self.mask_tokens_and_kp(batch_input, batch_mask, kp_batch_mask) 229 | if self.keyphrase_universe_ids is not None: 230 | # batches for keyphrase replacement 231 | batch_keyphrase_indexes = _collate_batch( 232 | overall_keyphrase_indexes, self.tokenizer 233 | ) 234 | batch_keyphrase_replacement_labels = _collate_batch( 235 | overall_keyphrase_replacement_labels, self.tokenizer 236 | ) 237 | 238 | if self.do_keyphrase_infilling: 239 | batch_masked_keyphrase_indexes = _collate_batch( 240 | overall_masked_keyphrase_indexes, self.tokenizer 241 | ) 242 | batch_masked_keyphrase_labels = self._collate_label_batch( 243 | overall_masked_keyphrase_labels, self.tokenizer 244 | ) 245 | batch_keyphrase_mask_num_tok_labels = _collate_batch( 246 | overall_keyphrase_mask_num_tok_labels, self.tokenizer 247 | ) 248 | 249 | if self.do_generation: 250 | batch_catseq_keyphrase_decoder_inputs = _collate_batch( 251 | overall_catseq_keyphrase_decoder_input_ids, self.tokenizer 252 | ) 253 | batch_catseq_keyphrase_labels = self._collate_label_batch( 254 | overall_catseq_keyphrase_input_ids, self.tokenizer 255 | ) 256 | batch_original_labels = self._collate_label_batch( 257 | original_input_ids, self.tokenizer 258 | ) 259 | if self.use_bart: 260 | if self.do_keyphrase_generation: 261 | return { 262 | "input_ids": inputs, 263 | "labels": batch_catseq_keyphrase_labels, 264 | } 265 | else: 266 | return {"input_ids": inputs, "labels": batch_original_labels} 267 | 268 | return { 269 | "input_ids": inputs, 270 | "decoder_input_ids": batch_catseq_keyphrase_decoder_inputs, 271 | "labels": batch_catseq_keyphrase_labels, 272 | "keyphrase_pairs": batch_keyphrase_indexes, 273 | "replacement_labels": batch_keyphrase_replacement_labels, 274 | } 275 | if self.do_keyphrase_infilling: 276 | if self.keyphrase_universe_ids: 277 | return { 278 | "input_ids": inputs, 279 | "labels": labels, 280 | "keyphrase_pairs": batch_keyphrase_indexes, 281 | "replacement_labels": batch_keyphrase_replacement_labels, 282 | "masked_keyphrase_pairs": batch_masked_keyphrase_indexes, 283 | "masked_keyphrase_labels": batch_masked_keyphrase_labels, 284 | "keyphrase_mask_num_tok_labels": batch_keyphrase_mask_num_tok_labels, 285 | } 286 | else: 287 | return { 288 | "input_ids": inputs, 289 | "labels": labels, 290 | "masked_keyphrase_pairs": batch_masked_keyphrase_indexes, 291 | "masked_keyphrase_labels": batch_masked_keyphrase_labels, 292 | "keyphrase_mask_num_tok_labels": batch_keyphrase_mask_num_tok_labels, 293 | } 294 | if self.keyphrase_universe_ids: 295 | return { 296 | "input_ids": inputs, 297 | "labels": labels, 298 | "keyphrase_pairs": batch_keyphrase_indexes, 299 | "replacement_labels": batch_keyphrase_replacement_labels, 300 | } 301 | else: 302 | return {"input_ids": inputs, "labels": labels} 303 | 304 | def _collate_label_batch(self, examples, tokenizer): 305 | """Collate `examples` into a batch, using the information in `tokenizer` for padding if necessary.""" 306 | # Tensorize if necessary. 307 | if isinstance(examples[0], (list, tuple)): 308 | examples = [torch.tensor(e, dtype=torch.long) for e in examples] 309 | 310 | # Check if padding is necessary. 311 | length_of_first = examples[0].size(0) 312 | are_tensors_same_length = all(x.size(0) == length_of_first for x in examples) 313 | if are_tensors_same_length: 314 | return torch.stack(examples, dim=0) 315 | 316 | # If yes, check if we have a `pad_token`. 317 | if tokenizer._pad_token is None: 318 | raise ValueError( 319 | "You are attempting to pad samples but the tokenizer you are using" 320 | f" ({tokenizer.__class__.__name__}) does not have a pad token." 321 | ) 322 | 323 | # Creating the full tensor and filling it with our data. 324 | max_length = max(x.size(0) for x in examples) 325 | result = examples[0].new_full( 326 | [len(examples), max_length], self.label_ignore_index 327 | ) 328 | for i, example in enumerate(examples): 329 | if tokenizer.padding_side == "right": 330 | result[i, : example.shape[0]] = example 331 | else: 332 | result[i, -example.shape[0] :] = example 333 | return result 334 | 335 | @staticmethod 336 | def is_keyphrase_masked(keyphrase_masks, keyphrase_indexes): 337 | for keyphrase_index in keyphrase_indexes: 338 | if ( 339 | keyphrase_index < len(keyphrase_masks) 340 | and keyphrase_masks[keyphrase_index] == 0 341 | ): 342 | return False 343 | return True 344 | 345 | def replace_keyphrases(self, input_ids, keyphrases_input_ids): 346 | """ 347 | Replace a defined percentage of keyphrases with another keyphrase 348 | from the keyphrase universe 349 | """ 350 | if ( 351 | self.keyphrase_universe_ids is None or self.kp_replace_percentage == 0 352 | ) and self.do_keyphrase_infilling == False: 353 | return input_ids, keyphrases_input_ids, [], [], [], [], [], [] 354 | new_input_ids = [] 355 | new_keyphrase_input_ids = [] 356 | keyphrase_replacement_labels = [] 357 | replaced_keyphrase_indexes = [] 358 | keyphrase_indexes = [] 359 | masked_keyphrase_indexes = [] 360 | keyphrase_mask_labels = [] 361 | keyphrase_mask_num_tok_labels = [] 362 | # Replace Keyphrases in Input 363 | input_idx = 0 364 | while input_idx < len(input_ids): 365 | input_id = input_ids[input_idx] 366 | found_keyphrase = False 367 | for keyphrase in keyphrases_input_ids: 368 | keyphrase_idx = input_idx + len(keyphrase) 369 | if input_ids[input_idx:keyphrase_idx] == keyphrase: 370 | found_keyphrase = True 371 | curr_new_input_idx = len(new_input_ids) 372 | if ( 373 | self.do_keyphrase_infilling 374 | and random.random() < self.kp_infill_percentage 375 | ): 376 | # Choose a random keyphrase from the keyphrase universe 377 | replaced_keyphrase = [self.tokenizer.mask_token_id] 378 | # Compute Indexes 379 | start_idx = ( 380 | curr_new_input_idx - 1 381 | if (curr_new_input_idx - 1) >= 0 382 | else 0 383 | ) 384 | if start_idx >= self.max_seq_len: 385 | continue 386 | end_idx = ( 387 | curr_new_input_idx + 1 388 | if (curr_new_input_idx + 1) < self.max_seq_len 389 | else self.max_seq_len - 1 390 | ) 391 | masked_keyphrase_indexes.append([start_idx, end_idx]) 392 | # Update Input IDs with replaced keyphrase 393 | new_input_ids.extend(replaced_keyphrase) 394 | # Capture the number of tokens label 395 | kp_num_tok = ( 396 | len(keyphrase) 397 | if len(keyphrase) < self.kp_max_seq_len 398 | else self.kp_max_seq_len - 1 399 | ) 400 | keyphrase_mask_num_tok_labels.append(kp_num_tok) 401 | # Update Keyphrase Input IDs 402 | kp_pad_len = self.kp_max_seq_len - len(keyphrase) 403 | kp_label = keyphrase 404 | if kp_pad_len < 0: 405 | kp_label = keyphrase[: self.kp_max_seq_len] 406 | else: 407 | kp_label += [self.label_ignore_index] * kp_pad_len 408 | keyphrase_mask_labels.append(kp_label) 409 | # Consolidate input IDs uptil now 410 | elif ( 411 | self.keyphrase_universe_ids is not None 412 | and random.random() < self.kp_replace_percentage 413 | ): 414 | # Choose a random keyphrase from the keyphrase universe 415 | replaced_keyphrase = random.choice(self.keyphrase_universe_ids) 416 | # Make sure we aren't replacing with the same keyphrase 417 | while replaced_keyphrase == keyphrase: 418 | replaced_keyphrase = random.choice( 419 | self.keyphrase_universe_ids 420 | ) 421 | # Compute Indexes 422 | indexes = [ 423 | idx 424 | for idx in range( 425 | curr_new_input_idx, 426 | curr_new_input_idx + len(replaced_keyphrase), 427 | ) 428 | ] 429 | keyphrase_indexes.append(indexes) 430 | replaced_keyphrase_indexes.append(indexes) 431 | # Update Input IDs with replaced keyphrase 432 | new_input_ids.extend(replaced_keyphrase) 433 | # Update Keyphrase Input IDs 434 | new_keyphrase_input_ids.append(replaced_keyphrase) 435 | # Capture the replacement label 436 | keyphrase_replacement_labels.append(1) 437 | else: 438 | # Compute Indexes 439 | indexes = [ 440 | idx 441 | for idx in range( 442 | curr_new_input_idx, curr_new_input_idx + len(keyphrase) 443 | ) 444 | ] 445 | keyphrase_indexes.append(indexes) 446 | # Update Input IDs with original keyphrase 447 | new_input_ids.extend(keyphrase) 448 | # Update Keyphrase Input IDs 449 | new_keyphrase_input_ids.append(keyphrase) 450 | # Capture the non-replacement label 451 | keyphrase_replacement_labels.append(0) 452 | # Skip Input to after Keyphrase 453 | input_idx = keyphrase_idx 454 | break 455 | if not found_keyphrase: 456 | new_input_ids.append(input_id) 457 | input_idx += 1 458 | 459 | return ( 460 | new_input_ids, 461 | new_keyphrase_input_ids, 462 | keyphrase_indexes, 463 | replaced_keyphrase_indexes, 464 | keyphrase_replacement_labels, 465 | masked_keyphrase_indexes, 466 | keyphrase_mask_labels, 467 | keyphrase_mask_num_tok_labels, 468 | ) 469 | 470 | def kp_and_whole_word_mask( 471 | self, 472 | input_tokens, 473 | kp_tokens_list, 474 | max_predictions=512, 475 | replaced_kp_indexes=None, 476 | ): 477 | """ 478 | Get 0/1 labels for masked tokens with whole word mask proxy 479 | """ 480 | 481 | cand_indexes = [] 482 | kp_indexes = [] 483 | for (i, token) in enumerate(input_tokens): 484 | if ( 485 | token == self.tokenizer.cls_token_id 486 | or token == self.tokenizer.sep_token_id 487 | or token == self.tokenizer.mask_token_id 488 | ): 489 | continue 490 | kp_flag = False 491 | for kp in kp_tokens_list: # kp = ["KP1-T1", "KP1-T2"] 492 | j = i + len(kp) 493 | if j < len(input_tokens): 494 | if input_tokens[i:j] == kp: # input_tokens = ["KP1-T1", "KP1-T2"] 495 | kp_indexes.append( 496 | [x for x in range(i, j)] 497 | ) # kp_indexes = ["index of KP1-T1", "index of KP1-T2"] 498 | i = j - 1 499 | kp_flag = True 500 | break 501 | if ( 502 | kp_flag 503 | ): # if token is included in kp mask then don't include in random token mask 504 | continue 505 | if self.tokenizer._convert_id_to_token(token).startswith("Ġ"): 506 | cand_indexes.append([i]) 507 | else: 508 | if len(cand_indexes) >= 1: 509 | cand_indexes[-1].append(i) 510 | else: 511 | cand_indexes.append([i]) 512 | 513 | if replaced_kp_indexes: 514 | filtered_kp_indexes = [ 515 | kp_index 516 | for kp_index in kp_indexes 517 | if kp_index not in replaced_kp_indexes 518 | ] 519 | else: 520 | filtered_kp_indexes = kp_indexes 521 | 522 | tok_to_predict = min( 523 | max_predictions, 524 | max( 525 | 1, 526 | int( 527 | round( 528 | len(input_tokens) 529 | * (1 - self.kp_mask_percentage) 530 | * self.mlm_probability 531 | ) 532 | ), 533 | ), 534 | ) 535 | # Probability of masking keyphrases is KP_MASK_PERCENTAGE * MLM_PROBABILITY over the total number 536 | # of tokens in the document (input_tokens) to make sure we are masking ~15% of tokens in line 537 | # with all other language modelling pre-training literature 538 | kp_to_predict = min( 539 | max_predictions, 540 | max( 541 | 1, 542 | int( 543 | round( 544 | len(input_tokens) 545 | * self.kp_mask_percentage 546 | * self.mlm_probability 547 | ) 548 | ), 549 | ), 550 | ) 551 | 552 | tok_mask_labels = self.get_mask_labels( 553 | cand_indexes=cand_indexes, 554 | len_input_tokens=len(input_tokens), 555 | num_to_predict=tok_to_predict, 556 | ) 557 | kp_mask_labels = self.get_mask_labels( 558 | cand_indexes=filtered_kp_indexes, 559 | len_input_tokens=len(input_tokens), 560 | num_to_predict=kp_to_predict, 561 | ) 562 | return tok_mask_labels, kp_mask_labels 563 | 564 | def get_mask_labels(self, cand_indexes, len_input_tokens, num_to_predict): 565 | random.shuffle(cand_indexes) 566 | masked_lms = [] 567 | covered_indexes = set() 568 | for index_set in cand_indexes: 569 | if len(masked_lms) >= num_to_predict: 570 | break 571 | # If adding a whole-word mask would exceed the maximum number of 572 | # predictions, then just skip this candidate. 573 | if len(masked_lms) + len(index_set) > num_to_predict: 574 | continue 575 | is_any_index_covered = False 576 | for index in index_set: 577 | if index in covered_indexes: 578 | is_any_index_covered = True 579 | break 580 | if is_any_index_covered: 581 | continue 582 | for index in index_set: 583 | covered_indexes.add(index) 584 | masked_lms.append(index) 585 | 586 | assert len(covered_indexes) == len(masked_lms) 587 | mask_labels = [ 588 | 1 if i in covered_indexes else 0 for i in range(len_input_tokens) 589 | ] 590 | return mask_labels 591 | 592 | def mask_tokens_and_kp(self, inputs, mask_labels, kp_mask_labels): 593 | """ 594 | Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original. Set 595 | 'mask_labels' means we use whole word mask (wwm), we directly mask idxs according to it's ref. 596 | """ 597 | 598 | if self.tokenizer.mask_token is None: 599 | raise ValueError( 600 | "This tokenizer does not have a mask token which is necessary for masked language " 601 | "modeling. Remove the --mlm flag if you want to use this tokenizer." 602 | ) 603 | labels = inputs.clone() 604 | # We sample a few tokens in each sequence for masked-LM training 605 | # (with probability args.mlm_probability defaults to 0.15 in Bert/RoBERTa) 606 | 607 | probability_matrix = mask_labels 608 | kp_probability_matrix = kp_mask_labels 609 | 610 | special_tokens_mask = [ 611 | self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) 612 | for val in labels.tolist() 613 | ] 614 | # do zero for special tokens 615 | probability_matrix.masked_fill_( 616 | torch.tensor(special_tokens_mask, dtype=torch.bool), value=0.0 617 | ) 618 | kp_probability_matrix.masked_fill_( 619 | torch.tensor(special_tokens_mask, dtype=torch.bool), value=0.0 620 | ) 621 | 622 | # assert kp_probability_matrix & probability_matrix == 0 623 | # do zero for padded points 624 | if self.tokenizer._pad_token is not None: 625 | padding_mask = labels.eq(self.tokenizer.pad_token_id) 626 | probability_matrix.masked_fill_(padding_mask, value=0.0) 627 | kp_probability_matrix.masked_fill_(padding_mask, value=0.0) 628 | 629 | masked_indices = probability_matrix.bool() 630 | kp_masked_indices = kp_probability_matrix.bool() 631 | # get the gold lables 632 | labels[ 633 | ~(masked_indices | kp_masked_indices) 634 | ] = ( 635 | self.label_ignore_index 636 | ) # We only compute loss on random masked tokens and kp masked token else is set to -100 637 | 638 | # 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK]) 639 | indices_replaced = ( 640 | torch.bernoulli(torch.full(labels.shape, 0.8)).bool() & masked_indices 641 | ) 642 | inputs[indices_replaced] = self.tokenizer.convert_tokens_to_ids( 643 | self.tokenizer.mask_token 644 | ) 645 | # 80 % masking for keyphrases 646 | kp_indices_replaced = ( 647 | torch.bernoulli(torch.full(labels.shape, 0.8)).bool() & kp_masked_indices 648 | ) 649 | inputs[kp_indices_replaced] = self.tokenizer.convert_tokens_to_ids( 650 | self.tokenizer.mask_token 651 | ) 652 | # generate random tokens 653 | random_words = torch.randint( 654 | len(self.tokenizer), labels.shape, dtype=torch.long 655 | ) 656 | # 10% of the time, we replace masked input tokens with random word 657 | indices_random = ( 658 | torch.bernoulli(torch.full(labels.shape, 0.5)).bool() 659 | & masked_indices 660 | & ~indices_replaced 661 | ) 662 | inputs[indices_random] = random_words[indices_random] 663 | 664 | # replace 10 # kp tokens with random indices 665 | # TODO: If keyphrase_universe available, replace with another keyphrase and capture 666 | # indices to classify replacement 667 | kp_indices_random = ( 668 | torch.bernoulli(torch.full(labels.shape, 0.5)).bool() 669 | & kp_masked_indices 670 | & ~kp_indices_replaced 671 | ) 672 | inputs[kp_indices_random] = random_words[kp_indices_random] 673 | # The rest of the time (10% of the time) we keep the masked input tokens unchanged 674 | # print("inside mask tok functiom \n",inputs,"\n", labels,"\n") 675 | 676 | # generation - t1, t2, t3 (actual) - [MASK], t4 [MASK], t5, t6 677 | # replacement - t1, t2, t3 (actual) - [MASK], t4 [MASK], t5, t6 (replace) t9 678 | 679 | return inputs, labels 680 | -------------------------------------------------------------------------------- /utils/logger.py: -------------------------------------------------------------------------------- 1 | import yaml 2 | import logging 3 | import logging.config 4 | from pathlib import Path 5 | 6 | from .saving import log_path 7 | 8 | 9 | LOG_LEVEL = logging.INFO 10 | 11 | 12 | def setup_logging(run_config, log_config="logging.yml") -> None: 13 | """ 14 | Setup ``logging.config`` 15 | 16 | Parameters 17 | ---------- 18 | run_config : str 19 | Path to configuration file for run 20 | 21 | log_config : str 22 | Path to configuration file for logging 23 | """ 24 | log_config = Path(log_config) 25 | 26 | if not log_config.exists(): 27 | logging.basicConfig(level=LOG_LEVEL) 28 | logger = logging.getLogger("setup") 29 | logger.warning(f'"{log_config}" not found. Using basicConfig.') 30 | return 31 | 32 | with open(log_config, "rt") as f: 33 | config = yaml.safe_load(f.read()) 34 | 35 | # modify logging paths based on run config 36 | run_path = log_path(run_config) 37 | for _, handler in config["handlers"].items(): 38 | if "filename" in handler: 39 | handler["filename"] = str(run_path / handler["filename"]) 40 | 41 | logging.config.dictConfig(config) 42 | 43 | 44 | def setup_logger(name): 45 | log = logging.getLogger(f"klm.{name}") 46 | log.setLevel(LOG_LEVEL) 47 | return log 48 | -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | 4 | 5 | def update_tokenizer_config(dirname, filename): 6 | if os.path.exists(os.path.join(dirname, filename)): 7 | with open(os.path.join(dirname, filename), "r") as f: 8 | for line in f: 9 | data = json.loads(line) 10 | data["tokenizer_file"] = os.path.join(dirname, "tokenizer.json") 11 | break 12 | with open(os.path.join(dirname, filename), "w") as f: 13 | f.write(json.dumps(data)) 14 | --------------------------------------------------------------------------------