├── .gitattributes ├── LICENSE ├── README.md ├── codes ├── fine-tuning │ ├── convert_tf_checkpoint_to_pytorch.py │ ├── extract_features.py │ ├── modeling.py │ ├── modeling_last_concat_avg.py │ ├── modeling_multitask.py │ ├── modeling_single_layer.py │ ├── optimization.py │ ├── run_classifier.py │ ├── run_classifier_discriminative.py │ ├── run_classifier_multitask.py │ ├── run_classifier_no_decay.py │ ├── run_classifier_single_layer.py │ └── tokenization.py └── further-pre-training │ ├── create_pretraining_data.py │ ├── extract_features.py │ ├── generate_corpus_agnews.py │ ├── modeling.py │ ├── optimization.py │ ├── run_pretraining.py │ └── tokenization.py └── data └── agnews_corpus_test.txt /.gitattributes: -------------------------------------------------------------------------------- 1 | # Auto detect text files and perform LF normalization 2 | * text=auto 3 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # How to Fine-Tune BERT for Text Classification? 2 | 3 | This is the code and source for the paper [How to Fine-Tune BERT for Text Classification?](https://arxiv.org/abs/1905.05583) 4 | 5 | In this paper, we conduct exhaustive experiments to investigate different fine-tuning methods of BERT on text classification task and provide a general solution for BERT fine-tuning. 6 | 7 | 8 | \*********** **update at Mar 14, 2020** \************* 9 | 10 | Our checkpoint can be loaded in BertEmbedding from the latest [fastNLP](https://github.com/fastnlp/fastNLP) package. 11 | 12 | [Link to](https://github.com/fastnlp/fastNLP/blob/master/fastNLP/embeddings/bert_embedding.py) fastNLP.embeddings.BertEmbedding 13 | 14 | ## Requirements 15 | 16 | For further pre-training, we borrow some code from Google BERT. Thus, we need: 17 | 18 | + tensorflow==1.1x 19 | + spacy 20 | + pandas 21 | + numpy 22 | 23 | Note that you need Python 3.7 or earlier for compatibility with tensorflow 1.1x. 24 | 25 | For fine-tuning, we borrow some codes from pytorch-pretrained-bert package (now well known as transformers). Thus, we need: 26 | 27 | + torch>=0.4.1,<=1.2.0 28 | 29 | 30 | 31 | ## Run the code 32 | 33 | ### 1) Prepare the data set: 34 | 35 | #### Sogou News 36 | 37 | We determine the category of the news based on the URL, such as “sports” corresponding 38 | to “http://sports.sohu.com”. We choose 6 categories 39 | – “sports”, “house”, “business”, “entertainment”, 40 | “women” and “technology”. The number 41 | of training samples selected for each class is 9,000 42 | and testing 1,000. 43 | 44 | Data is available at [here](https://drive.google.com/drive/folders/1Rbi0tnvsQrsHvT_353pMdIbRwDlLhfwM). 45 | 46 | #### The rest data sets 47 | 48 | The rest data sets were built by [Zhang et al. (2015)](https://papers.nips.cc/paper/5782-character-level-convolutional-networks-for-text-classification.pdf). 49 | We download from [URL](https://drive.google.com/drive/u/0/folders/0Bz8a_Dbh9Qhbfll6bVpmNUtUcFdjYmF2SEpmZUZUcVNiMUw1TWN6RDV3a0JHT3kxLVhVR2M) created by Xiang Zhang. 50 | 51 | 52 | ### 2) Prepare Google BERT: 53 | 54 | [BERT-Base, Uncased](https://storage.googleapis.com/bert_models/2018_10_18/uncased_L-12_H-768_A-12.zip) 55 | 56 | [BERT-Base, Chinese](https://storage.googleapis.com/bert_models/2018_11_03/chinese_L-12_H-768_A-12.zip) 57 | 58 | 59 | ### 3) Further Pre-Training: 60 | 61 | #### Generate Further Pre-Training Corpus 62 | 63 | Here we use AG's News as example: 64 | ```shell 65 | python generate_corpus_agnews.py 66 | ``` 67 | File ``agnews_corpus_test.txt`` can be found in directory ``./data``. 68 | 69 | #### Run Further Pre-Training 70 | 71 | ```shell 72 | python create_pretraining_data.py \ 73 | --input_file=./AGnews_corpus.txt \ 74 | --output_file=tmp/tf_AGnews.tfrecord \ 75 | --vocab_file=./uncased_L-12_H-768_A-12/vocab.txt \ 76 | --do_lower_case=True \ 77 | --max_seq_length=128 \ 78 | --max_predictions_per_seq=20 \ 79 | --masked_lm_prob=0.15 \ 80 | --random_seed=12345 \ 81 | --dupe_factor=5 82 | 83 | python run_pretraining.py \ 84 | --input_file=./tmp/tf_AGnews.tfrecord \ 85 | --output_dir=./uncased_L-12_H-768_A-12_AGnews_pretrain \ 86 | --do_train=True \ 87 | --do_eval=True \ 88 | --bert_config_file=./uncased_L-12_H-768_A-12/bert_config.json \ 89 | --init_checkpoint=./uncased_L-12_H-768_A-12/bert_model.ckpt \ 90 | --train_batch_size=32 \ 91 | --max_seq_length=128 \ 92 | --max_predictions_per_seq=20 \ 93 | --num_train_steps=100000 \ 94 | --num_warmup_steps=10000 \ 95 | --save_checkpoints_steps=10000 \ 96 | --learning_rate=5e-5 97 | ``` 98 | 99 | 100 | ### 4) Fine-Tuning 101 | 102 | #### Convert Tensorflow checkpoint to PyTorch checkpoint 103 | 104 | ```shell 105 | python convert_tf_checkpoint_to_pytorch.py \ 106 | --tf_checkpoint_path ./uncased_L-12_H-768_A-12_AGnews_pretrain/model.ckpt-100000 \ 107 | --bert_config_file ./uncased_L-12_H-768_A-12_AGnews_pretrain/bert_config.json \ 108 | --pytorch_dump_path ./uncased_L-12_H-768_A-12_AGnews_pretrain/pytorch_model.bin 109 | ``` 110 | 111 | #### Fine-Tuning on downstream tasks 112 | 113 | While fine-tuning on downstream tasks, we notice that different GPU (e.g.: 1080Ti and Titan Xp) may cause 114 | slight differences in experimental results even though we fix the initial random seed. 115 | Here we use 1080Ti * 4 as example. 116 | 117 | Take Exp-I (See Section 5.3) as example, 118 | 119 | ```shell 120 | export CUDA_VISIBLE_DEVICES=0,1,2,3 121 | python run_classifier_single_layer.py \ 122 | --task_name imdb \ 123 | --do_train \ 124 | --do_eval \ 125 | --do_lower_case \ 126 | --data_dir ./IMDB_data/ \ 127 | --vocab_file ./uncased_L-12_H-768_A-12_IMDB_pretrain/vocab.txt \ 128 | --bert_config_file ./uncased_L-12_H-768_A-12_IMDB_pretrain/bert_config.json \ 129 | --init_checkpoint ./uncased_L-12_H-768_A-12_IMDB_pretrain/pytorch_model.bin \ 130 | --max_seq_length 512 \ 131 | --train_batch_size 24 \ 132 | --learning_rate 2e-5 \ 133 | --num_train_epochs 3.0 \ 134 | --output_dir ./imdb \ 135 | --seed 42 \ 136 | --layers 11 10 \ 137 | --trunc_medium -1 138 | ``` 139 | 140 | where ``num_train_epochs`` can be 3.0, 4.0, or 6.0. 141 | 142 | ``layers`` indicates list of layers which will be taken as feature for classification. 143 | -2 means use pooled output, -1 means concat all layer, the command above means concat 144 | layer-10 and layer-11 (last two layers). 145 | 146 | ``trunc_medium`` indicates dealing with long texts. -2 means head-only, -1 means tail-only, 147 | 0 means head-half + tail-half (e.g.: head256+tail256), 148 | other natural number k means head-k + tail-rest (e.g.: head-k + tail-(512-k)). 149 | 150 | There also other arguments for fine-tuning: 151 | 152 | ``pooling_type`` indicates which feature will be used for classification. `mean` means 153 | mean-pooling for hidden state of the whole sequence, `max` means max-pooling, default means 154 | taking hidden state of `[CLS]` token as features. 155 | 156 | ``layer_learning_rate`` and ``layer_learning_rate_decay`` in ``run_classifier_discriminative.py`` 157 | indicates layer-wise decreasing layer rate (See Section 5.3.4). 158 | 159 | 160 | ## Further Pre-Trained Checkpoints 161 | 162 | We upload IMDb-based further pre-trained checkpoints at 163 | [here](https://drive.google.com/drive/folders/1Rbi0tnvsQrsHvT_353pMdIbRwDlLhfwM). 164 | 165 | For other checkpoints, please contact us by e-mail. 166 | 167 | ## How to cite our paper 168 | 169 | ```text 170 | @inproceedings{sun2019fine, 171 | title={How to fine-tune {BERT} for text classification?}, 172 | author={Sun, Chi and Qiu, Xipeng and Xu, Yige and Huang, Xuanjing}, 173 | booktitle={China National Conference on Chinese Computational Linguistics}, 174 | pages={194--206}, 175 | year={2019}, 176 | organization={Springer} 177 | } 178 | ``` 179 | -------------------------------------------------------------------------------- /codes/fine-tuning/convert_tf_checkpoint_to_pytorch.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The HugginFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Convert BERT checkpoint.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import re 22 | import argparse 23 | import tensorflow as tf 24 | import torch 25 | import numpy as np 26 | 27 | from modeling import BertConfig, BertModel 28 | 29 | parser = argparse.ArgumentParser() 30 | 31 | ## Required parameters 32 | parser.add_argument("--tf_checkpoint_path", 33 | default = None, 34 | type = str, 35 | required = True, 36 | help = "Path the TensorFlow checkpoint path.") 37 | parser.add_argument("--bert_config_file", 38 | default = None, 39 | type = str, 40 | required = True, 41 | help = "The config json file corresponding to the pre-trained BERT model. \n" 42 | "This specifies the model architecture.") 43 | parser.add_argument("--pytorch_dump_path", 44 | default = None, 45 | type = str, 46 | required = True, 47 | help = "Path to the output PyTorch model.") 48 | 49 | args = parser.parse_args() 50 | 51 | def convert(): 52 | # Initialise PyTorch model 53 | config = BertConfig.from_json_file(args.bert_config_file) 54 | model = BertModel(config) 55 | 56 | # Load weights from TF model 57 | path = args.tf_checkpoint_path 58 | print("Converting TensorFlow checkpoint from {}".format(path)) 59 | 60 | init_vars = tf.train.list_variables(path) 61 | names = [] 62 | arrays = [] 63 | for name, shape in init_vars: 64 | print("Loading {} with shape {}".format(name, shape)) 65 | array = tf.train.load_variable(path, name) 66 | print("Numpy array shape {}".format(array.shape)) 67 | names.append(name) 68 | arrays.append(array) 69 | 70 | for name, array in zip(names, arrays): 71 | name = name[5:] # skip "bert/" 72 | print("Loading {}".format(name)) 73 | name = name.split('/') 74 | if any(n in ["adam_v", "adam_m","l_step"] for n in name): 75 | print("Skipping {}".format("/".join(name))) 76 | continue 77 | if name[0] in ['redictions', 'eq_relationship']: 78 | print("Skipping") 79 | continue 80 | pointer = model 81 | for m_name in name: 82 | if re.fullmatch(r'[A-Za-z]+_\d+', m_name): 83 | l = re.split(r'_(\d+)', m_name) 84 | else: 85 | l = [m_name] 86 | if l[0] == 'kernel': 87 | pointer = getattr(pointer, 'weight') 88 | else: 89 | pointer = getattr(pointer, l[0]) 90 | if len(l) >= 2: 91 | num = int(l[1]) 92 | pointer = pointer[num] 93 | if m_name[-11:] == '_embeddings': 94 | pointer = getattr(pointer, 'weight') 95 | elif m_name == 'kernel': 96 | array = np.transpose(array) 97 | try: 98 | assert pointer.shape == array.shape 99 | except AssertionError as e: 100 | e.args += (pointer.shape, array.shape) 101 | raise 102 | pointer.data = torch.from_numpy(array) 103 | 104 | # Save pytorch-model 105 | torch.save(model.state_dict(), args.pytorch_dump_path) 106 | 107 | if __name__ == "__main__": 108 | convert() 109 | -------------------------------------------------------------------------------- /codes/fine-tuning/extract_features.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Extract pre-computed feature vectors from a PyTorch BERT model.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import argparse 22 | import codecs 23 | import collections 24 | import logging 25 | import json 26 | import re 27 | 28 | import torch 29 | from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler 30 | from torch.utils.data.distributed import DistributedSampler 31 | 32 | import tokenization 33 | from modeling import BertConfig, BertModel 34 | 35 | logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s', 36 | datefmt = '%m/%d/%Y %H:%M:%S', 37 | level = logging.INFO) 38 | logger = logging.getLogger(__name__) 39 | 40 | 41 | class InputExample(object): 42 | 43 | def __init__(self, unique_id, text_a, text_b): 44 | self.unique_id = unique_id 45 | self.text_a = text_a 46 | self.text_b = text_b 47 | 48 | 49 | class InputFeatures(object): 50 | """A single set of features of data.""" 51 | 52 | def __init__(self, unique_id, tokens, input_ids, input_mask, input_type_ids): 53 | self.unique_id = unique_id 54 | self.tokens = tokens 55 | self.input_ids = input_ids 56 | self.input_mask = input_mask 57 | self.input_type_ids = input_type_ids 58 | 59 | 60 | def convert_examples_to_features(examples, seq_length, tokenizer): 61 | """Loads a data file into a list of `InputBatch`s.""" 62 | 63 | features = [] 64 | for (ex_index, example) in enumerate(examples): 65 | tokens_a = tokenizer.tokenize(example.text_a) 66 | 67 | tokens_b = None 68 | if example.text_b: 69 | tokens_b = tokenizer.tokenize(example.text_b) 70 | 71 | if tokens_b: 72 | # Modifies `tokens_a` and `tokens_b` in place so that the total 73 | # length is less than the specified length. 74 | # Account for [CLS], [SEP], [SEP] with "- 3" 75 | _truncate_seq_pair(tokens_a, tokens_b, seq_length - 3) 76 | else: 77 | # Account for [CLS] and [SEP] with "- 2" 78 | if len(tokens_a) > seq_length - 2: 79 | tokens_a = tokens_a[0:(seq_length - 2)] 80 | 81 | # The convention in BERT is: 82 | # (a) For sequence pairs: 83 | # tokens: [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP] 84 | # type_ids: 0 0 0 0 0 0 0 0 1 1 1 1 1 1 85 | # (b) For single sequences: 86 | # tokens: [CLS] the dog is hairy . [SEP] 87 | # type_ids: 0 0 0 0 0 0 0 88 | # 89 | # Where "type_ids" are used to indicate whether this is the first 90 | # sequence or the second sequence. The embedding vectors for `type=0` and 91 | # `type=1` were learned during pre-training and are added to the wordpiece 92 | # embedding vector (and position vector). This is not *strictly* necessary 93 | # since the [SEP] token unambigiously separates the sequences, but it makes 94 | # it easier for the model to learn the concept of sequences. 95 | # 96 | # For classification tasks, the first vector (corresponding to [CLS]) is 97 | # used as as the "sentence vector". Note that this only makes sense because 98 | # the entire model is fine-tuned. 99 | tokens = [] 100 | input_type_ids = [] 101 | tokens.append("[CLS]") 102 | input_type_ids.append(0) 103 | for token in tokens_a: 104 | tokens.append(token) 105 | input_type_ids.append(0) 106 | tokens.append("[SEP]") 107 | input_type_ids.append(0) 108 | 109 | if tokens_b: 110 | for token in tokens_b: 111 | tokens.append(token) 112 | input_type_ids.append(1) 113 | tokens.append("[SEP]") 114 | input_type_ids.append(1) 115 | 116 | input_ids = tokenizer.convert_tokens_to_ids(tokens) 117 | 118 | # The mask has 1 for real tokens and 0 for padding tokens. Only real 119 | # tokens are attended to. 120 | input_mask = [1] * len(input_ids) 121 | 122 | # Zero-pad up to the sequence length. 123 | while len(input_ids) < seq_length: 124 | input_ids.append(0) 125 | input_mask.append(0) 126 | input_type_ids.append(0) 127 | 128 | assert len(input_ids) == seq_length 129 | assert len(input_mask) == seq_length 130 | assert len(input_type_ids) == seq_length 131 | 132 | if ex_index < 5: 133 | logger.info("*** Example ***") 134 | logger.info("unique_id: %s" % (example.unique_id)) 135 | logger.info("tokens: %s" % " ".join([str(x) for x in tokens])) 136 | logger.info("input_ids: %s" % " ".join([str(x) for x in input_ids])) 137 | logger.info("input_mask: %s" % " ".join([str(x) for x in input_mask])) 138 | logger.info( 139 | "input_type_ids: %s" % " ".join([str(x) for x in input_type_ids])) 140 | 141 | features.append( 142 | InputFeatures( 143 | unique_id=example.unique_id, 144 | tokens=tokens, 145 | input_ids=input_ids, 146 | input_mask=input_mask, 147 | input_type_ids=input_type_ids)) 148 | return features 149 | 150 | 151 | def _truncate_seq_pair(tokens_a, tokens_b, max_length): 152 | """Truncates a sequence pair in place to the maximum length.""" 153 | 154 | # This is a simple heuristic which will always truncate the longer sequence 155 | # one token at a time. This makes more sense than truncating an equal percent 156 | # of tokens from each, since if one sequence is very short then each token 157 | # that's truncated likely contains more information than a longer sequence. 158 | while True: 159 | total_length = len(tokens_a) + len(tokens_b) 160 | if total_length <= max_length: 161 | break 162 | if len(tokens_a) > len(tokens_b): 163 | tokens_a.pop() 164 | else: 165 | tokens_b.pop() 166 | 167 | 168 | def read_examples(input_file): 169 | """Read a list of `InputExample`s from an input file.""" 170 | examples = [] 171 | unique_id = 0 172 | with open(input_file, "r") as reader: 173 | while True: 174 | line = tokenization.convert_to_unicode(reader.readline()) 175 | if not line: 176 | break 177 | line = line.strip() 178 | text_a = None 179 | text_b = None 180 | m = re.match(r"^(.*) \|\|\| (.*)$", line) 181 | if m is None: 182 | text_a = line 183 | else: 184 | text_a = m.group(1) 185 | text_b = m.group(2) 186 | examples.append( 187 | InputExample(unique_id=unique_id, text_a=text_a, text_b=text_b)) 188 | unique_id += 1 189 | return examples 190 | 191 | 192 | def main(): 193 | parser = argparse.ArgumentParser() 194 | 195 | ## Required parameters 196 | parser.add_argument("--input_file", default=None, type=str, required=True) 197 | parser.add_argument("--vocab_file", default=None, type=str, required=True, 198 | help="The vocabulary file that the BERT model was trained on.") 199 | parser.add_argument("--output_file", default=None, type=str, required=True) 200 | parser.add_argument("--bert_config_file", default=None, type=str, required=True, 201 | help="The config json file corresponding to the pre-trained BERT model. " 202 | "This specifies the model architecture.") 203 | parser.add_argument("--init_checkpoint", default=None, type=str, required=True, 204 | help="Initial checkpoint (usually from a pre-trained BERT model).") 205 | 206 | ## Other parameters 207 | parser.add_argument("--layers", default="-1,-2,-3,-4", type=str) 208 | parser.add_argument("--max_seq_length", default=128, type=int, 209 | help="The maximum total input sequence length after WordPiece tokenization. Sequences longer " 210 | "than this will be truncated, and sequences shorter than this will be padded.") 211 | parser.add_argument("--do_lower_case", default=True, action='store_true', 212 | help="Whether to lower case the input text. Should be True for uncased " 213 | "models and False for cased models.") 214 | parser.add_argument("--batch_size", default=32, type=int, help="Batch size for predictions.") 215 | parser.add_argument("--local_rank", 216 | type=int, 217 | default=-1, 218 | help = "local_rank for distributed training on gpus") 219 | 220 | args = parser.parse_args() 221 | 222 | if args.local_rank == -1 or args.no_cuda: 223 | device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu") 224 | n_gpu = torch.cuda.device_count() 225 | else: 226 | device = torch.device("cuda", args.local_rank) 227 | n_gpu = 1 228 | # Initializes the distributed backend which will take care of sychronizing nodes/GPUs 229 | torch.distributed.init_process_group(backend='nccl') 230 | logger.info("device", device, "n_gpu", n_gpu, "distributed training", bool(args.local_rank != -1)) 231 | 232 | layer_indexes = [int(x) for x in args.layers.split(",")] 233 | 234 | bert_config = BertConfig.from_json_file(args.bert_config_file) 235 | 236 | tokenizer = tokenization.FullTokenizer( 237 | vocab_file=args.vocab_file, do_lower_case=args.do_lower_case) 238 | 239 | examples = read_examples(args.input_file) 240 | 241 | features = convert_examples_to_features( 242 | examples=examples, seq_length=args.max_seq_length, tokenizer=tokenizer) 243 | 244 | unique_id_to_feature = {} 245 | for feature in features: 246 | unique_id_to_feature[feature.unique_id] = feature 247 | 248 | model = BertModel(bert_config) 249 | if args.init_checkpoint is not None: 250 | model.load_state_dict(torch.load(args.init_checkpoint, map_location='cpu')) 251 | model.to(device) 252 | 253 | if args.local_rank != -1: 254 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank], 255 | output_device=args.local_rank) 256 | elif n_gpu > 1: 257 | model = torch.nn.DataParallel(model) 258 | 259 | all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long) 260 | all_input_mask = torch.tensor([f.input_mask for f in features], dtype=torch.long) 261 | all_example_index = torch.arange(all_input_ids.size(0), dtype=torch.long) 262 | 263 | eval_data = TensorDataset(all_input_ids, all_input_mask, all_example_index) 264 | if args.local_rank == -1: 265 | eval_sampler = SequentialSampler(eval_data) 266 | else: 267 | eval_sampler = DistributedSampler(eval_data) 268 | eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=args.batch_size) 269 | 270 | model.eval() 271 | with open(args.output_file, "w", encoding='utf-8') as writer: 272 | for input_ids, input_mask, example_indices in eval_dataloader: 273 | input_ids = input_ids.to(device) 274 | input_mask = input_mask.to(device) 275 | 276 | all_encoder_layers, _ = model(input_ids, token_type_ids=None, attention_mask=input_mask) 277 | all_encoder_layers = all_encoder_layers 278 | 279 | for b, example_index in enumerate(example_indices): 280 | feature = features[example_index.item()] 281 | unique_id = int(feature.unique_id) 282 | # feature = unique_id_to_feature[unique_id] 283 | output_json = collections.OrderedDict() 284 | output_json["linex_index"] = unique_id 285 | all_out_features = [] 286 | for (i, token) in enumerate(feature.tokens): 287 | all_layers = [] 288 | for (j, layer_index) in enumerate(layer_indexes): 289 | layer_output = all_encoder_layers[int(layer_index)].detach().cpu().numpy() 290 | layer_output = layer_output[b] 291 | layers = collections.OrderedDict() 292 | layers["index"] = layer_index 293 | layers["values"] = [ 294 | round(x.item(), 6) for x in layer_output[i] 295 | ] 296 | all_layers.append(layers) 297 | out_features = collections.OrderedDict() 298 | out_features["token"] = token 299 | out_features["layers"] = all_layers 300 | all_out_features.append(out_features) 301 | output_json["features"] = all_out_features 302 | writer.write(json.dumps(output_json) + "\n") 303 | 304 | 305 | if __name__ == "__main__": 306 | main() 307 | -------------------------------------------------------------------------------- /codes/fine-tuning/modeling.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """PyTorch BERT model.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import copy 22 | import json 23 | import math 24 | import six 25 | import torch 26 | import torch.nn as nn 27 | from torch.nn import CrossEntropyLoss 28 | 29 | def gelu(x): 30 | """Implementation of the gelu activation function. 31 | For information: OpenAI GPT's gelu is slightly different (and gives slightly different results): 32 | 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) 33 | """ 34 | return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0))) 35 | 36 | 37 | class BertConfig(object): 38 | """Configuration class to store the configuration of a `BertModel`. 39 | """ 40 | def __init__(self, 41 | vocab_size, 42 | hidden_size=768, 43 | num_hidden_layers=12, 44 | num_attention_heads=12, 45 | intermediate_size=3072, 46 | hidden_act="gelu", 47 | hidden_dropout_prob=0.1, 48 | attention_probs_dropout_prob=0.1, 49 | max_position_embeddings=512, 50 | type_vocab_size=16, 51 | initializer_range=0.02): 52 | """Constructs BertConfig. 53 | 54 | Args: 55 | vocab_size: Vocabulary size of `inputs_ids` in `BertModel`. 56 | hidden_size: Size of the encoder layers and the pooler layer. 57 | num_hidden_layers: Number of hidden layers in the Transformer encoder. 58 | num_attention_heads: Number of attention heads for each attention layer in 59 | the Transformer encoder. 60 | intermediate_size: The size of the "intermediate" (i.e., feed-forward) 61 | layer in the Transformer encoder. 62 | hidden_act: The non-linear activation function (function or string) in the 63 | encoder and pooler. 64 | hidden_dropout_prob: The dropout probabilitiy for all fully connected 65 | layers in the embeddings, encoder, and pooler. 66 | attention_probs_dropout_prob: The dropout ratio for the attention 67 | probabilities. 68 | max_position_embeddings: The maximum sequence length that this model might 69 | ever be used with. Typically set this to something large just in case 70 | (e.g., 512 or 1024 or 2048). 71 | type_vocab_size: The vocabulary size of the `token_type_ids` passed into 72 | `BertModel`. 73 | initializer_range: The sttdev of the truncated_normal_initializer for 74 | initializing all weight matrices. 75 | """ 76 | self.vocab_size = vocab_size 77 | self.hidden_size = hidden_size 78 | self.num_hidden_layers = num_hidden_layers 79 | self.num_attention_heads = num_attention_heads 80 | self.hidden_act = hidden_act 81 | self.intermediate_size = intermediate_size 82 | self.hidden_dropout_prob = hidden_dropout_prob 83 | self.attention_probs_dropout_prob = attention_probs_dropout_prob 84 | self.max_position_embeddings = max_position_embeddings 85 | self.type_vocab_size = type_vocab_size 86 | self.initializer_range = initializer_range 87 | 88 | @classmethod 89 | def from_dict(cls, json_object): 90 | """Constructs a `BertConfig` from a Python dictionary of parameters.""" 91 | config = BertConfig(vocab_size=None) 92 | for (key, value) in six.iteritems(json_object): 93 | config.__dict__[key] = value 94 | return config 95 | 96 | @classmethod 97 | def from_json_file(cls, json_file): 98 | """Constructs a `BertConfig` from a json file of parameters.""" 99 | with open(json_file, "r") as reader: 100 | text = reader.read() 101 | return cls.from_dict(json.loads(text)) 102 | 103 | def to_dict(self): 104 | """Serializes this instance to a Python dictionary.""" 105 | output = copy.deepcopy(self.__dict__) 106 | return output 107 | 108 | def to_json_string(self): 109 | """Serializes this instance to a JSON string.""" 110 | return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n" 111 | 112 | 113 | class BERTLayerNorm(nn.Module): 114 | def __init__(self, config, variance_epsilon=1e-12): 115 | """Construct a layernorm module in the TF style (epsilon inside the square root). 116 | """ 117 | super(BERTLayerNorm, self).__init__() 118 | self.gamma = nn.Parameter(torch.ones(config.hidden_size)) 119 | self.beta = nn.Parameter(torch.zeros(config.hidden_size)) 120 | self.variance_epsilon = variance_epsilon 121 | 122 | def forward(self, x): 123 | u = x.mean(-1, keepdim=True) 124 | s = (x - u).pow(2).mean(-1, keepdim=True) 125 | x = (x - u) / torch.sqrt(s + self.variance_epsilon) 126 | return self.gamma * x + self.beta 127 | 128 | class BERTEmbeddings(nn.Module): 129 | def __init__(self, config): 130 | super(BERTEmbeddings, self).__init__() 131 | """Construct the embedding module from word, position and token_type embeddings. 132 | """ 133 | self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size) 134 | self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) 135 | self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size) 136 | 137 | # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load 138 | # any TensorFlow checkpoint file 139 | self.LayerNorm = BERTLayerNorm(config) 140 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 141 | 142 | def forward(self, input_ids, token_type_ids=None): 143 | seq_length = input_ids.size(1) 144 | position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device) 145 | position_ids = position_ids.unsqueeze(0).expand_as(input_ids) 146 | if token_type_ids is None: 147 | token_type_ids = torch.zeros_like(input_ids) 148 | 149 | words_embeddings = self.word_embeddings(input_ids) 150 | position_embeddings = self.position_embeddings(position_ids) 151 | token_type_embeddings = self.token_type_embeddings(token_type_ids) 152 | 153 | embeddings = words_embeddings + position_embeddings + token_type_embeddings 154 | embeddings = self.LayerNorm(embeddings) 155 | embeddings = self.dropout(embeddings) 156 | return embeddings 157 | 158 | 159 | class BERTSelfAttention(nn.Module): 160 | def __init__(self, config): 161 | super(BERTSelfAttention, self).__init__() 162 | if config.hidden_size % config.num_attention_heads != 0: 163 | raise ValueError( 164 | "The hidden size (%d) is not a multiple of the number of attention " 165 | "heads (%d)" % (config.hidden_size, config.num_attention_heads)) 166 | self.num_attention_heads = config.num_attention_heads 167 | self.attention_head_size = int(config.hidden_size / config.num_attention_heads) 168 | self.all_head_size = self.num_attention_heads * self.attention_head_size 169 | 170 | self.query = nn.Linear(config.hidden_size, self.all_head_size) 171 | self.key = nn.Linear(config.hidden_size, self.all_head_size) 172 | self.value = nn.Linear(config.hidden_size, self.all_head_size) 173 | 174 | self.dropout = nn.Dropout(config.attention_probs_dropout_prob) 175 | 176 | def transpose_for_scores(self, x): 177 | new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) 178 | x = x.view(*new_x_shape) 179 | return x.permute(0, 2, 1, 3) 180 | 181 | def forward(self, hidden_states, attention_mask): 182 | mixed_query_layer = self.query(hidden_states) 183 | mixed_key_layer = self.key(hidden_states) 184 | mixed_value_layer = self.value(hidden_states) 185 | 186 | query_layer = self.transpose_for_scores(mixed_query_layer) 187 | key_layer = self.transpose_for_scores(mixed_key_layer) 188 | value_layer = self.transpose_for_scores(mixed_value_layer) 189 | 190 | # Take the dot product between "query" and "key" to get the raw attention scores. 191 | attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) 192 | attention_scores = attention_scores / math.sqrt(self.attention_head_size) 193 | # Apply the attention mask is (precomputed for all layers in BertModel forward() function) 194 | attention_scores = attention_scores + attention_mask 195 | 196 | # Normalize the attention scores to probabilities. 197 | attention_probs = nn.Softmax(dim=-1)(attention_scores) 198 | 199 | # This is actually dropping out entire tokens to attend to, which might 200 | # seem a bit unusual, but is taken from the original Transformer paper. 201 | attention_probs = self.dropout(attention_probs) 202 | 203 | context_layer = torch.matmul(attention_probs, value_layer) 204 | context_layer = context_layer.permute(0, 2, 1, 3).contiguous() 205 | new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) 206 | context_layer = context_layer.view(*new_context_layer_shape) 207 | return context_layer 208 | 209 | 210 | class BERTSelfOutput(nn.Module): 211 | def __init__(self, config): 212 | super(BERTSelfOutput, self).__init__() 213 | self.dense = nn.Linear(config.hidden_size, config.hidden_size) 214 | self.LayerNorm = BERTLayerNorm(config) 215 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 216 | 217 | def forward(self, hidden_states, input_tensor): 218 | hidden_states = self.dense(hidden_states) 219 | hidden_states = self.dropout(hidden_states) 220 | hidden_states = self.LayerNorm(hidden_states + input_tensor) 221 | return hidden_states 222 | 223 | 224 | class BERTAttention(nn.Module): 225 | def __init__(self, config): 226 | super(BERTAttention, self).__init__() 227 | self.self = BERTSelfAttention(config) 228 | self.output = BERTSelfOutput(config) 229 | 230 | def forward(self, input_tensor, attention_mask): 231 | self_output = self.self(input_tensor, attention_mask) 232 | attention_output = self.output(self_output, input_tensor) 233 | return attention_output 234 | 235 | 236 | class BERTIntermediate(nn.Module): 237 | def __init__(self, config): 238 | super(BERTIntermediate, self).__init__() 239 | self.dense = nn.Linear(config.hidden_size, config.intermediate_size) 240 | self.intermediate_act_fn = gelu 241 | 242 | def forward(self, hidden_states): 243 | hidden_states = self.dense(hidden_states) 244 | hidden_states = self.intermediate_act_fn(hidden_states) 245 | return hidden_states 246 | 247 | 248 | class BERTOutput(nn.Module): 249 | def __init__(self, config): 250 | super(BERTOutput, self).__init__() 251 | self.dense = nn.Linear(config.intermediate_size, config.hidden_size) 252 | self.LayerNorm = BERTLayerNorm(config) 253 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 254 | 255 | def forward(self, hidden_states, input_tensor): 256 | hidden_states = self.dense(hidden_states) 257 | hidden_states = self.dropout(hidden_states) 258 | hidden_states = self.LayerNorm(hidden_states + input_tensor) 259 | return hidden_states 260 | 261 | 262 | class BERTLayer(nn.Module): 263 | def __init__(self, config): 264 | super(BERTLayer, self).__init__() 265 | self.attention = BERTAttention(config) 266 | self.intermediate = BERTIntermediate(config) 267 | self.output = BERTOutput(config) 268 | 269 | def forward(self, hidden_states, attention_mask): 270 | attention_output = self.attention(hidden_states, attention_mask) 271 | intermediate_output = self.intermediate(attention_output) 272 | layer_output = self.output(intermediate_output, attention_output) 273 | return layer_output 274 | 275 | 276 | class BERTEncoder(nn.Module): 277 | def __init__(self, config): 278 | super(BERTEncoder, self).__init__() 279 | layer = BERTLayer(config) 280 | self.layer = nn.ModuleList([copy.deepcopy(layer) for _ in range(config.num_hidden_layers)]) 281 | 282 | def forward(self, hidden_states, attention_mask): 283 | all_encoder_layers = [] 284 | for layer_module in self.layer: 285 | hidden_states = layer_module(hidden_states, attention_mask) 286 | all_encoder_layers.append(hidden_states) 287 | return all_encoder_layers 288 | 289 | 290 | class BERTPooler(nn.Module): 291 | def __init__(self, config): 292 | super(BERTPooler, self).__init__() 293 | self.dense = nn.Linear(config.hidden_size, config.hidden_size) 294 | self.activation = nn.Tanh() 295 | 296 | def forward(self, hidden_states): 297 | # We "pool" the model by simply taking the hidden state corresponding 298 | # to the first token. 299 | first_token_tensor = hidden_states[:, 0] 300 | pooled_output = self.dense(first_token_tensor) 301 | pooled_output = self.activation(pooled_output) 302 | return pooled_output 303 | 304 | 305 | class BertModel(nn.Module): 306 | """BERT model ("Bidirectional Embedding Representations from a Transformer"). 307 | 308 | Example usage: 309 | ```python 310 | # Already been converted into WordPiece token ids 311 | input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) 312 | input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) 313 | token_type_ids = torch.LongTensor([[0, 0, 1], [0, 2, 0]]) 314 | 315 | config = modeling.BertConfig(vocab_size=32000, hidden_size=512, 316 | num_hidden_layers=8, num_attention_heads=6, intermediate_size=1024) 317 | 318 | model = modeling.BertModel(config=config) 319 | all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask) 320 | ``` 321 | """ 322 | def __init__(self, config: BertConfig): 323 | """Constructor for BertModel. 324 | 325 | Args: 326 | config: `BertConfig` instance. 327 | """ 328 | super(BertModel, self).__init__() 329 | self.embeddings = BERTEmbeddings(config) 330 | self.encoder = BERTEncoder(config) 331 | self.pooler = BERTPooler(config) 332 | 333 | def forward(self, input_ids, token_type_ids=None, attention_mask=None): 334 | if attention_mask is None: 335 | attention_mask = torch.ones_like(input_ids) 336 | if token_type_ids is None: 337 | token_type_ids = torch.zeros_like(input_ids) 338 | 339 | # We create a 3D attention mask from a 2D tensor mask. 340 | # Sizes are [batch_size, 1, 1, from_seq_length] 341 | # So we can broadcast to [batch_size, num_heads, to_seq_length, from_seq_length] 342 | # this attention mask is more simple than the triangular masking of causal attention 343 | # used in OpenAI GPT, we just need to prepare the broadcast dimension here. 344 | extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) 345 | 346 | # Since attention_mask is 1.0 for positions we want to attend and 0.0 for 347 | # masked positions, this operation will create a tensor which is 0.0 for 348 | # positions we want to attend and -10000.0 for masked positions. 349 | # Since we are adding it to the raw scores before the softmax, this is 350 | # effectively the same as removing these entirely. 351 | extended_attention_mask = extended_attention_mask.float() 352 | extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 353 | 354 | embedding_output = self.embeddings(input_ids, token_type_ids) 355 | all_encoder_layers = self.encoder(embedding_output, extended_attention_mask) 356 | sequence_output = all_encoder_layers[-1] 357 | pooled_output = self.pooler(sequence_output) 358 | return all_encoder_layers, pooled_output 359 | 360 | class BertForSequenceClassification(nn.Module): 361 | """BERT model for classification. 362 | This module is composed of the BERT model with a linear layer on top of 363 | the pooled output. 364 | 365 | Example usage: 366 | ```python 367 | # Already been converted into WordPiece token ids 368 | input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) 369 | input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) 370 | token_type_ids = torch.LongTensor([[0, 0, 1], [0, 2, 0]]) 371 | 372 | config = BertConfig(vocab_size=32000, hidden_size=512, 373 | num_hidden_layers=8, num_attention_heads=6, intermediate_size=1024) 374 | 375 | num_labels = 2 376 | 377 | model = BertForSequenceClassification(config, num_labels) 378 | logits = model(input_ids, token_type_ids, input_mask) 379 | ``` 380 | """ 381 | def __init__(self, config, num_labels): 382 | super(BertForSequenceClassification, self).__init__() 383 | self.bert = BertModel(config) 384 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 385 | self.classifier = nn.Linear(config.hidden_size, num_labels) 386 | 387 | def init_weights(module): 388 | if isinstance(module, (nn.Linear, nn.Embedding)): 389 | # Slightly different from the TF version which uses truncated_normal for initialization 390 | # cf https://github.com/pytorch/pytorch/pull/5617 391 | module.weight.data.normal_(mean=0.0, std=config.initializer_range) 392 | elif isinstance(module, BERTLayerNorm): 393 | module.beta.data.normal_(mean=0.0, std=config.initializer_range) 394 | module.gamma.data.normal_(mean=0.0, std=config.initializer_range) 395 | if isinstance(module, nn.Linear): 396 | module.bias.data.zero_() 397 | self.apply(init_weights) 398 | 399 | def forward(self, input_ids, token_type_ids, attention_mask, labels=None): 400 | _, pooled_output = self.bert(input_ids, token_type_ids, attention_mask) 401 | pooled_output = self.dropout(pooled_output) 402 | logits = self.classifier(pooled_output) 403 | 404 | if labels is not None: 405 | loss_fct = CrossEntropyLoss() 406 | loss = loss_fct(logits, labels) 407 | return loss, logits 408 | else: 409 | return logits 410 | 411 | 412 | class BertForQuestionAnswering(nn.Module): 413 | """BERT model for Question Answering (span extraction). 414 | This module is composed of the BERT model with a linear layer on top of 415 | the sequence output that computes start_logits and end_logits 416 | 417 | Example usage: 418 | ```python 419 | # Already been converted into WordPiece token ids 420 | input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) 421 | input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) 422 | token_type_ids = torch.LongTensor([[0, 0, 1], [0, 2, 0]]) 423 | 424 | config = BertConfig(vocab_size=32000, hidden_size=512, 425 | num_hidden_layers=8, num_attention_heads=6, intermediate_size=1024) 426 | 427 | model = BertForQuestionAnswering(config) 428 | start_logits, end_logits = model(input_ids, token_type_ids, input_mask) 429 | ``` 430 | """ 431 | def __init__(self, config): 432 | super(BertForQuestionAnswering, self).__init__() 433 | self.bert = BertModel(config) 434 | # TODO check with Google if it's normal there is no dropout on the token classifier of SQuAD in the TF version 435 | # self.dropout = nn.Dropout(config.hidden_dropout_prob) 436 | self.qa_outputs = nn.Linear(config.hidden_size, 2) 437 | 438 | def init_weights(module): 439 | if isinstance(module, (nn.Linear, nn.Embedding)): 440 | # Slightly different from the TF version which uses truncated_normal for initialization 441 | # cf https://github.com/pytorch/pytorch/pull/5617 442 | module.weight.data.normal_(mean=0.0, std=config.initializer_range) 443 | elif isinstance(module, BERTLayerNorm): 444 | module.beta.data.normal_(mean=0.0, std=config.initializer_range) 445 | module.gamma.data.normal_(mean=0.0, std=config.initializer_range) 446 | if isinstance(module, nn.Linear): 447 | module.bias.data.zero_() 448 | self.apply(init_weights) 449 | 450 | def forward(self, input_ids, token_type_ids, attention_mask, start_positions=None, end_positions=None): 451 | all_encoder_layers, _ = self.bert(input_ids, token_type_ids, attention_mask) 452 | sequence_output = all_encoder_layers[-1] 453 | logits = self.qa_outputs(sequence_output) 454 | start_logits, end_logits = logits.split(1, dim=-1) 455 | start_logits = start_logits.squeeze(-1) 456 | end_logits = end_logits.squeeze(-1) 457 | 458 | if start_positions is not None and end_positions is not None: 459 | # If we are on multi-GPU, split add a dimension - if not this is a no-op 460 | start_positions = start_positions.squeeze(-1) 461 | end_positions = end_positions.squeeze(-1) 462 | # sometimes the start/end positions are outside our model inputs, we ignore these terms 463 | ignored_index = start_logits.size(1) 464 | start_positions.clamp_(0, ignored_index) 465 | end_positions.clamp_(0, ignored_index) 466 | 467 | loss_fct = CrossEntropyLoss(ignore_index=ignored_index) 468 | start_loss = loss_fct(start_logits, start_positions) 469 | end_loss = loss_fct(end_logits, end_positions) 470 | total_loss = (start_loss + end_loss) / 2 471 | return total_loss 472 | else: 473 | return start_logits, end_logits 474 | -------------------------------------------------------------------------------- /codes/fine-tuning/modeling_last_concat_avg.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """PyTorch BERT model.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import copy 22 | import json 23 | import math 24 | import six 25 | import torch 26 | import torch.nn as nn 27 | from torch.nn import CrossEntropyLoss 28 | import torch.nn.functional as F 29 | 30 | def gelu(x): 31 | """Implementation of the gelu activation function. 32 | For information: OpenAI GPT's gelu is slightly different (and gives slightly different results): 33 | 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) 34 | """ 35 | return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0))) 36 | 37 | 38 | class BertConfig(object): 39 | """Configuration class to store the configuration of a `BertModel`. 40 | """ 41 | def __init__(self, 42 | vocab_size, 43 | hidden_size=768, 44 | num_hidden_layers=12, 45 | num_attention_heads=12, 46 | intermediate_size=3072, 47 | hidden_act="gelu", 48 | hidden_dropout_prob=0.1, 49 | attention_probs_dropout_prob=0.1, 50 | max_position_embeddings=512, 51 | type_vocab_size=16, 52 | initializer_range=0.02): 53 | """Constructs BertConfig. 54 | 55 | Args: 56 | vocab_size: Vocabulary size of `inputs_ids` in `BertModel`. 57 | hidden_size: Size of the encoder layers and the pooler layer. 58 | num_hidden_layers: Number of hidden layers in the Transformer encoder. 59 | num_attention_heads: Number of attention heads for each attention layer in 60 | the Transformer encoder. 61 | intermediate_size: The size of the "intermediate" (i.e., feed-forward) 62 | layer in the Transformer encoder. 63 | hidden_act: The non-linear activation function (function or string) in the 64 | encoder and pooler. 65 | hidden_dropout_prob: The dropout probabilitiy for all fully connected 66 | layers in the embeddings, encoder, and pooler. 67 | attention_probs_dropout_prob: The dropout ratio for the attention 68 | probabilities. 69 | max_position_embeddings: The maximum sequence length that this model might 70 | ever be used with. Typically set this to something large just in case 71 | (e.g., 512 or 1024 or 2048). 72 | type_vocab_size: The vocabulary size of the `token_type_ids` passed into 73 | `BertModel`. 74 | initializer_range: The sttdev of the truncated_normal_initializer for 75 | initializing all weight matrices. 76 | """ 77 | self.vocab_size = vocab_size 78 | self.hidden_size = hidden_size 79 | self.num_hidden_layers = num_hidden_layers 80 | self.num_attention_heads = num_attention_heads 81 | self.hidden_act = hidden_act 82 | self.intermediate_size = intermediate_size 83 | self.hidden_dropout_prob = hidden_dropout_prob 84 | self.attention_probs_dropout_prob = attention_probs_dropout_prob 85 | self.max_position_embeddings = max_position_embeddings 86 | self.type_vocab_size = type_vocab_size 87 | self.initializer_range = initializer_range 88 | 89 | @classmethod 90 | def from_dict(cls, json_object): 91 | """Constructs a `BertConfig` from a Python dictionary of parameters.""" 92 | config = BertConfig(vocab_size=None) 93 | for (key, value) in six.iteritems(json_object): 94 | config.__dict__[key] = value 95 | return config 96 | 97 | @classmethod 98 | def from_json_file(cls, json_file): 99 | """Constructs a `BertConfig` from a json file of parameters.""" 100 | with open(json_file, "r") as reader: 101 | text = reader.read() 102 | return cls.from_dict(json.loads(text)) 103 | 104 | def to_dict(self): 105 | """Serializes this instance to a Python dictionary.""" 106 | output = copy.deepcopy(self.__dict__) 107 | return output 108 | 109 | def to_json_string(self): 110 | """Serializes this instance to a JSON string.""" 111 | return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n" 112 | 113 | 114 | class BERTLayerNorm(nn.Module): 115 | def __init__(self, config, variance_epsilon=1e-12): 116 | """Construct a layernorm module in the TF style (epsilon inside the square root). 117 | """ 118 | super(BERTLayerNorm, self).__init__() 119 | self.gamma = nn.Parameter(torch.ones(config.hidden_size)) 120 | self.beta = nn.Parameter(torch.zeros(config.hidden_size)) 121 | self.variance_epsilon = variance_epsilon 122 | 123 | def forward(self, x): 124 | u = x.mean(-1, keepdim=True) 125 | s = (x - u).pow(2).mean(-1, keepdim=True) 126 | x = (x - u) / torch.sqrt(s + self.variance_epsilon) 127 | return self.gamma * x + self.beta 128 | 129 | class BERTEmbeddings(nn.Module): 130 | def __init__(self, config): 131 | super(BERTEmbeddings, self).__init__() 132 | """Construct the embedding module from word, position and token_type embeddings. 133 | """ 134 | self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size) 135 | self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) 136 | self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size) 137 | 138 | # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load 139 | # any TensorFlow checkpoint file 140 | self.LayerNorm = BERTLayerNorm(config) 141 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 142 | 143 | def forward(self, input_ids, token_type_ids=None): 144 | seq_length = input_ids.size(1) 145 | position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device) 146 | position_ids = position_ids.unsqueeze(0).expand_as(input_ids) 147 | if token_type_ids is None: 148 | token_type_ids = torch.zeros_like(input_ids) 149 | 150 | words_embeddings = self.word_embeddings(input_ids) 151 | position_embeddings = self.position_embeddings(position_ids) 152 | token_type_embeddings = self.token_type_embeddings(token_type_ids) 153 | 154 | embeddings = words_embeddings + position_embeddings + token_type_embeddings 155 | embeddings = self.LayerNorm(embeddings) 156 | embeddings = self.dropout(embeddings) 157 | return embeddings 158 | 159 | 160 | class BERTSelfAttention(nn.Module): 161 | def __init__(self, config): 162 | super(BERTSelfAttention, self).__init__() 163 | if config.hidden_size % config.num_attention_heads != 0: 164 | raise ValueError( 165 | "The hidden size (%d) is not a multiple of the number of attention " 166 | "heads (%d)" % (config.hidden_size, config.num_attention_heads)) 167 | self.num_attention_heads = config.num_attention_heads 168 | self.attention_head_size = int(config.hidden_size / config.num_attention_heads) 169 | self.all_head_size = self.num_attention_heads * self.attention_head_size 170 | 171 | self.query = nn.Linear(config.hidden_size, self.all_head_size) 172 | self.key = nn.Linear(config.hidden_size, self.all_head_size) 173 | self.value = nn.Linear(config.hidden_size, self.all_head_size) 174 | 175 | self.dropout = nn.Dropout(config.attention_probs_dropout_prob) 176 | 177 | def transpose_for_scores(self, x): 178 | new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) 179 | x = x.view(*new_x_shape) 180 | return x.permute(0, 2, 1, 3) 181 | 182 | def forward(self, hidden_states, attention_mask): 183 | mixed_query_layer = self.query(hidden_states) 184 | mixed_key_layer = self.key(hidden_states) 185 | mixed_value_layer = self.value(hidden_states) 186 | 187 | query_layer = self.transpose_for_scores(mixed_query_layer) 188 | key_layer = self.transpose_for_scores(mixed_key_layer) 189 | value_layer = self.transpose_for_scores(mixed_value_layer) 190 | 191 | # Take the dot product between "query" and "key" to get the raw attention scores. 192 | attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) 193 | attention_scores = attention_scores / math.sqrt(self.attention_head_size) 194 | # Apply the attention mask is (precomputed for all layers in BertModel forward() function) 195 | attention_scores = attention_scores + attention_mask 196 | 197 | # Normalize the attention scores to probabilities. 198 | attention_probs = nn.Softmax(dim=-1)(attention_scores) 199 | 200 | # This is actually dropping out entire tokens to attend to, which might 201 | # seem a bit unusual, but is taken from the original Transformer paper. 202 | attention_probs = self.dropout(attention_probs) 203 | 204 | context_layer = torch.matmul(attention_probs, value_layer) 205 | context_layer = context_layer.permute(0, 2, 1, 3).contiguous() 206 | new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) 207 | context_layer = context_layer.view(*new_context_layer_shape) 208 | return context_layer 209 | 210 | 211 | class BERTSelfOutput(nn.Module): 212 | def __init__(self, config): 213 | super(BERTSelfOutput, self).__init__() 214 | self.dense = nn.Linear(config.hidden_size, config.hidden_size) 215 | self.LayerNorm = BERTLayerNorm(config) 216 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 217 | 218 | def forward(self, hidden_states, input_tensor): 219 | hidden_states = self.dense(hidden_states) 220 | hidden_states = self.dropout(hidden_states) 221 | hidden_states = self.LayerNorm(hidden_states + input_tensor) 222 | return hidden_states 223 | 224 | 225 | class BERTAttention(nn.Module): 226 | def __init__(self, config): 227 | super(BERTAttention, self).__init__() 228 | self.self = BERTSelfAttention(config) 229 | self.output = BERTSelfOutput(config) 230 | 231 | def forward(self, input_tensor, attention_mask): 232 | self_output = self.self(input_tensor, attention_mask) 233 | attention_output = self.output(self_output, input_tensor) 234 | return attention_output 235 | 236 | 237 | class BERTIntermediate(nn.Module): 238 | def __init__(self, config): 239 | super(BERTIntermediate, self).__init__() 240 | self.dense = nn.Linear(config.hidden_size, config.intermediate_size) 241 | self.intermediate_act_fn = gelu 242 | 243 | def forward(self, hidden_states): 244 | hidden_states = self.dense(hidden_states) 245 | hidden_states = self.intermediate_act_fn(hidden_states) 246 | return hidden_states 247 | 248 | 249 | class BERTOutput(nn.Module): 250 | def __init__(self, config): 251 | super(BERTOutput, self).__init__() 252 | self.dense = nn.Linear(config.intermediate_size, config.hidden_size) 253 | self.LayerNorm = BERTLayerNorm(config) 254 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 255 | 256 | def forward(self, hidden_states, input_tensor): 257 | hidden_states = self.dense(hidden_states) 258 | hidden_states = self.dropout(hidden_states) 259 | hidden_states = self.LayerNorm(hidden_states + input_tensor) 260 | return hidden_states 261 | 262 | 263 | class BERTLayer(nn.Module): 264 | def __init__(self, config): 265 | super(BERTLayer, self).__init__() 266 | self.attention = BERTAttention(config) 267 | self.intermediate = BERTIntermediate(config) 268 | self.output = BERTOutput(config) 269 | 270 | def forward(self, hidden_states, attention_mask): 271 | attention_output = self.attention(hidden_states, attention_mask) 272 | intermediate_output = self.intermediate(attention_output) 273 | layer_output = self.output(intermediate_output, attention_output) 274 | return layer_output 275 | 276 | 277 | class BERTEncoder(nn.Module): 278 | def __init__(self, config): 279 | super(BERTEncoder, self).__init__() 280 | layer = BERTLayer(config) 281 | self.layer = nn.ModuleList([copy.deepcopy(layer) for _ in range(config.num_hidden_layers)]) 282 | 283 | def forward(self, hidden_states, attention_mask): 284 | all_encoder_layers = [] 285 | for layer_module in self.layer: 286 | hidden_states = layer_module(hidden_states, attention_mask) 287 | all_encoder_layers.append(hidden_states) 288 | return all_encoder_layers 289 | 290 | 291 | class BERTPooler(nn.Module): 292 | def __init__(self, config): 293 | super(BERTPooler, self).__init__() 294 | self.dense = nn.Linear(config.hidden_size, config.hidden_size) 295 | self.activation = nn.Tanh() 296 | 297 | def forward(self, hidden_states): 298 | # We "pool" the model by simply taking the hidden state corresponding 299 | # to the first token. 300 | first_token_tensor = hidden_states[:, 0] 301 | pooled_output = self.dense(first_token_tensor) 302 | pooled_output = self.activation(pooled_output) 303 | return pooled_output 304 | 305 | 306 | class BertModel(nn.Module): 307 | """BERT model ("Bidirectional Embedding Representations from a Transformer"). 308 | 309 | Example usage: 310 | ```python 311 | # Already been converted into WordPiece token ids 312 | input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) 313 | input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) 314 | token_type_ids = torch.LongTensor([[0, 0, 1], [0, 2, 0]]) 315 | 316 | config = modeling.BertConfig(vocab_size=32000, hidden_size=512, 317 | num_hidden_layers=8, num_attention_heads=6, intermediate_size=1024) 318 | 319 | model = modeling.BertModel(config=config) 320 | all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask) 321 | ``` 322 | """ 323 | def __init__(self, config: BertConfig): 324 | """Constructor for BertModel. 325 | 326 | Args: 327 | config: `BertConfig` instance. 328 | """ 329 | super(BertModel, self).__init__() 330 | self.embeddings = BERTEmbeddings(config) 331 | self.encoder = BERTEncoder(config) 332 | self.pooler = BERTPooler(config) 333 | 334 | def forward(self, input_ids, token_type_ids=None, attention_mask=None): 335 | if attention_mask is None: 336 | attention_mask = torch.ones_like(input_ids) 337 | if token_type_ids is None: 338 | token_type_ids = torch.zeros_like(input_ids) 339 | 340 | # We create a 3D attention mask from a 2D tensor mask. 341 | # Sizes are [batch_size, 1, 1, from_seq_length] 342 | # So we can broadcast to [batch_size, num_heads, to_seq_length, from_seq_length] 343 | # this attention mask is more simple than the triangular masking of causal attention 344 | # used in OpenAI GPT, we just need to prepare the broadcast dimension here. 345 | extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) 346 | 347 | # Since attention_mask is 1.0 for positions we want to attend and 0.0 for 348 | # masked positions, this operation will create a tensor which is 0.0 for 349 | # positions we want to attend and -10000.0 for masked positions. 350 | # Since we are adding it to the raw scores before the softmax, this is 351 | # effectively the same as removing these entirely. 352 | extended_attention_mask = extended_attention_mask.float() 353 | extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 354 | 355 | embedding_output = self.embeddings(input_ids, token_type_ids) 356 | all_encoder_layers = self.encoder(embedding_output, extended_attention_mask) 357 | sequence_output = all_encoder_layers[-1] 358 | pooled_output = self.pooler(sequence_output) 359 | sequence_output2 = all_encoder_layers[-2] 360 | pooled_output2 = self.pooler(sequence_output2) 361 | sequence_output3 = all_encoder_layers[-3] 362 | pooled_output3 = self.pooler(sequence_output3) 363 | sequence_output4 = all_encoder_layers[-4] 364 | pooled_output4 = self.pooler(sequence_output4) 365 | sequence_output5 = all_encoder_layers[-5] 366 | pooled_output5 = self.pooler(sequence_output5) 367 | sequence_output6 = all_encoder_layers[-6] 368 | pooled_output6 = self.pooler(sequence_output6) 369 | sequence_output7 = all_encoder_layers[-7] 370 | pooled_output7 = self.pooler(sequence_output7) 371 | sequence_output8 = all_encoder_layers[-8] 372 | pooled_output8 = self.pooler(sequence_output8) 373 | sequence_output9 = all_encoder_layers[-9] 374 | pooled_output9 = self.pooler(sequence_output9) 375 | sequence_output10 = all_encoder_layers[-10] 376 | pooled_output10 = self.pooler(sequence_output10) 377 | sequence_output11 = all_encoder_layers[-11] 378 | pooled_output11 = self.pooler(sequence_output11) 379 | sequence_output12 = all_encoder_layers[-12] 380 | pooled_output12 = self.pooler(sequence_output12) 381 | return all_encoder_layers, pooled_output, pooled_output2, pooled_output3, pooled_output4, pooled_output5, pooled_output6, pooled_output7, pooled_output8, pooled_output9, pooled_output10, pooled_output11, pooled_output12 382 | 383 | class BertForSequenceClassification(nn.Module): 384 | """BERT model for classification. 385 | This module is composed of the BERT model with a linear layer on top of 386 | the pooled output. 387 | 388 | Example usage: 389 | ```python 390 | # Already been converted into WordPiece token ids 391 | input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) 392 | input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) 393 | token_type_ids = torch.LongTensor([[0, 0, 1], [0, 2, 0]]) 394 | 395 | config = BertConfig(vocab_size=32000, hidden_size=512, 396 | num_hidden_layers=8, num_attention_heads=6, intermediate_size=1024) 397 | 398 | num_labels = 2 399 | 400 | model = BertForSequenceClassification(config, num_labels) 401 | logits = model(input_ids, token_type_ids, input_mask) 402 | ``` 403 | """ 404 | def __init__(self, config, num_labels): 405 | super(BertForSequenceClassification, self).__init__() 406 | self.d_a=128 407 | self.attn_heads=1 408 | 409 | self.bert = BertModel(config) 410 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 411 | self.classifier = nn.Linear(self.attn_heads*config.hidden_size, num_labels) 412 | 413 | self.linear_first = torch.nn.Linear(config.hidden_size, self.d_a) 414 | self.linear_first.bias.data.fill_(0) 415 | self.linear_second = torch.nn.Linear(self.d_a, self.attn_heads) 416 | self.linear_second.bias.data.fill_(0) 417 | 418 | def init_weights(module): 419 | if isinstance(module, (nn.Linear, nn.Embedding)): 420 | # Slightly different from the TF version which uses truncated_normal for initialization 421 | # cf https://github.com/pytorch/pytorch/pull/5617 422 | module.weight.data.normal_(mean=0.0, std=config.initializer_range) 423 | elif isinstance(module, BERTLayerNorm): 424 | module.beta.data.normal_(mean=0.0, std=config.initializer_range) 425 | module.gamma.data.normal_(mean=0.0, std=config.initializer_range) 426 | if isinstance(module, nn.Linear): 427 | module.bias.data.zero_() 428 | self.apply(init_weights) 429 | 430 | def forward(self, input_ids, token_type_ids, attention_mask, labels=None): 431 | _, pooled_output, pooled_output2, pooled_output3, pooled_output4, pooled_output5, pooled_output6, pooled_output7, pooled_output8, pooled_output9, pooled_output10, pooled_output11, pooled_output12 = self.bert(input_ids, token_type_ids, attention_mask) 432 | pooled_output = self.dropout(pooled_output).unsqueeze(0) 433 | pooled_output2 = self.dropout(pooled_output2).unsqueeze(0) 434 | pooled_output3 = self.dropout(pooled_output3).unsqueeze(0) 435 | pooled_output4 = self.dropout(pooled_output4).unsqueeze(0) 436 | pooled_output5 = self.dropout(pooled_output5).unsqueeze(0) 437 | pooled_output6 = self.dropout(pooled_output6).unsqueeze(0) 438 | pooled_output7 = self.dropout(pooled_output7).unsqueeze(0) 439 | pooled_output8 = self.dropout(pooled_output8).unsqueeze(0) 440 | pooled_output9 = self.dropout(pooled_output9).unsqueeze(0) 441 | pooled_output10 = self.dropout(pooled_output10).unsqueeze(0) 442 | pooled_output11 = self.dropout(pooled_output11).unsqueeze(0) 443 | pooled_output12 = self.dropout(pooled_output12).unsqueeze(0) # 12, batchsize, hidden_dim 444 | pooled_output = torch.cat((pooled_output, pooled_output2, pooled_output3, pooled_output4, pooled_output5, pooled_output6, pooled_output7, pooled_output8, pooled_output9, pooled_output10, pooled_output11, pooled_output12),0) 445 | 446 | seq_len, batch_size, hidden_dim = pooled_output.size() 447 | a=self.linear_first(pooled_output) #seq_len. batchsize. d_a 448 | a=F.tanh(a) 449 | a=self.linear_second(a) 450 | a=F.softmax(a,dim=0) # seq_len,batchsize,heads 451 | 452 | b = [] 453 | y = [] 454 | for i in range(self.attn_heads): 455 | b.append(a[:, :, i]) 456 | b[i] = b[i].unsqueeze(2).expand(seq_len, batch_size, hidden_dim) 457 | y.append((b[i] * pooled_output).sum(dim=0)) # batchsize, hidden_dim 458 | y_all = torch.cat(y,1) # batchsize, hidden_dim*heads 459 | 460 | 461 | logits = self.classifier(y_all) 462 | 463 | if labels is not None: 464 | loss_fct = CrossEntropyLoss() 465 | loss = loss_fct(logits, labels) 466 | return loss, logits 467 | else: 468 | return logits 469 | 470 | 471 | class BertForQuestionAnswering(nn.Module): 472 | """BERT model for Question Answering (span extraction). 473 | This module is composed of the BERT model with a linear layer on top of 474 | the sequence output that computes start_logits and end_logits 475 | 476 | Example usage: 477 | ```python 478 | # Already been converted into WordPiece token ids 479 | input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) 480 | input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) 481 | token_type_ids = torch.LongTensor([[0, 0, 1], [0, 2, 0]]) 482 | 483 | config = BertConfig(vocab_size=32000, hidden_size=512, 484 | num_hidden_layers=8, num_attention_heads=6, intermediate_size=1024) 485 | 486 | model = BertForQuestionAnswering(config) 487 | start_logits, end_logits = model(input_ids, token_type_ids, input_mask) 488 | ``` 489 | """ 490 | def __init__(self, config): 491 | super(BertForQuestionAnswering, self).__init__() 492 | self.bert = BertModel(config) 493 | # TODO check with Google if it's normal there is no dropout on the token classifier of SQuAD in the TF version 494 | # self.dropout = nn.Dropout(config.hidden_dropout_prob) 495 | self.qa_outputs = nn.Linear(config.hidden_size, 2) 496 | 497 | def init_weights(module): 498 | if isinstance(module, (nn.Linear, nn.Embedding)): 499 | # Slightly different from the TF version which uses truncated_normal for initialization 500 | # cf https://github.com/pytorch/pytorch/pull/5617 501 | module.weight.data.normal_(mean=0.0, std=config.initializer_range) 502 | elif isinstance(module, BERTLayerNorm): 503 | module.beta.data.normal_(mean=0.0, std=config.initializer_range) 504 | module.gamma.data.normal_(mean=0.0, std=config.initializer_range) 505 | if isinstance(module, nn.Linear): 506 | module.bias.data.zero_() 507 | self.apply(init_weights) 508 | 509 | def forward(self, input_ids, token_type_ids, attention_mask, start_positions=None, end_positions=None): 510 | all_encoder_layers, _ = self.bert(input_ids, token_type_ids, attention_mask) 511 | sequence_output = all_encoder_layers[-1] 512 | logits = self.qa_outputs(sequence_output) 513 | start_logits, end_logits = logits.split(1, dim=-1) 514 | start_logits = start_logits.squeeze(-1) 515 | end_logits = end_logits.squeeze(-1) 516 | 517 | if start_positions is not None and end_positions is not None: 518 | # If we are on multi-GPU, split add a dimension - if not this is a no-op 519 | start_positions = start_positions.squeeze(-1) 520 | end_positions = end_positions.squeeze(-1) 521 | # sometimes the start/end positions are outside our model inputs, we ignore these terms 522 | ignored_index = start_logits.size(1) 523 | start_positions.clamp_(0, ignored_index) 524 | end_positions.clamp_(0, ignored_index) 525 | 526 | loss_fct = CrossEntropyLoss(ignore_index=ignored_index) 527 | start_loss = loss_fct(start_logits, start_positions) 528 | end_loss = loss_fct(end_logits, end_positions) 529 | total_loss = (start_loss + end_loss) / 2 530 | return total_loss 531 | else: 532 | return start_logits, end_logits 533 | -------------------------------------------------------------------------------- /codes/fine-tuning/modeling_multitask.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """PyTorch BERT model.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import copy 22 | import json 23 | import math 24 | import six 25 | import torch 26 | import torch.nn as nn 27 | from torch.nn import CrossEntropyLoss 28 | 29 | def gelu(x): 30 | """Implementation of the gelu activation function. 31 | For information: OpenAI GPT's gelu is slightly different (and gives slightly different results): 32 | 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) 33 | """ 34 | return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0))) 35 | 36 | 37 | class BertConfig(object): 38 | """Configuration class to store the configuration of a `BertModel`. 39 | """ 40 | def __init__(self, 41 | vocab_size, 42 | hidden_size=768, 43 | num_hidden_layers=12, 44 | num_attention_heads=12, 45 | intermediate_size=3072, 46 | hidden_act="gelu", 47 | hidden_dropout_prob=0.1, 48 | attention_probs_dropout_prob=0.1, 49 | max_position_embeddings=512, 50 | type_vocab_size=16, 51 | initializer_range=0.02): 52 | """Constructs BertConfig. 53 | 54 | Args: 55 | vocab_size: Vocabulary size of `inputs_ids` in `BertModel`. 56 | hidden_size: Size of the encoder layers and the pooler layer. 57 | num_hidden_layers: Number of hidden layers in the Transformer encoder. 58 | num_attention_heads: Number of attention heads for each attention layer in 59 | the Transformer encoder. 60 | intermediate_size: The size of the "intermediate" (i.e., feed-forward) 61 | layer in the Transformer encoder. 62 | hidden_act: The non-linear activation function (function or string) in the 63 | encoder and pooler. 64 | hidden_dropout_prob: The dropout probabilitiy for all fully connected 65 | layers in the embeddings, encoder, and pooler. 66 | attention_probs_dropout_prob: The dropout ratio for the attention 67 | probabilities. 68 | max_position_embeddings: The maximum sequence length that this model might 69 | ever be used with. Typically set this to something large just in case 70 | (e.g., 512 or 1024 or 2048). 71 | type_vocab_size: The vocabulary size of the `token_type_ids` passed into 72 | `BertModel`. 73 | initializer_range: The sttdev of the truncated_normal_initializer for 74 | initializing all weight matrices. 75 | """ 76 | self.vocab_size = vocab_size 77 | self.hidden_size = hidden_size 78 | self.num_hidden_layers = num_hidden_layers 79 | self.num_attention_heads = num_attention_heads 80 | self.hidden_act = hidden_act 81 | self.intermediate_size = intermediate_size 82 | self.hidden_dropout_prob = hidden_dropout_prob 83 | self.attention_probs_dropout_prob = attention_probs_dropout_prob 84 | self.max_position_embeddings = max_position_embeddings 85 | self.type_vocab_size = type_vocab_size 86 | self.initializer_range = initializer_range 87 | 88 | @classmethod 89 | def from_dict(cls, json_object): 90 | """Constructs a `BertConfig` from a Python dictionary of parameters.""" 91 | config = BertConfig(vocab_size=None) 92 | for (key, value) in six.iteritems(json_object): 93 | config.__dict__[key] = value 94 | return config 95 | 96 | @classmethod 97 | def from_json_file(cls, json_file): 98 | """Constructs a `BertConfig` from a json file of parameters.""" 99 | with open(json_file, "r") as reader: 100 | text = reader.read() 101 | return cls.from_dict(json.loads(text)) 102 | 103 | def to_dict(self): 104 | """Serializes this instance to a Python dictionary.""" 105 | output = copy.deepcopy(self.__dict__) 106 | return output 107 | 108 | def to_json_string(self): 109 | """Serializes this instance to a JSON string.""" 110 | return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n" 111 | 112 | 113 | class BERTLayerNorm(nn.Module): 114 | def __init__(self, config, variance_epsilon=1e-12): 115 | """Construct a layernorm module in the TF style (epsilon inside the square root). 116 | """ 117 | super(BERTLayerNorm, self).__init__() 118 | self.gamma = nn.Parameter(torch.ones(config.hidden_size)) 119 | self.beta = nn.Parameter(torch.zeros(config.hidden_size)) 120 | self.variance_epsilon = variance_epsilon 121 | 122 | def forward(self, x): 123 | u = x.mean(-1, keepdim=True) 124 | s = (x - u).pow(2).mean(-1, keepdim=True) 125 | x = (x - u) / torch.sqrt(s + self.variance_epsilon) 126 | return self.gamma * x + self.beta 127 | 128 | class BERTEmbeddings(nn.Module): 129 | def __init__(self, config): 130 | super(BERTEmbeddings, self).__init__() 131 | """Construct the embedding module from word, position and token_type embeddings. 132 | """ 133 | self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size) 134 | self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) 135 | self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size) 136 | 137 | # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load 138 | # any TensorFlow checkpoint file 139 | self.LayerNorm = BERTLayerNorm(config) 140 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 141 | 142 | def forward(self, input_ids, token_type_ids=None): 143 | seq_length = input_ids.size(1) 144 | position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device) 145 | position_ids = position_ids.unsqueeze(0).expand_as(input_ids) 146 | if token_type_ids is None: 147 | token_type_ids = torch.zeros_like(input_ids) 148 | 149 | words_embeddings = self.word_embeddings(input_ids) 150 | position_embeddings = self.position_embeddings(position_ids) 151 | token_type_embeddings = self.token_type_embeddings(token_type_ids) 152 | 153 | embeddings = words_embeddings + position_embeddings + token_type_embeddings 154 | embeddings = self.LayerNorm(embeddings) 155 | embeddings = self.dropout(embeddings) 156 | return embeddings 157 | 158 | 159 | class BERTSelfAttention(nn.Module): 160 | def __init__(self, config): 161 | super(BERTSelfAttention, self).__init__() 162 | if config.hidden_size % config.num_attention_heads != 0: 163 | raise ValueError( 164 | "The hidden size (%d) is not a multiple of the number of attention " 165 | "heads (%d)" % (config.hidden_size, config.num_attention_heads)) 166 | self.num_attention_heads = config.num_attention_heads 167 | self.attention_head_size = int(config.hidden_size / config.num_attention_heads) 168 | self.all_head_size = self.num_attention_heads * self.attention_head_size 169 | 170 | self.query = nn.Linear(config.hidden_size, self.all_head_size) 171 | self.key = nn.Linear(config.hidden_size, self.all_head_size) 172 | self.value = nn.Linear(config.hidden_size, self.all_head_size) 173 | 174 | self.dropout = nn.Dropout(config.attention_probs_dropout_prob) 175 | 176 | def transpose_for_scores(self, x): 177 | new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) 178 | x = x.view(*new_x_shape) 179 | return x.permute(0, 2, 1, 3) 180 | 181 | def forward(self, hidden_states, attention_mask): 182 | mixed_query_layer = self.query(hidden_states) 183 | mixed_key_layer = self.key(hidden_states) 184 | mixed_value_layer = self.value(hidden_states) 185 | 186 | query_layer = self.transpose_for_scores(mixed_query_layer) 187 | key_layer = self.transpose_for_scores(mixed_key_layer) 188 | value_layer = self.transpose_for_scores(mixed_value_layer) 189 | 190 | # Take the dot product between "query" and "key" to get the raw attention scores. 191 | attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) 192 | attention_scores = attention_scores / math.sqrt(self.attention_head_size) 193 | # Apply the attention mask is (precomputed for all layers in BertModel forward() function) 194 | attention_scores = attention_scores + attention_mask 195 | 196 | # Normalize the attention scores to probabilities. 197 | attention_probs = nn.Softmax(dim=-1)(attention_scores) 198 | 199 | # This is actually dropping out entire tokens to attend to, which might 200 | # seem a bit unusual, but is taken from the original Transformer paper. 201 | attention_probs = self.dropout(attention_probs) 202 | 203 | context_layer = torch.matmul(attention_probs, value_layer) 204 | context_layer = context_layer.permute(0, 2, 1, 3).contiguous() 205 | new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) 206 | context_layer = context_layer.view(*new_context_layer_shape) 207 | return context_layer 208 | 209 | 210 | class BERTSelfOutput(nn.Module): 211 | def __init__(self, config): 212 | super(BERTSelfOutput, self).__init__() 213 | self.dense = nn.Linear(config.hidden_size, config.hidden_size) 214 | self.LayerNorm = BERTLayerNorm(config) 215 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 216 | 217 | def forward(self, hidden_states, input_tensor): 218 | hidden_states = self.dense(hidden_states) 219 | hidden_states = self.dropout(hidden_states) 220 | hidden_states = self.LayerNorm(hidden_states + input_tensor) 221 | return hidden_states 222 | 223 | 224 | class BERTAttention(nn.Module): 225 | def __init__(self, config): 226 | super(BERTAttention, self).__init__() 227 | self.self = BERTSelfAttention(config) 228 | self.output = BERTSelfOutput(config) 229 | 230 | def forward(self, input_tensor, attention_mask): 231 | self_output = self.self(input_tensor, attention_mask) 232 | attention_output = self.output(self_output, input_tensor) 233 | return attention_output 234 | 235 | 236 | class BERTIntermediate(nn.Module): 237 | def __init__(self, config): 238 | super(BERTIntermediate, self).__init__() 239 | self.dense = nn.Linear(config.hidden_size, config.intermediate_size) 240 | self.intermediate_act_fn = gelu 241 | 242 | def forward(self, hidden_states): 243 | hidden_states = self.dense(hidden_states) 244 | hidden_states = self.intermediate_act_fn(hidden_states) 245 | return hidden_states 246 | 247 | 248 | class BERTOutput(nn.Module): 249 | def __init__(self, config): 250 | super(BERTOutput, self).__init__() 251 | self.dense = nn.Linear(config.intermediate_size, config.hidden_size) 252 | self.LayerNorm = BERTLayerNorm(config) 253 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 254 | 255 | def forward(self, hidden_states, input_tensor): 256 | hidden_states = self.dense(hidden_states) 257 | hidden_states = self.dropout(hidden_states) 258 | hidden_states = self.LayerNorm(hidden_states + input_tensor) 259 | return hidden_states 260 | 261 | 262 | class BERTLayer(nn.Module): 263 | def __init__(self, config): 264 | super(BERTLayer, self).__init__() 265 | self.attention = BERTAttention(config) 266 | self.intermediate = BERTIntermediate(config) 267 | self.output = BERTOutput(config) 268 | 269 | def forward(self, hidden_states, attention_mask): 270 | attention_output = self.attention(hidden_states, attention_mask) 271 | intermediate_output = self.intermediate(attention_output) 272 | layer_output = self.output(intermediate_output, attention_output) 273 | return layer_output 274 | 275 | 276 | class BERTEncoder(nn.Module): 277 | def __init__(self, config): 278 | super(BERTEncoder, self).__init__() 279 | layer = BERTLayer(config) 280 | self.layer = nn.ModuleList([copy.deepcopy(layer) for _ in range(config.num_hidden_layers)]) 281 | 282 | def forward(self, hidden_states, attention_mask): 283 | all_encoder_layers = [] 284 | for layer_module in self.layer: 285 | hidden_states = layer_module(hidden_states, attention_mask) 286 | all_encoder_layers.append(hidden_states) 287 | return all_encoder_layers 288 | 289 | 290 | class BERTPooler(nn.Module): 291 | def __init__(self, config): 292 | super(BERTPooler, self).__init__() 293 | self.dense = nn.Linear(config.hidden_size, config.hidden_size) 294 | self.activation = nn.Tanh() 295 | 296 | def forward(self, hidden_states): 297 | # We "pool" the model by simply taking the hidden state corresponding 298 | # to the first token. 299 | first_token_tensor = hidden_states[:, 0] 300 | pooled_output = self.dense(first_token_tensor) 301 | pooled_output = self.activation(pooled_output) 302 | return pooled_output 303 | 304 | 305 | class BertModel(nn.Module): 306 | """BERT model ("Bidirectional Embedding Representations from a Transformer"). 307 | 308 | Example usage: 309 | ```python 310 | # Already been converted into WordPiece token ids 311 | input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) 312 | input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) 313 | token_type_ids = torch.LongTensor([[0, 0, 1], [0, 2, 0]]) 314 | 315 | config = modeling.BertConfig(vocab_size=32000, hidden_size=512, 316 | num_hidden_layers=8, num_attention_heads=6, intermediate_size=1024) 317 | 318 | model = modeling.BertModel(config=config) 319 | all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask) 320 | ``` 321 | """ 322 | def __init__(self, config: BertConfig): 323 | """Constructor for BertModel. 324 | 325 | Args: 326 | config: `BertConfig` instance. 327 | """ 328 | super(BertModel, self).__init__() 329 | self.embeddings = BERTEmbeddings(config) 330 | self.encoder = BERTEncoder(config) 331 | self.pooler = BERTPooler(config) 332 | 333 | def forward(self, input_ids, token_type_ids=None, attention_mask=None): 334 | if attention_mask is None: 335 | attention_mask = torch.ones_like(input_ids) 336 | if token_type_ids is None: 337 | token_type_ids = torch.zeros_like(input_ids) 338 | 339 | # We create a 3D attention mask from a 2D tensor mask. 340 | # Sizes are [batch_size, 1, 1, from_seq_length] 341 | # So we can broadcast to [batch_size, num_heads, to_seq_length, from_seq_length] 342 | # this attention mask is more simple than the triangular masking of causal attention 343 | # used in OpenAI GPT, we just need to prepare the broadcast dimension here. 344 | extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) 345 | 346 | # Since attention_mask is 1.0 for positions we want to attend and 0.0 for 347 | # masked positions, this operation will create a tensor which is 0.0 for 348 | # positions we want to attend and -10000.0 for masked positions. 349 | # Since we are adding it to the raw scores before the softmax, this is 350 | # effectively the same as removing these entirely. 351 | extended_attention_mask = extended_attention_mask.float() 352 | extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 353 | 354 | embedding_output = self.embeddings(input_ids, token_type_ids) 355 | all_encoder_layers = self.encoder(embedding_output, extended_attention_mask) 356 | sequence_output = all_encoder_layers[-1] 357 | pooled_output = self.pooler(sequence_output) 358 | return all_encoder_layers, pooled_output 359 | 360 | class BertForSequenceClassification(nn.Module): 361 | """BERT model for classification. 362 | This module is composed of the BERT model with a linear layer on top of 363 | the pooled output. 364 | 365 | Example usage: 366 | ```python 367 | # Already been converted into WordPiece token ids 368 | input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) 369 | input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) 370 | token_type_ids = torch.LongTensor([[0, 0, 1], [0, 2, 0]]) 371 | 372 | config = BertConfig(vocab_size=32000, hidden_size=512, 373 | num_hidden_layers=8, num_attention_heads=6, intermediate_size=1024) 374 | 375 | num_labels = 2 376 | 377 | model = BertForSequenceClassification(config, num_labels) 378 | logits = model(input_ids, token_type_ids, input_mask) 379 | ``` 380 | """ 381 | def __init__(self, config, num_labels): 382 | super(BertForSequenceClassification, self).__init__() 383 | self.bert = BertModel(config) 384 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 385 | self.classifier_1 = nn.Linear(config.hidden_size, 2) 386 | self.classifier_2 = nn.Linear(config.hidden_size, 2) 387 | self.classifier_3 = nn.Linear(config.hidden_size, 4) 388 | self.classifier_4 = nn.Linear(config.hidden_size, 14) 389 | 390 | def init_weights(module): 391 | if isinstance(module, (nn.Linear, nn.Embedding)): 392 | # Slightly different from the TF version which uses truncated_normal for initialization 393 | # cf https://github.com/pytorch/pytorch/pull/5617 394 | module.weight.data.normal_(mean=0.0, std=config.initializer_range) 395 | elif isinstance(module, BERTLayerNorm): 396 | module.beta.data.normal_(mean=0.0, std=config.initializer_range) 397 | module.gamma.data.normal_(mean=0.0, std=config.initializer_range) 398 | if isinstance(module, nn.Linear): 399 | module.bias.data.zero_() 400 | self.apply(init_weights) 401 | 402 | def forward(self, input_ids, token_type_ids, attention_mask, labels=None, dataset_labels=None): 403 | _, pooled_output = self.bert(input_ids, token_type_ids, attention_mask) 404 | pooled_output = self.dropout(pooled_output) 405 | #print("dataset_label=",dataset_labels[0].item()) 406 | if dataset_labels[0].item()==1:logits = self.classifier_1(pooled_output) 407 | if dataset_labels[0].item()==2:logits = self.classifier_2(pooled_output) 408 | if dataset_labels[0].item()==3:logits = self.classifier_3(pooled_output) 409 | if dataset_labels[0].item()==4:logits = self.classifier_4(pooled_output) 410 | if labels is not None: 411 | loss_fct = CrossEntropyLoss() 412 | loss = loss_fct(logits, labels) 413 | return loss, logits 414 | else: 415 | return logits 416 | 417 | 418 | class BertForQuestionAnswering(nn.Module): 419 | """BERT model for Question Answering (span extraction). 420 | This module is composed of the BERT model with a linear layer on top of 421 | the sequence output that computes start_logits and end_logits 422 | 423 | Example usage: 424 | ```python 425 | # Already been converted into WordPiece token ids 426 | input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) 427 | input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) 428 | token_type_ids = torch.LongTensor([[0, 0, 1], [0, 2, 0]]) 429 | 430 | config = BertConfig(vocab_size=32000, hidden_size=512, 431 | num_hidden_layers=8, num_attention_heads=6, intermediate_size=1024) 432 | 433 | model = BertForQuestionAnswering(config) 434 | start_logits, end_logits = model(input_ids, token_type_ids, input_mask) 435 | ``` 436 | """ 437 | def __init__(self, config): 438 | super(BertForQuestionAnswering, self).__init__() 439 | self.bert = BertModel(config) 440 | # TODO check with Google if it's normal there is no dropout on the token classifier of SQuAD in the TF version 441 | # self.dropout = nn.Dropout(config.hidden_dropout_prob) 442 | self.qa_outputs = nn.Linear(config.hidden_size, 2) 443 | 444 | def init_weights(module): 445 | if isinstance(module, (nn.Linear, nn.Embedding)): 446 | # Slightly different from the TF version which uses truncated_normal for initialization 447 | # cf https://github.com/pytorch/pytorch/pull/5617 448 | module.weight.data.normal_(mean=0.0, std=config.initializer_range) 449 | elif isinstance(module, BERTLayerNorm): 450 | module.beta.data.normal_(mean=0.0, std=config.initializer_range) 451 | module.gamma.data.normal_(mean=0.0, std=config.initializer_range) 452 | if isinstance(module, nn.Linear): 453 | module.bias.data.zero_() 454 | self.apply(init_weights) 455 | 456 | def forward(self, input_ids, token_type_ids, attention_mask, start_positions=None, end_positions=None): 457 | all_encoder_layers, _ = self.bert(input_ids, token_type_ids, attention_mask) 458 | sequence_output = all_encoder_layers[-1] 459 | logits = self.qa_outputs(sequence_output) 460 | start_logits, end_logits = logits.split(1, dim=-1) 461 | start_logits = start_logits.squeeze(-1) 462 | end_logits = end_logits.squeeze(-1) 463 | 464 | if start_positions is not None and end_positions is not None: 465 | # If we are on multi-GPU, split add a dimension - if not this is a no-op 466 | start_positions = start_positions.squeeze(-1) 467 | end_positions = end_positions.squeeze(-1) 468 | # sometimes the start/end positions are outside our model inputs, we ignore these terms 469 | ignored_index = start_logits.size(1) 470 | start_positions.clamp_(0, ignored_index) 471 | end_positions.clamp_(0, ignored_index) 472 | 473 | loss_fct = CrossEntropyLoss(ignore_index=ignored_index) 474 | start_loss = loss_fct(start_logits, start_positions) 475 | end_loss = loss_fct(end_logits, end_positions) 476 | total_loss = (start_loss + end_loss) / 2 477 | return total_loss 478 | else: 479 | return start_logits, end_logits 480 | -------------------------------------------------------------------------------- /codes/fine-tuning/modeling_single_layer.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """PyTorch BERT model.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import copy 22 | import json 23 | import math 24 | import six 25 | import torch 26 | import torch.nn as nn 27 | from torch.nn import CrossEntropyLoss 28 | 29 | def gelu(x): 30 | """Implementation of the gelu activation function. 31 | For information: OpenAI GPT's gelu is slightly different (and gives slightly different results): 32 | 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) 33 | """ 34 | return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0))) 35 | 36 | 37 | class BertConfig(object): 38 | """Configuration class to store the configuration of a `BertModel`. 39 | """ 40 | def __init__(self, 41 | vocab_size, 42 | hidden_size=768, 43 | num_hidden_layers=12, 44 | num_attention_heads=12, 45 | intermediate_size=3072, 46 | hidden_act="gelu", 47 | hidden_dropout_prob=0.1, 48 | attention_probs_dropout_prob=0.1, 49 | max_position_embeddings=512, 50 | type_vocab_size=16, 51 | initializer_range=0.02): 52 | """Constructs BertConfig. 53 | 54 | Args: 55 | vocab_size: Vocabulary size of `inputs_ids` in `BertModel`. 56 | hidden_size: Size of the encoder layers and the pooler layer. 57 | num_hidden_layers: Number of hidden layers in the Transformer encoder. 58 | num_attention_heads: Number of attention heads for each attention layer in 59 | the Transformer encoder. 60 | intermediate_size: The size of the "intermediate" (i.e., feed-forward) 61 | layer in the Transformer encoder. 62 | hidden_act: The non-linear activation function (function or string) in the 63 | encoder and pooler. 64 | hidden_dropout_prob: The dropout probabilitiy for all fully connected 65 | layers in the embeddings, encoder, and pooler. 66 | attention_probs_dropout_prob: The dropout ratio for the attention 67 | probabilities. 68 | max_position_embeddings: The maximum sequence length that this model might 69 | ever be used with. Typically set this to something large just in case 70 | (e.g., 512 or 1024 or 2048). 71 | type_vocab_size: The vocabulary size of the `token_type_ids` passed into 72 | `BertModel`. 73 | initializer_range: The sttdev of the truncated_normal_initializer for 74 | initializing all weight matrices. 75 | """ 76 | self.vocab_size = vocab_size 77 | self.hidden_size = hidden_size 78 | self.num_hidden_layers = num_hidden_layers 79 | self.num_attention_heads = num_attention_heads 80 | self.hidden_act = hidden_act 81 | self.intermediate_size = intermediate_size 82 | self.hidden_dropout_prob = hidden_dropout_prob 83 | self.attention_probs_dropout_prob = attention_probs_dropout_prob 84 | self.max_position_embeddings = max_position_embeddings 85 | self.type_vocab_size = type_vocab_size 86 | self.initializer_range = initializer_range 87 | 88 | @classmethod 89 | def from_dict(cls, json_object): 90 | """Constructs a `BertConfig` from a Python dictionary of parameters.""" 91 | config = BertConfig(vocab_size=None) 92 | for (key, value) in six.iteritems(json_object): 93 | config.__dict__[key] = value 94 | return config 95 | 96 | @classmethod 97 | def from_json_file(cls, json_file): 98 | """Constructs a `BertConfig` from a json file of parameters.""" 99 | with open(json_file, "r") as reader: 100 | text = reader.read() 101 | return cls.from_dict(json.loads(text)) 102 | 103 | def to_dict(self): 104 | """Serializes this instance to a Python dictionary.""" 105 | output = copy.deepcopy(self.__dict__) 106 | return output 107 | 108 | def to_json_string(self): 109 | """Serializes this instance to a JSON string.""" 110 | return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n" 111 | 112 | 113 | class BERTLayerNorm(nn.Module): 114 | def __init__(self, config, variance_epsilon=1e-12): 115 | """Construct a layernorm module in the TF style (epsilon inside the square root). 116 | """ 117 | super(BERTLayerNorm, self).__init__() 118 | self.gamma = nn.Parameter(torch.ones(config.hidden_size)) 119 | self.beta = nn.Parameter(torch.zeros(config.hidden_size)) 120 | self.variance_epsilon = variance_epsilon 121 | 122 | def forward(self, x): 123 | u = x.mean(-1, keepdim=True) 124 | s = (x - u).pow(2).mean(-1, keepdim=True) 125 | x = (x - u) / torch.sqrt(s + self.variance_epsilon) 126 | return self.gamma * x + self.beta 127 | 128 | class BERTEmbeddings(nn.Module): 129 | def __init__(self, config): 130 | super(BERTEmbeddings, self).__init__() 131 | """Construct the embedding module from word, position and token_type embeddings. 132 | """ 133 | self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size) 134 | self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) 135 | self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size) 136 | 137 | # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load 138 | # any TensorFlow checkpoint file 139 | self.LayerNorm = BERTLayerNorm(config) 140 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 141 | 142 | def forward(self, input_ids, token_type_ids=None): 143 | seq_length = input_ids.size(1) 144 | position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device) 145 | position_ids = position_ids.unsqueeze(0).expand_as(input_ids) 146 | if token_type_ids is None: 147 | token_type_ids = torch.zeros_like(input_ids) 148 | 149 | words_embeddings = self.word_embeddings(input_ids) 150 | position_embeddings = self.position_embeddings(position_ids) 151 | token_type_embeddings = self.token_type_embeddings(token_type_ids) 152 | 153 | embeddings = words_embeddings + position_embeddings + token_type_embeddings 154 | embeddings = self.LayerNorm(embeddings) 155 | embeddings = self.dropout(embeddings) 156 | return embeddings 157 | 158 | 159 | class BERTSelfAttention(nn.Module): 160 | def __init__(self, config): 161 | super(BERTSelfAttention, self).__init__() 162 | if config.hidden_size % config.num_attention_heads != 0: 163 | raise ValueError( 164 | "The hidden size (%d) is not a multiple of the number of attention " 165 | "heads (%d)" % (config.hidden_size, config.num_attention_heads)) 166 | self.num_attention_heads = config.num_attention_heads 167 | self.attention_head_size = int(config.hidden_size / config.num_attention_heads) 168 | self.all_head_size = self.num_attention_heads * self.attention_head_size 169 | 170 | self.query = nn.Linear(config.hidden_size, self.all_head_size) 171 | self.key = nn.Linear(config.hidden_size, self.all_head_size) 172 | self.value = nn.Linear(config.hidden_size, self.all_head_size) 173 | 174 | self.dropout = nn.Dropout(config.attention_probs_dropout_prob) 175 | 176 | def transpose_for_scores(self, x): 177 | new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) 178 | x = x.view(*new_x_shape) 179 | return x.permute(0, 2, 1, 3) 180 | 181 | def forward(self, hidden_states, attention_mask): 182 | mixed_query_layer = self.query(hidden_states) 183 | mixed_key_layer = self.key(hidden_states) 184 | mixed_value_layer = self.value(hidden_states) 185 | 186 | query_layer = self.transpose_for_scores(mixed_query_layer) 187 | key_layer = self.transpose_for_scores(mixed_key_layer) 188 | value_layer = self.transpose_for_scores(mixed_value_layer) 189 | 190 | # Take the dot product between "query" and "key" to get the raw attention scores. 191 | attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) 192 | attention_scores = attention_scores / math.sqrt(self.attention_head_size) 193 | # Apply the attention mask is (precomputed for all layers in BertModel forward() function) 194 | attention_scores = attention_scores + attention_mask 195 | 196 | # Normalize the attention scores to probabilities. 197 | attention_probs = nn.Softmax(dim=-1)(attention_scores) 198 | 199 | # This is actually dropping out entire tokens to attend to, which might 200 | # seem a bit unusual, but is taken from the original Transformer paper. 201 | attention_probs = self.dropout(attention_probs) 202 | 203 | context_layer = torch.matmul(attention_probs, value_layer) 204 | context_layer = context_layer.permute(0, 2, 1, 3).contiguous() 205 | new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) 206 | context_layer = context_layer.view(*new_context_layer_shape) 207 | return context_layer 208 | 209 | 210 | class BERTSelfOutput(nn.Module): 211 | def __init__(self, config): 212 | super(BERTSelfOutput, self).__init__() 213 | self.dense = nn.Linear(config.hidden_size, config.hidden_size) 214 | self.LayerNorm = BERTLayerNorm(config) 215 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 216 | 217 | def forward(self, hidden_states, input_tensor): 218 | hidden_states = self.dense(hidden_states) 219 | hidden_states = self.dropout(hidden_states) 220 | hidden_states = self.LayerNorm(hidden_states + input_tensor) 221 | return hidden_states 222 | 223 | 224 | class BERTAttention(nn.Module): 225 | def __init__(self, config): 226 | super(BERTAttention, self).__init__() 227 | self.self = BERTSelfAttention(config) 228 | self.output = BERTSelfOutput(config) 229 | 230 | def forward(self, input_tensor, attention_mask): 231 | self_output = self.self(input_tensor, attention_mask) 232 | attention_output = self.output(self_output, input_tensor) 233 | return attention_output 234 | 235 | 236 | class BERTIntermediate(nn.Module): 237 | def __init__(self, config): 238 | super(BERTIntermediate, self).__init__() 239 | self.dense = nn.Linear(config.hidden_size, config.intermediate_size) 240 | self.intermediate_act_fn = gelu 241 | 242 | def forward(self, hidden_states): 243 | hidden_states = self.dense(hidden_states) 244 | hidden_states = self.intermediate_act_fn(hidden_states) 245 | return hidden_states 246 | 247 | 248 | class BERTOutput(nn.Module): 249 | def __init__(self, config): 250 | super(BERTOutput, self).__init__() 251 | self.dense = nn.Linear(config.intermediate_size, config.hidden_size) 252 | self.LayerNorm = BERTLayerNorm(config) 253 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 254 | 255 | def forward(self, hidden_states, input_tensor): 256 | hidden_states = self.dense(hidden_states) 257 | hidden_states = self.dropout(hidden_states) 258 | hidden_states = self.LayerNorm(hidden_states + input_tensor) 259 | return hidden_states 260 | 261 | 262 | class BERTLayer(nn.Module): 263 | def __init__(self, config): 264 | super(BERTLayer, self).__init__() 265 | self.attention = BERTAttention(config) 266 | self.intermediate = BERTIntermediate(config) 267 | self.output = BERTOutput(config) 268 | 269 | def forward(self, hidden_states, attention_mask): 270 | attention_output = self.attention(hidden_states, attention_mask) 271 | intermediate_output = self.intermediate(attention_output) 272 | layer_output = self.output(intermediate_output, attention_output) 273 | return layer_output 274 | 275 | 276 | class BERTEncoder(nn.Module): 277 | def __init__(self, config): 278 | super(BERTEncoder, self).__init__() 279 | layer = BERTLayer(config) 280 | self.layer = nn.ModuleList([copy.deepcopy(layer) for _ in range(config.num_hidden_layers)]) 281 | 282 | def forward(self, hidden_states, attention_mask): 283 | all_encoder_layers = [] 284 | for layer_module in self.layer: 285 | hidden_states = layer_module(hidden_states, attention_mask) 286 | all_encoder_layers.append(hidden_states) 287 | return all_encoder_layers 288 | 289 | 290 | class BERTPooler(nn.Module): 291 | def __init__(self, config): 292 | super(BERTPooler, self).__init__() 293 | self.dense = nn.Linear(config.hidden_size, config.hidden_size) 294 | self.activation = nn.Tanh() 295 | 296 | def forward(self, hidden_states): 297 | # We "pool" the model by simply taking the hidden state corresponding 298 | # to the first token. 299 | first_token_tensor = hidden_states[:, 0] 300 | pooled_output = self.dense(first_token_tensor) 301 | pooled_output = self.activation(pooled_output) 302 | return pooled_output 303 | 304 | 305 | class BertModel(nn.Module): 306 | """BERT model ("Bidirectional Embedding Representations from a Transformer"). 307 | 308 | Example usage: 309 | ```python 310 | # Already been converted into WordPiece token ids 311 | input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) 312 | input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) 313 | token_type_ids = torch.LongTensor([[0, 0, 1], [0, 2, 0]]) 314 | 315 | config = modeling.BertConfig(vocab_size=32000, hidden_size=512, 316 | num_hidden_layers=8, num_attention_heads=6, intermediate_size=1024) 317 | 318 | model = modeling.BertModel(config=config) 319 | all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask) 320 | ``` 321 | """ 322 | def __init__(self, config: BertConfig): 323 | """Constructor for BertModel. 324 | 325 | Args: 326 | config: `BertConfig` instance. 327 | """ 328 | super(BertModel, self).__init__() 329 | self.embeddings = BERTEmbeddings(config) 330 | self.encoder = BERTEncoder(config) 331 | self.pooler = BERTPooler(config) 332 | 333 | def forward(self, input_ids, token_type_ids=None, attention_mask=None): 334 | if attention_mask is None: 335 | attention_mask = torch.ones_like(input_ids) 336 | if token_type_ids is None: 337 | token_type_ids = torch.zeros_like(input_ids) 338 | 339 | # We create a 3D attention mask from a 2D tensor mask. 340 | # Sizes are [batch_size, 1, 1, from_seq_length] 341 | # So we can broadcast to [batch_size, num_heads, to_seq_length, from_seq_length] 342 | # this attention mask is more simple than the triangular masking of causal attention 343 | # used in OpenAI GPT, we just need to prepare the broadcast dimension here. 344 | extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) 345 | 346 | # Since attention_mask is 1.0 for positions we want to attend and 0.0 for 347 | # masked positions, this operation will create a tensor which is 0.0 for 348 | # positions we want to attend and -10000.0 for masked positions. 349 | # Since we are adding it to the raw scores before the softmax, this is 350 | # effectively the same as removing these entirely. 351 | extended_attention_mask = extended_attention_mask.float() 352 | extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 353 | 354 | embedding_output = self.embeddings(input_ids, token_type_ids) 355 | all_encoder_layers = self.encoder(embedding_output, extended_attention_mask) 356 | sequence_output = all_encoder_layers[-1] 357 | pooled_output = self.pooler(sequence_output) 358 | return all_encoder_layers, pooled_output 359 | 360 | from fastNLP.modules.decoder import MLP 361 | 362 | class BertForSequenceClassification(nn.Module): 363 | """BERT model for classification. 364 | This module is composed of the BERT model with a linear layer on top of 365 | the pooled output. 366 | 367 | Example usage: 368 | ```python 369 | # Already been converted into WordPiece token ids 370 | input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) 371 | input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) 372 | token_type_ids = torch.LongTensor([[0, 0, 1], [0, 2, 0]]) 373 | 374 | config = BertConfig(vocab_size=32000, hidden_size=512, 375 | num_hidden_layers=8, num_attention_heads=6, intermediate_size=1024) 376 | 377 | num_labels = 2 378 | 379 | model = BertForSequenceClassification(config, num_labels) 380 | logits = model(input_ids, token_type_ids, input_mask) 381 | ``` 382 | """ 383 | def __init__(self, config, num_labels, layers=None, pooling=None): 384 | super(BertForSequenceClassification, self).__init__() 385 | self.bert = BertModel(config) 386 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 387 | self.pooling = pooling 388 | if layers is None or layers[0] == -2: 389 | self.layers = [] 390 | elif layers[0] == -1: 391 | self.layers = [i for i in range(config.num_hidden_layers)] 392 | else: 393 | self.layers = layers 394 | self.classifier = nn.Linear( 395 | config.hidden_size * max(1, len(self.layers) if self.pooling is None else 1), num_labels) 396 | # self.classifier = MLP( 397 | # [config.hidden_size * max(1, len(self.layers) if self.pooling is None else 1), config.hidden_size, num_labels]) 398 | 399 | def init_weights(module): 400 | if isinstance(module, (nn.Linear, nn.Embedding)): 401 | # Slightly different from the TF version which uses truncated_normal for initialization 402 | # cf https://github.com/pytorch/pytorch/pull/5617 403 | module.weight.data.normal_(mean=0.0, std=config.initializer_range) 404 | elif isinstance(module, BERTLayerNorm): 405 | module.beta.data.normal_(mean=0.0, std=config.initializer_range) 406 | module.gamma.data.normal_(mean=0.0, std=config.initializer_range) 407 | if isinstance(module, nn.Linear): 408 | module.bias.data.zero_() 409 | self.apply(init_weights) 410 | 411 | def forward(self, input_ids, token_type_ids, attention_mask, labels=None): 412 | encoded_layers, pooled_output = self.bert(input_ids, token_type_ids, attention_mask) 413 | if len(self.layers) > 0: 414 | hidden_state = [] 415 | for l in self.layers: 416 | hidden_state.append(encoded_layers[l][:, 0].unsqueeze(1)) 417 | hidden_state = torch.cat(hidden_state, dim=1) 418 | 419 | if self.pooling == 'max': 420 | hidden_state, _ = torch.max(hidden_state, dim=1) 421 | elif self.pooling == 'mean': 422 | hidden_state = torch.mean(hidden_state, dim=1) 423 | else: 424 | hidden_state = hidden_state.view(hidden_state.size(0), -1) 425 | 426 | hidden_state = self.dropout(hidden_state) 427 | logits = self.classifier(hidden_state) 428 | else: 429 | pooled_output = self.dropout(pooled_output) 430 | logits = self.classifier(pooled_output) 431 | 432 | if labels is not None: 433 | loss_fct = CrossEntropyLoss() 434 | loss = loss_fct(logits, labels) 435 | return loss, logits 436 | else: 437 | return logits 438 | 439 | 440 | class BertForQuestionAnswering(nn.Module): 441 | """BERT model for Question Answering (span extraction). 442 | This module is composed of the BERT model with a linear layer on top of 443 | the sequence output that computes start_logits and end_logits 444 | 445 | Example usage: 446 | ```python 447 | # Already been converted into WordPiece token ids 448 | input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) 449 | input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) 450 | token_type_ids = torch.LongTensor([[0, 0, 1], [0, 2, 0]]) 451 | 452 | config = BertConfig(vocab_size=32000, hidden_size=512, 453 | num_hidden_layers=8, num_attention_heads=6, intermediate_size=1024) 454 | 455 | model = BertForQuestionAnswering(config) 456 | start_logits, end_logits = model(input_ids, token_type_ids, input_mask) 457 | ``` 458 | """ 459 | def __init__(self, config): 460 | super(BertForQuestionAnswering, self).__init__() 461 | self.bert = BertModel(config) 462 | # TODO check with Google if it's normal there is no dropout on the token classifier of SQuAD in the TF version 463 | # self.dropout = nn.Dropout(config.hidden_dropout_prob) 464 | self.qa_outputs = nn.Linear(config.hidden_size, 2) 465 | 466 | def init_weights(module): 467 | if isinstance(module, (nn.Linear, nn.Embedding)): 468 | # Slightly different from the TF version which uses truncated_normal for initialization 469 | # cf https://github.com/pytorch/pytorch/pull/5617 470 | module.weight.data.normal_(mean=0.0, std=config.initializer_range) 471 | elif isinstance(module, BERTLayerNorm): 472 | module.beta.data.normal_(mean=0.0, std=config.initializer_range) 473 | module.gamma.data.normal_(mean=0.0, std=config.initializer_range) 474 | if isinstance(module, nn.Linear): 475 | module.bias.data.zero_() 476 | self.apply(init_weights) 477 | 478 | def forward(self, input_ids, token_type_ids, attention_mask, start_positions=None, end_positions=None): 479 | all_encoder_layers, _ = self.bert(input_ids, token_type_ids, attention_mask) 480 | sequence_output = all_encoder_layers[-1] 481 | logits = self.qa_outputs(sequence_output) 482 | start_logits, end_logits = logits.split(1, dim=-1) 483 | start_logits = start_logits.squeeze(-1) 484 | end_logits = end_logits.squeeze(-1) 485 | 486 | if start_positions is not None and end_positions is not None: 487 | # If we are on multi-GPU, split add a dimension - if not this is a no-op 488 | start_positions = start_positions.squeeze(-1) 489 | end_positions = end_positions.squeeze(-1) 490 | # sometimes the start/end positions are outside our model inputs, we ignore these terms 491 | ignored_index = start_logits.size(1) 492 | start_positions.clamp_(0, ignored_index) 493 | end_positions.clamp_(0, ignored_index) 494 | 495 | loss_fct = CrossEntropyLoss(ignore_index=ignored_index) 496 | start_loss = loss_fct(start_logits, start_positions) 497 | end_loss = loss_fct(end_logits, end_positions) 498 | total_loss = (start_loss + end_loss) / 2 499 | return total_loss 500 | else: 501 | return start_logits, end_logits 502 | -------------------------------------------------------------------------------- /codes/fine-tuning/optimization.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """PyTorch optimization for BERT model.""" 16 | 17 | import math 18 | import torch 19 | from torch.optim import Optimizer 20 | from torch.nn.utils import clip_grad_norm_ 21 | 22 | def warmup_cosine(x, warmup=0.002): 23 | if x < warmup: 24 | return x/warmup 25 | return 0.5 * (1.0 + torch.cos(math.pi * x)) 26 | 27 | def warmup_constant(x, warmup=0.002): 28 | if x < warmup: 29 | return x/warmup 30 | return 1.0 31 | 32 | def warmup_linear(x, warmup=0.002): 33 | if x < warmup: 34 | return x/warmup 35 | return 1.0 - x 36 | 37 | SCHEDULES = { 38 | 'warmup_cosine':warmup_cosine, 39 | 'warmup_constant':warmup_constant, 40 | 'warmup_linear':warmup_linear, 41 | } 42 | 43 | 44 | class BERTAdam(Optimizer): 45 | """Implements BERT version of Adam algorithm with weight decay fix (and no ). 46 | Params: 47 | lr: learning rate 48 | warmup: portion of t_total for the warmup, -1 means no warmup. Default: -1 49 | t_total: total number of training steps for the learning 50 | rate schedule, -1 means constant learning rate. Default: -1 51 | schedule: schedule to use for the warmup (see above). Default: 'warmup_linear' 52 | b1: Adams b1. Default: 0.9 53 | b2: Adams b2. Default: 0.999 54 | e: Adams epsilon. Default: 1e-6 55 | weight_decay_rate: Weight decay. Default: 0.01 56 | max_grad_norm: Maximum norm for the gradients (-1 means no clipping). Default: 1.0 57 | """ 58 | def __init__(self, params, lr, warmup=-1, t_total=-1, schedule='warmup_linear', 59 | b1=0.9, b2=0.999, e=1e-6, weight_decay_rate=0.01, 60 | max_grad_norm=1.0): 61 | if not lr >= 0.0: 62 | raise ValueError("Invalid learning rate: {} - should be >= 0.0".format(lr)) 63 | if schedule not in SCHEDULES: 64 | raise ValueError("Invalid schedule parameter: {}".format(schedule)) 65 | if not 0.0 <= warmup < 1.0 and not warmup == -1: 66 | raise ValueError("Invalid warmup: {} - should be in [0.0, 1.0[ or -1".format(warmup)) 67 | if not 0.0 <= b1 < 1.0: 68 | raise ValueError("Invalid b1 parameter: {} - should be in [0.0, 1.0[".format(b1)) 69 | if not 0.0 <= b2 < 1.0: 70 | raise ValueError("Invalid b2 parameter: {} - should be in [0.0, 1.0[".format(b2)) 71 | if not e >= 0.0: 72 | raise ValueError("Invalid epsilon value: {} - should be >= 0.0".format(e)) 73 | defaults = dict(lr=lr, schedule=schedule, warmup=warmup, t_total=t_total, 74 | b1=b1, b2=b2, e=e, weight_decay_rate=weight_decay_rate, 75 | max_grad_norm=max_grad_norm) 76 | super(BERTAdam, self).__init__(params, defaults) 77 | 78 | def get_lr(self): 79 | lr = [] 80 | print("l_total=",len(self.param_groups)) 81 | for group in self.param_groups: 82 | print("l_p=",len(group['params'])) 83 | for p in group['params']: 84 | state = self.state[p] 85 | if len(state) == 0: 86 | return [0] 87 | if group['t_total'] != -1: 88 | schedule_fct = SCHEDULES[group['schedule']] 89 | lr_scheduled = group['lr'] * schedule_fct(state['step']/group['t_total'], group['warmup']) 90 | else: 91 | lr_scheduled = group['lr'] 92 | lr.append(lr_scheduled) 93 | return lr 94 | 95 | def to(self, device): 96 | """ Move the optimizer state to a specified device""" 97 | for state in self.state.values(): 98 | state['exp_avg'].to(device) 99 | state['exp_avg_sq'].to(device) 100 | 101 | def initialize_step(self, initial_step): 102 | """Initialize state with a defined step (but we don't have stored averaged). 103 | Arguments: 104 | initial_step (int): Initial step number. 105 | """ 106 | for group in self.param_groups: 107 | for p in group['params']: 108 | state = self.state[p] 109 | # State initialization 110 | state['step'] = initial_step 111 | # Exponential moving average of gradient values 112 | state['exp_avg'] = torch.zeros_like(p.data) 113 | # Exponential moving average of squared gradient values 114 | state['exp_avg_sq'] = torch.zeros_like(p.data) 115 | 116 | def step(self, closure=None): 117 | """Performs a single optimization step. 118 | 119 | Arguments: 120 | closure (callable, optional): A closure that reevaluates the model 121 | and returns the loss. 122 | """ 123 | loss = None 124 | if closure is not None: 125 | loss = closure() 126 | 127 | for group in self.param_groups: 128 | for p in group['params']: 129 | if p.grad is None: 130 | continue 131 | grad = p.grad.data 132 | if grad.is_sparse: 133 | raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead') 134 | 135 | state = self.state[p] 136 | 137 | # State initialization 138 | if len(state) == 0: 139 | state['step'] = 0 140 | # Exponential moving average of gradient values 141 | state['next_m'] = torch.zeros_like(p.data) 142 | # Exponential moving average of squared gradient values 143 | state['next_v'] = torch.zeros_like(p.data) 144 | 145 | next_m, next_v = state['next_m'], state['next_v'] 146 | beta1, beta2 = group['b1'], group['b2'] 147 | 148 | # Add grad clipping 149 | if group['max_grad_norm'] > 0: 150 | clip_grad_norm_(p, group['max_grad_norm']) 151 | 152 | # Decay the first and second moment running average coefficient 153 | # In-place operations to update the averages at the same time 154 | next_m.mul_(beta1).add_(1 - beta1, grad) 155 | next_v.mul_(beta2).addcmul_(1 - beta2, grad, grad) 156 | update = next_m / (next_v.sqrt() + group['e']) 157 | 158 | # Just adding the square of the weights to the loss function is *not* 159 | # the correct way of using L2 regularization/weight decay with Adam, 160 | # since that will interact with the m and v parameters in strange ways. 161 | # 162 | # Instead we want ot decay the weights in a manner that doesn't interact 163 | # with the m/v parameters. This is equivalent to adding the square 164 | # of the weights to the loss with plain (non-momentum) SGD. 165 | if group['weight_decay_rate'] > 0.0: 166 | update += group['weight_decay_rate'] * p.data 167 | 168 | if group['t_total'] != -1: 169 | schedule_fct = SCHEDULES[group['schedule']] 170 | lr_scheduled = group['lr'] * schedule_fct(state['step']/group['t_total'], group['warmup']) 171 | else: 172 | lr_scheduled = group['lr'] 173 | 174 | update_with_lr = lr_scheduled * update 175 | p.data.add_(-update_with_lr) 176 | 177 | state['step'] += 1 178 | 179 | # step_size = lr_scheduled * math.sqrt(bias_correction2) / bias_correction1 180 | # bias_correction1 = 1 - beta1 ** state['step'] 181 | # bias_correction2 = 1 - beta2 ** state['step'] 182 | 183 | return loss 184 | -------------------------------------------------------------------------------- /codes/fine-tuning/tokenization.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Tokenization classes.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import collections 22 | import unicodedata 23 | import six 24 | 25 | 26 | def convert_to_unicode(text): 27 | """Converts `text` to Unicode (if it's not already), assuming utf-8 input.""" 28 | if six.PY3: 29 | if isinstance(text, str): 30 | return text 31 | elif isinstance(text, bytes): 32 | return text.decode("utf-8", "ignore") 33 | else: 34 | raise ValueError("Unsupported string type: %s" % (type(text))) 35 | elif six.PY2: 36 | if isinstance(text, str): 37 | return text.decode("utf-8", "ignore") 38 | elif isinstance(text, unicode): 39 | return text 40 | else: 41 | raise ValueError("Unsupported string type: %s" % (type(text))) 42 | else: 43 | raise ValueError("Not running on Python2 or Python 3?") 44 | 45 | 46 | def printable_text(text): 47 | """Returns text encoded in a way suitable for print or `tf.logging`.""" 48 | 49 | # These functions want `str` for both Python2 and Python3, but in one case 50 | # it's a Unicode string and in the other it's a byte string. 51 | if six.PY3: 52 | if isinstance(text, str): 53 | return text 54 | elif isinstance(text, bytes): 55 | return text.decode("utf-8", "ignore") 56 | else: 57 | raise ValueError("Unsupported string type: %s" % (type(text))) 58 | elif six.PY2: 59 | if isinstance(text, str): 60 | return text 61 | elif isinstance(text, unicode): 62 | return text.encode("utf-8") 63 | else: 64 | raise ValueError("Unsupported string type: %s" % (type(text))) 65 | else: 66 | raise ValueError("Not running on Python2 or Python 3?") 67 | 68 | 69 | def load_vocab(vocab_file): 70 | """Loads a vocabulary file into a dictionary.""" 71 | vocab = collections.OrderedDict() 72 | index = 0 73 | with open(vocab_file, "r", encoding="utf-8") as reader: 74 | while True: 75 | token = convert_to_unicode(reader.readline()) 76 | if not token: 77 | break 78 | token = token.strip() 79 | vocab[token] = index 80 | index += 1 81 | return vocab 82 | 83 | 84 | def convert_tokens_to_ids(vocab, tokens): 85 | """Converts a sequence of tokens into ids using the vocab.""" 86 | ids = [] 87 | for token in tokens: 88 | ids.append(vocab[token]) 89 | return ids 90 | 91 | 92 | def whitespace_tokenize(text): 93 | """Runs basic whitespace cleaning and splitting on a peice of text.""" 94 | text = text.strip() 95 | if not text: 96 | return [] 97 | tokens = text.split() 98 | return tokens 99 | 100 | 101 | class FullTokenizer(object): 102 | """Runs end-to-end tokenziation.""" 103 | 104 | def __init__(self, vocab_file, do_lower_case=True): 105 | self.vocab = load_vocab(vocab_file) 106 | self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case) 107 | self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab) 108 | 109 | def tokenize(self, text): 110 | split_tokens = [] 111 | for token in self.basic_tokenizer.tokenize(text): 112 | for sub_token in self.wordpiece_tokenizer.tokenize(token): 113 | split_tokens.append(sub_token) 114 | 115 | return split_tokens 116 | 117 | def convert_tokens_to_ids(self, tokens): 118 | return convert_tokens_to_ids(self.vocab, tokens) 119 | 120 | 121 | class BasicTokenizer(object): 122 | """Runs basic tokenization (punctuation splitting, lower casing, etc.).""" 123 | 124 | def __init__(self, do_lower_case=True): 125 | """Constructs a BasicTokenizer. 126 | 127 | Args: 128 | do_lower_case: Whether to lower case the input. 129 | """ 130 | self.do_lower_case = do_lower_case 131 | 132 | def tokenize(self, text): 133 | """Tokenizes a piece of text.""" 134 | text = convert_to_unicode(text) 135 | text = self._clean_text(text) 136 | orig_tokens = whitespace_tokenize(text) 137 | split_tokens = [] 138 | for token in orig_tokens: 139 | if self.do_lower_case: 140 | token = token.lower() 141 | token = self._run_strip_accents(token) 142 | split_tokens.extend(self._run_split_on_punc(token)) 143 | 144 | output_tokens = whitespace_tokenize(" ".join(split_tokens)) 145 | return output_tokens 146 | 147 | def _run_strip_accents(self, text): 148 | """Strips accents from a piece of text.""" 149 | text = unicodedata.normalize("NFD", text) 150 | output = [] 151 | for char in text: 152 | cat = unicodedata.category(char) 153 | if cat == "Mn": 154 | continue 155 | output.append(char) 156 | return "".join(output) 157 | 158 | def _run_split_on_punc(self, text): 159 | """Splits punctuation on a piece of text.""" 160 | chars = list(text) 161 | i = 0 162 | start_new_word = True 163 | output = [] 164 | while i < len(chars): 165 | char = chars[i] 166 | if _is_punctuation(char): 167 | output.append([char]) 168 | start_new_word = True 169 | else: 170 | if start_new_word: 171 | output.append([]) 172 | start_new_word = False 173 | output[-1].append(char) 174 | i += 1 175 | 176 | return ["".join(x) for x in output] 177 | 178 | def _clean_text(self, text): 179 | """Performs invalid character removal and whitespace cleanup on text.""" 180 | output = [] 181 | for char in text: 182 | cp = ord(char) 183 | if cp == 0 or cp == 0xfffd or _is_control(char): 184 | continue 185 | if _is_whitespace(char): 186 | output.append(" ") 187 | else: 188 | output.append(char) 189 | return "".join(output) 190 | 191 | 192 | class WordpieceTokenizer(object): 193 | """Runs WordPiece tokenization.""" 194 | 195 | def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=100): 196 | self.vocab = vocab 197 | self.unk_token = unk_token 198 | self.max_input_chars_per_word = max_input_chars_per_word 199 | 200 | def tokenize(self, text): 201 | """Tokenizes a piece of text into its word pieces. 202 | 203 | This uses a greedy longest-match-first algorithm to perform tokenization 204 | using the given vocabulary. 205 | 206 | For example: 207 | input = "unaffable" 208 | output = ["un", "##aff", "##able"] 209 | 210 | Args: 211 | text: A single token or whitespace separated tokens. This should have 212 | already been passed through `BasicTokenizer. 213 | 214 | Returns: 215 | A list of wordpiece tokens. 216 | """ 217 | 218 | text = convert_to_unicode(text) 219 | 220 | output_tokens = [] 221 | for token in whitespace_tokenize(text): 222 | chars = list(token) 223 | if len(chars) > self.max_input_chars_per_word: 224 | output_tokens.append(self.unk_token) 225 | continue 226 | 227 | is_bad = False 228 | start = 0 229 | sub_tokens = [] 230 | while start < len(chars): 231 | end = len(chars) 232 | cur_substr = None 233 | while start < end: 234 | substr = "".join(chars[start:end]) 235 | if start > 0: 236 | substr = "##" + substr 237 | if substr in self.vocab: 238 | cur_substr = substr 239 | break 240 | end -= 1 241 | if cur_substr is None: 242 | is_bad = True 243 | break 244 | sub_tokens.append(cur_substr) 245 | start = end 246 | 247 | if is_bad: 248 | output_tokens.append(self.unk_token) 249 | else: 250 | output_tokens.extend(sub_tokens) 251 | return output_tokens 252 | 253 | 254 | def _is_whitespace(char): 255 | """Checks whether `chars` is a whitespace character.""" 256 | # \t, \n, and \r are technically contorl characters but we treat them 257 | # as whitespace since they are generally considered as such. 258 | if char == " " or char == "\t" or char == "\n" or char == "\r": 259 | return True 260 | cat = unicodedata.category(char) 261 | if cat == "Zs": 262 | return True 263 | return False 264 | 265 | 266 | def _is_control(char): 267 | """Checks whether `chars` is a control character.""" 268 | # These are technically control characters but we count them as whitespace 269 | # characters. 270 | if char == "\t" or char == "\n" or char == "\r": 271 | return False 272 | cat = unicodedata.category(char) 273 | if cat.startswith("C"): 274 | return True 275 | return False 276 | 277 | 278 | def _is_punctuation(char): 279 | """Checks whether `chars` is a punctuation character.""" 280 | cp = ord(char) 281 | # We treat all non-letter/number ASCII as punctuation. 282 | # Characters such as "^", "$", and "`" are not in the Unicode 283 | # Punctuation class but we treat them as punctuation anyways, for 284 | # consistency. 285 | if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or 286 | (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)): 287 | return True 288 | cat = unicodedata.category(char) 289 | if cat.startswith("P"): 290 | return True 291 | return False 292 | -------------------------------------------------------------------------------- /codes/further-pre-training/create_pretraining_data.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Create masked LM/next sentence masked_lm TF examples for BERT.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import collections 22 | import random 23 | import tensorflow as tf 24 | import tokenization 25 | 26 | flags = tf.flags 27 | 28 | FLAGS = flags.FLAGS 29 | 30 | flags.DEFINE_string("input_file", None, 31 | "Input raw text file (or comma-separated list of files).") 32 | 33 | flags.DEFINE_string( 34 | "output_file", None, 35 | "Output TF example file (or comma-separated list of files).") 36 | 37 | flags.DEFINE_string("vocab_file", None, 38 | "The vocabulary file that the BERT model was trained on.") 39 | 40 | flags.DEFINE_bool( 41 | "do_lower_case", True, 42 | "Whether to lower case the input text. Should be True for uncased " 43 | "models and False for cased models.") 44 | 45 | flags.DEFINE_integer("max_seq_length", 128, "Maximum sequence length.") 46 | 47 | flags.DEFINE_integer("max_predictions_per_seq", 20, 48 | "Maximum number of masked LM predictions per sequence.") 49 | 50 | flags.DEFINE_integer("random_seed", 12345, "Random seed for data generation.") 51 | 52 | flags.DEFINE_integer( 53 | "dupe_factor", 10, 54 | "Number of times to duplicate the input data (with different masks).") 55 | 56 | flags.DEFINE_float("masked_lm_prob", 0.15, "Masked LM probability.") 57 | 58 | flags.DEFINE_float( 59 | "short_seq_prob", 0.1, 60 | "Probability of creating sequences which are shorter than the " 61 | "maximum length.") 62 | 63 | 64 | class TrainingInstance(object): 65 | """A single training instance (sentence pair).""" 66 | 67 | def __init__(self, tokens, segment_ids, masked_lm_positions, masked_lm_labels, 68 | is_random_next): 69 | self.tokens = tokens 70 | self.segment_ids = segment_ids 71 | self.is_random_next = is_random_next 72 | self.masked_lm_positions = masked_lm_positions 73 | self.masked_lm_labels = masked_lm_labels 74 | 75 | def __str__(self): 76 | s = "" 77 | s += "tokens: %s\n" % (" ".join( 78 | [tokenization.printable_text(x) for x in self.tokens])) 79 | s += "segment_ids: %s\n" % (" ".join([str(x) for x in self.segment_ids])) 80 | s += "is_random_next: %s\n" % self.is_random_next 81 | s += "masked_lm_positions: %s\n" % (" ".join( 82 | [str(x) for x in self.masked_lm_positions])) 83 | s += "masked_lm_labels: %s\n" % (" ".join( 84 | [tokenization.printable_text(x) for x in self.masked_lm_labels])) 85 | s += "\n" 86 | return s 87 | 88 | def __repr__(self): 89 | return self.__str__() 90 | 91 | 92 | def write_instance_to_example_files(instances, tokenizer, max_seq_length, 93 | max_predictions_per_seq, output_files): 94 | """Create TF example files from `TrainingInstance`s.""" 95 | writers = [] 96 | for output_file in output_files: 97 | writers.append(tf.python_io.TFRecordWriter(output_file)) 98 | 99 | writer_index = 0 100 | 101 | total_written = 0 102 | for (inst_index, instance) in enumerate(instances): 103 | input_ids = tokenizer.convert_tokens_to_ids(instance.tokens) 104 | input_mask = [1] * len(input_ids) 105 | segment_ids = list(instance.segment_ids) 106 | assert len(input_ids) <= max_seq_length 107 | 108 | while len(input_ids) < max_seq_length: 109 | input_ids.append(0) 110 | input_mask.append(0) 111 | segment_ids.append(0) 112 | 113 | assert len(input_ids) == max_seq_length 114 | assert len(input_mask) == max_seq_length 115 | assert len(segment_ids) == max_seq_length 116 | 117 | masked_lm_positions = list(instance.masked_lm_positions) 118 | masked_lm_ids = tokenizer.convert_tokens_to_ids(instance.masked_lm_labels) 119 | masked_lm_weights = [1.0] * len(masked_lm_ids) 120 | 121 | while len(masked_lm_positions) < max_predictions_per_seq: 122 | masked_lm_positions.append(0) 123 | masked_lm_ids.append(0) 124 | masked_lm_weights.append(0.0) 125 | 126 | next_sentence_label = 1 if instance.is_random_next else 0 127 | 128 | features = collections.OrderedDict() 129 | features["input_ids"] = create_int_feature(input_ids) 130 | features["input_mask"] = create_int_feature(input_mask) 131 | features["segment_ids"] = create_int_feature(segment_ids) 132 | features["masked_lm_positions"] = create_int_feature(masked_lm_positions) 133 | features["masked_lm_ids"] = create_int_feature(masked_lm_ids) 134 | features["masked_lm_weights"] = create_float_feature(masked_lm_weights) 135 | features["next_sentence_labels"] = create_int_feature([next_sentence_label]) 136 | 137 | tf_example = tf.train.Example(features=tf.train.Features(feature=features)) 138 | 139 | writers[writer_index].write(tf_example.SerializeToString()) 140 | writer_index = (writer_index + 1) % len(writers) 141 | 142 | total_written += 1 143 | 144 | if inst_index < 20: 145 | tf.logging.info("*** Example ***") 146 | tf.logging.info("tokens: %s" % " ".join( 147 | [tokenization.printable_text(x) for x in instance.tokens])) 148 | 149 | for feature_name in features.keys(): 150 | feature = features[feature_name] 151 | values = [] 152 | if feature.int64_list.value: 153 | values = feature.int64_list.value 154 | elif feature.float_list.value: 155 | values = feature.float_list.value 156 | tf.logging.info( 157 | "%s: %s" % (feature_name, " ".join([str(x) for x in values]))) 158 | 159 | for writer in writers: 160 | writer.close() 161 | 162 | tf.logging.info("Wrote %d total instances", total_written) 163 | 164 | 165 | def create_int_feature(values): 166 | feature = tf.train.Feature(int64_list=tf.train.Int64List(value=list(values))) 167 | return feature 168 | 169 | 170 | def create_float_feature(values): 171 | feature = tf.train.Feature(float_list=tf.train.FloatList(value=list(values))) 172 | return feature 173 | 174 | 175 | def create_training_instances(input_files, tokenizer, max_seq_length, 176 | dupe_factor, short_seq_prob, masked_lm_prob, 177 | max_predictions_per_seq, rng): 178 | """Create `TrainingInstance`s from raw text.""" 179 | all_documents = [[]] 180 | 181 | # Input file format: 182 | # (1) One sentence per line. These should ideally be actual sentences, not 183 | # entire paragraphs or arbitrary spans of text. (Because we use the 184 | # sentence boundaries for the "next sentence prediction" task). 185 | # (2) Blank lines between documents. Document boundaries are needed so 186 | # that the "next sentence prediction" task doesn't span between documents. 187 | for input_file in input_files: 188 | with tf.gfile.GFile(input_file, "r") as reader: 189 | while True: 190 | line = tokenization.convert_to_unicode(reader.readline()) 191 | if not line: 192 | break 193 | line = line.strip() 194 | 195 | # Empty lines are used as document delimiters 196 | if not line: 197 | all_documents.append([]) 198 | tokens = tokenizer.tokenize(line) 199 | if tokens: 200 | all_documents[-1].append(tokens) 201 | 202 | # Remove empty documents 203 | all_documents = [x for x in all_documents if x] 204 | rng.shuffle(all_documents) 205 | 206 | vocab_words = list(tokenizer.vocab.keys()) 207 | instances = [] 208 | for _ in range(dupe_factor): 209 | for document_index in range(len(all_documents)): 210 | instances.extend( 211 | create_instances_from_document( 212 | all_documents, document_index, max_seq_length, short_seq_prob, 213 | masked_lm_prob, max_predictions_per_seq, vocab_words, rng)) 214 | 215 | rng.shuffle(instances) 216 | return instances 217 | 218 | 219 | def create_instances_from_document( 220 | all_documents, document_index, max_seq_length, short_seq_prob, 221 | masked_lm_prob, max_predictions_per_seq, vocab_words, rng): 222 | """Creates `TrainingInstance`s for a single document.""" 223 | document = all_documents[document_index] 224 | 225 | # Account for [CLS], [SEP], [SEP] 226 | max_num_tokens = max_seq_length - 3 227 | 228 | # We *usually* want to fill up the entire sequence since we are padding 229 | # to `max_seq_length` anyways, so short sequences are generally wasted 230 | # computation. However, we *sometimes* 231 | # (i.e., short_seq_prob == 0.1 == 10% of the time) want to use shorter 232 | # sequences to minimize the mismatch between pre-training and fine-tuning. 233 | # The `target_seq_length` is just a rough target however, whereas 234 | # `max_seq_length` is a hard limit. 235 | target_seq_length = max_num_tokens 236 | if rng.random() < short_seq_prob: 237 | target_seq_length = rng.randint(2, max_num_tokens) 238 | 239 | # We DON'T just concatenate all of the tokens from a document into a long 240 | # sequence and choose an arbitrary split point because this would make the 241 | # next sentence prediction task too easy. Instead, we split the input into 242 | # segments "A" and "B" based on the actual "sentences" provided by the user 243 | # input. 244 | instances = [] 245 | current_chunk = [] 246 | current_length = 0 247 | i = 0 248 | while i < len(document): 249 | segment = document[i] 250 | current_chunk.append(segment) 251 | current_length += len(segment) 252 | if i == len(document) - 1 or current_length >= target_seq_length: 253 | if current_chunk: 254 | # `a_end` is how many segments from `current_chunk` go into the `A` 255 | # (first) sentence. 256 | a_end = 1 257 | if len(current_chunk) >= 2: 258 | a_end = rng.randint(1, len(current_chunk) - 1) 259 | 260 | tokens_a = [] 261 | for j in range(a_end): 262 | tokens_a.extend(current_chunk[j]) 263 | 264 | tokens_b = [] 265 | # Random next 266 | is_random_next = False 267 | if len(current_chunk) == 1 or rng.random() < 0.5: 268 | is_random_next = True 269 | target_b_length = target_seq_length - len(tokens_a) 270 | 271 | # This should rarely go for more than one iteration for large 272 | # corpora. However, just to be careful, we try to make sure that 273 | # the random document is not the same as the document 274 | # we're processing. 275 | for _ in range(10): 276 | random_document_index = rng.randint(0, len(all_documents) - 1) 277 | if random_document_index != document_index: 278 | break 279 | 280 | random_document = all_documents[random_document_index] 281 | random_start = rng.randint(0, len(random_document) - 1) 282 | for j in range(random_start, len(random_document)): 283 | tokens_b.extend(random_document[j]) 284 | if len(tokens_b) >= target_b_length: 285 | break 286 | # We didn't actually use these segments so we "put them back" so 287 | # they don't go to waste. 288 | num_unused_segments = len(current_chunk) - a_end 289 | i -= num_unused_segments 290 | # Actual next 291 | else: 292 | is_random_next = False 293 | for j in range(a_end, len(current_chunk)): 294 | tokens_b.extend(current_chunk[j]) 295 | truncate_seq_pair(tokens_a, tokens_b, max_num_tokens, rng) 296 | 297 | assert len(tokens_a) >= 1 298 | assert len(tokens_b) >= 1 299 | 300 | tokens = [] 301 | segment_ids = [] 302 | tokens.append("[CLS]") 303 | segment_ids.append(0) 304 | for token in tokens_a: 305 | tokens.append(token) 306 | segment_ids.append(0) 307 | 308 | tokens.append("[SEP]") 309 | segment_ids.append(0) 310 | 311 | for token in tokens_b: 312 | tokens.append(token) 313 | segment_ids.append(1) 314 | tokens.append("[SEP]") 315 | segment_ids.append(1) 316 | 317 | (tokens, masked_lm_positions, 318 | masked_lm_labels) = create_masked_lm_predictions( 319 | tokens, masked_lm_prob, max_predictions_per_seq, vocab_words, rng) 320 | instance = TrainingInstance( 321 | tokens=tokens, 322 | segment_ids=segment_ids, 323 | is_random_next=is_random_next, 324 | masked_lm_positions=masked_lm_positions, 325 | masked_lm_labels=masked_lm_labels) 326 | instances.append(instance) 327 | current_chunk = [] 328 | current_length = 0 329 | i += 1 330 | 331 | return instances 332 | 333 | 334 | MaskedLmInstance = collections.namedtuple("MaskedLmInstance", 335 | ["index", "label"]) 336 | 337 | 338 | def create_masked_lm_predictions(tokens, masked_lm_prob, 339 | max_predictions_per_seq, vocab_words, rng): 340 | """Creates the predictions for the masked LM objective.""" 341 | 342 | cand_indexes = [] 343 | for (i, token) in enumerate(tokens): 344 | if token == "[CLS]" or token == "[SEP]": 345 | continue 346 | cand_indexes.append(i) 347 | 348 | rng.shuffle(cand_indexes) 349 | 350 | output_tokens = list(tokens) 351 | 352 | num_to_predict = min(max_predictions_per_seq, 353 | max(1, int(round(len(tokens) * masked_lm_prob)))) 354 | 355 | masked_lms = [] 356 | covered_indexes = set() 357 | for index in cand_indexes: 358 | if len(masked_lms) >= num_to_predict: 359 | break 360 | if index in covered_indexes: 361 | continue 362 | covered_indexes.add(index) 363 | 364 | masked_token = None 365 | # 80% of the time, replace with [MASK] 366 | if rng.random() < 0.8: 367 | masked_token = "[MASK]" 368 | else: 369 | # 10% of the time, keep original 370 | if rng.random() < 0.5: 371 | masked_token = tokens[index] 372 | # 10% of the time, replace with random word 373 | else: 374 | masked_token = vocab_words[rng.randint(0, len(vocab_words) - 1)] 375 | 376 | output_tokens[index] = masked_token 377 | 378 | masked_lms.append(MaskedLmInstance(index=index, label=tokens[index])) 379 | 380 | masked_lms = sorted(masked_lms, key=lambda x: x.index) 381 | 382 | masked_lm_positions = [] 383 | masked_lm_labels = [] 384 | for p in masked_lms: 385 | masked_lm_positions.append(p.index) 386 | masked_lm_labels.append(p.label) 387 | 388 | return (output_tokens, masked_lm_positions, masked_lm_labels) 389 | 390 | 391 | def truncate_seq_pair(tokens_a, tokens_b, max_num_tokens, rng): 392 | """Truncates a pair of sequences to a maximum sequence length.""" 393 | while True: 394 | total_length = len(tokens_a) + len(tokens_b) 395 | if total_length <= max_num_tokens: 396 | break 397 | 398 | trunc_tokens = tokens_a if len(tokens_a) > len(tokens_b) else tokens_b 399 | assert len(trunc_tokens) >= 1 400 | 401 | # We want to sometimes truncate from the front and sometimes from the 402 | # back to add more randomness and avoid biases. 403 | if rng.random() < 0.5: 404 | del trunc_tokens[0] 405 | else: 406 | trunc_tokens.pop() 407 | 408 | 409 | def main(_): 410 | tf.logging.set_verbosity(tf.logging.INFO) 411 | 412 | tokenizer = tokenization.FullTokenizer( 413 | vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case) 414 | 415 | input_files = [] 416 | for input_pattern in FLAGS.input_file.split(","): 417 | input_files.extend(tf.gfile.Glob(input_pattern)) 418 | 419 | tf.logging.info("*** Reading from input files ***") 420 | for input_file in input_files: 421 | tf.logging.info(" %s", input_file) 422 | 423 | rng = random.Random(FLAGS.random_seed) 424 | instances = create_training_instances( 425 | input_files, tokenizer, FLAGS.max_seq_length, FLAGS.dupe_factor, 426 | FLAGS.short_seq_prob, FLAGS.masked_lm_prob, FLAGS.max_predictions_per_seq, 427 | rng) 428 | 429 | output_files = FLAGS.output_file.split(",") 430 | tf.logging.info("*** Writing to output files ***") 431 | for output_file in output_files: 432 | tf.logging.info(" %s", output_file) 433 | 434 | write_instance_to_example_files(instances, tokenizer, FLAGS.max_seq_length, 435 | FLAGS.max_predictions_per_seq, output_files) 436 | 437 | 438 | if __name__ == "__main__": 439 | flags.mark_flag_as_required("input_file") 440 | flags.mark_flag_as_required("output_file") 441 | flags.mark_flag_as_required("vocab_file") 442 | tf.app.run() 443 | -------------------------------------------------------------------------------- /codes/further-pre-training/extract_features.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Extract pre-computed feature vectors from BERT.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import codecs 22 | import collections 23 | import json 24 | import re 25 | 26 | import modeling 27 | import tokenization 28 | import tensorflow as tf 29 | 30 | flags = tf.flags 31 | 32 | FLAGS = flags.FLAGS 33 | 34 | flags.DEFINE_string("input_file", None, "") 35 | 36 | flags.DEFINE_string("output_file", None, "") 37 | 38 | flags.DEFINE_string("layers", "-1,-2,-3,-4", "") 39 | 40 | flags.DEFINE_string( 41 | "bert_config_file", None, 42 | "The config json file corresponding to the pre-trained BERT model. " 43 | "This specifies the model architecture.") 44 | 45 | flags.DEFINE_integer( 46 | "max_seq_length", 128, 47 | "The maximum total input sequence length after WordPiece tokenization. " 48 | "Sequences longer than this will be truncated, and sequences shorter " 49 | "than this will be padded.") 50 | 51 | flags.DEFINE_string( 52 | "init_checkpoint", None, 53 | "Initial checkpoint (usually from a pre-trained BERT model).") 54 | 55 | flags.DEFINE_string("vocab_file", None, 56 | "The vocabulary file that the BERT model was trained on.") 57 | 58 | flags.DEFINE_bool( 59 | "do_lower_case", True, 60 | "Whether to lower case the input text. Should be True for uncased " 61 | "models and False for cased models.") 62 | 63 | flags.DEFINE_integer("batch_size", 32, "Batch size for predictions.") 64 | 65 | flags.DEFINE_bool("use_tpu", False, "Whether to use TPU or GPU/CPU.") 66 | 67 | flags.DEFINE_string("master", None, 68 | "If using a TPU, the address of the master.") 69 | 70 | flags.DEFINE_integer( 71 | "num_tpu_cores", 8, 72 | "Only used if `use_tpu` is True. Total number of TPU cores to use.") 73 | 74 | flags.DEFINE_bool( 75 | "use_one_hot_embeddings", False, 76 | "If True, tf.one_hot will be used for embedding lookups, otherwise " 77 | "tf.nn.embedding_lookup will be used. On TPUs, this should be True " 78 | "since it is much faster.") 79 | 80 | 81 | class InputExample(object): 82 | 83 | def __init__(self, unique_id, text_a, text_b): 84 | self.unique_id = unique_id 85 | self.text_a = text_a 86 | self.text_b = text_b 87 | 88 | 89 | class InputFeatures(object): 90 | """A single set of features of data.""" 91 | 92 | def __init__(self, unique_id, tokens, input_ids, input_mask, input_type_ids): 93 | self.unique_id = unique_id 94 | self.tokens = tokens 95 | self.input_ids = input_ids 96 | self.input_mask = input_mask 97 | self.input_type_ids = input_type_ids 98 | 99 | 100 | def input_fn_builder(features, seq_length): 101 | """Creates an `input_fn` closure to be passed to TPUEstimator.""" 102 | 103 | all_unique_ids = [] 104 | all_input_ids = [] 105 | all_input_mask = [] 106 | all_input_type_ids = [] 107 | 108 | for feature in features: 109 | all_unique_ids.append(feature.unique_id) 110 | all_input_ids.append(feature.input_ids) 111 | all_input_mask.append(feature.input_mask) 112 | all_input_type_ids.append(feature.input_type_ids) 113 | 114 | def input_fn(params): 115 | """The actual input function.""" 116 | batch_size = params["batch_size"] 117 | 118 | num_examples = len(features) 119 | 120 | # This is for demo purposes and does NOT scale to large data sets. We do 121 | # not use Dataset.from_generator() because that uses tf.py_func which is 122 | # not TPU compatible. The right way to load data is with TFRecordReader. 123 | d = tf.data.Dataset.from_tensor_slices({ 124 | "unique_ids": 125 | tf.constant(all_unique_ids, shape=[num_examples], dtype=tf.int32), 126 | "input_ids": 127 | tf.constant( 128 | all_input_ids, shape=[num_examples, seq_length], 129 | dtype=tf.int32), 130 | "input_mask": 131 | tf.constant( 132 | all_input_mask, 133 | shape=[num_examples, seq_length], 134 | dtype=tf.int32), 135 | "input_type_ids": 136 | tf.constant( 137 | all_input_type_ids, 138 | shape=[num_examples, seq_length], 139 | dtype=tf.int32), 140 | }) 141 | 142 | d = d.batch(batch_size=batch_size, drop_remainder=False) 143 | return d 144 | 145 | return input_fn 146 | 147 | 148 | def model_fn_builder(bert_config, init_checkpoint, layer_indexes, use_tpu, 149 | use_one_hot_embeddings): 150 | """Returns `model_fn` closure for TPUEstimator.""" 151 | 152 | def model_fn(features, labels, mode, params): # pylint: disable=unused-argument 153 | """The `model_fn` for TPUEstimator.""" 154 | 155 | unique_ids = features["unique_ids"] 156 | input_ids = features["input_ids"] 157 | input_mask = features["input_mask"] 158 | input_type_ids = features["input_type_ids"] 159 | 160 | model = modeling.BertModel( 161 | config=bert_config, 162 | is_training=False, 163 | input_ids=input_ids, 164 | input_mask=input_mask, 165 | token_type_ids=input_type_ids, 166 | use_one_hot_embeddings=use_one_hot_embeddings) 167 | 168 | if mode != tf.estimator.ModeKeys.PREDICT: 169 | raise ValueError("Only PREDICT modes are supported: %s" % (mode)) 170 | 171 | tvars = tf.trainable_variables() 172 | scaffold_fn = None 173 | (assignment_map, 174 | initialized_variable_names) = modeling.get_assignment_map_from_checkpoint( 175 | tvars, init_checkpoint) 176 | if use_tpu: 177 | 178 | def tpu_scaffold(): 179 | tf.train.init_from_checkpoint(init_checkpoint, assignment_map) 180 | return tf.train.Scaffold() 181 | 182 | scaffold_fn = tpu_scaffold 183 | else: 184 | tf.train.init_from_checkpoint(init_checkpoint, assignment_map) 185 | 186 | tf.logging.info("**** Trainable Variables ****") 187 | for var in tvars: 188 | init_string = "" 189 | if var.name in initialized_variable_names: 190 | init_string = ", *INIT_FROM_CKPT*" 191 | tf.logging.info(" name = %s, shape = %s%s", var.name, var.shape, 192 | init_string) 193 | 194 | all_layers = model.get_all_encoder_layers() 195 | 196 | predictions = { 197 | "unique_id": unique_ids, 198 | } 199 | 200 | for (i, layer_index) in enumerate(layer_indexes): 201 | predictions["layer_output_%d" % i] = all_layers[layer_index] 202 | 203 | output_spec = tf.contrib.tpu.TPUEstimatorSpec( 204 | mode=mode, predictions=predictions, scaffold_fn=scaffold_fn) 205 | return output_spec 206 | 207 | return model_fn 208 | 209 | 210 | def convert_examples_to_features(examples, seq_length, tokenizer): 211 | """Loads a data file into a list of `InputBatch`s.""" 212 | 213 | features = [] 214 | for (ex_index, example) in enumerate(examples): 215 | tokens_a = tokenizer.tokenize(example.text_a) 216 | 217 | tokens_b = None 218 | if example.text_b: 219 | tokens_b = tokenizer.tokenize(example.text_b) 220 | 221 | if tokens_b: 222 | # Modifies `tokens_a` and `tokens_b` in place so that the total 223 | # length is less than the specified length. 224 | # Account for [CLS], [SEP], [SEP] with "- 3" 225 | _truncate_seq_pair(tokens_a, tokens_b, seq_length - 3) 226 | else: 227 | # Account for [CLS] and [SEP] with "- 2" 228 | if len(tokens_a) > seq_length - 2: 229 | tokens_a = tokens_a[0:(seq_length - 2)] 230 | 231 | # The convention in BERT is: 232 | # (a) For sequence pairs: 233 | # tokens: [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP] 234 | # type_ids: 0 0 0 0 0 0 0 0 1 1 1 1 1 1 235 | # (b) For single sequences: 236 | # tokens: [CLS] the dog is hairy . [SEP] 237 | # type_ids: 0 0 0 0 0 0 0 238 | # 239 | # Where "type_ids" are used to indicate whether this is the first 240 | # sequence or the second sequence. The embedding vectors for `type=0` and 241 | # `type=1` were learned during pre-training and are added to the wordpiece 242 | # embedding vector (and position vector). This is not *strictly* necessary 243 | # since the [SEP] token unambiguously separates the sequences, but it makes 244 | # it easier for the model to learn the concept of sequences. 245 | # 246 | # For classification tasks, the first vector (corresponding to [CLS]) is 247 | # used as as the "sentence vector". Note that this only makes sense because 248 | # the entire model is fine-tuned. 249 | tokens = [] 250 | input_type_ids = [] 251 | tokens.append("[CLS]") 252 | input_type_ids.append(0) 253 | for token in tokens_a: 254 | tokens.append(token) 255 | input_type_ids.append(0) 256 | tokens.append("[SEP]") 257 | input_type_ids.append(0) 258 | 259 | if tokens_b: 260 | for token in tokens_b: 261 | tokens.append(token) 262 | input_type_ids.append(1) 263 | tokens.append("[SEP]") 264 | input_type_ids.append(1) 265 | 266 | input_ids = tokenizer.convert_tokens_to_ids(tokens) 267 | 268 | # The mask has 1 for real tokens and 0 for padding tokens. Only real 269 | # tokens are attended to. 270 | input_mask = [1] * len(input_ids) 271 | 272 | # Zero-pad up to the sequence length. 273 | while len(input_ids) < seq_length: 274 | input_ids.append(0) 275 | input_mask.append(0) 276 | input_type_ids.append(0) 277 | 278 | assert len(input_ids) == seq_length 279 | assert len(input_mask) == seq_length 280 | assert len(input_type_ids) == seq_length 281 | 282 | if ex_index < 5: 283 | tf.logging.info("*** Example ***") 284 | tf.logging.info("unique_id: %s" % (example.unique_id)) 285 | tf.logging.info("tokens: %s" % " ".join( 286 | [tokenization.printable_text(x) for x in tokens])) 287 | tf.logging.info("input_ids: %s" % " ".join([str(x) for x in input_ids])) 288 | tf.logging.info("input_mask: %s" % " ".join([str(x) for x in input_mask])) 289 | tf.logging.info( 290 | "input_type_ids: %s" % " ".join([str(x) for x in input_type_ids])) 291 | 292 | features.append( 293 | InputFeatures( 294 | unique_id=example.unique_id, 295 | tokens=tokens, 296 | input_ids=input_ids, 297 | input_mask=input_mask, 298 | input_type_ids=input_type_ids)) 299 | return features 300 | 301 | 302 | def _truncate_seq_pair(tokens_a, tokens_b, max_length): 303 | """Truncates a sequence pair in place to the maximum length.""" 304 | 305 | # This is a simple heuristic which will always truncate the longer sequence 306 | # one token at a time. This makes more sense than truncating an equal percent 307 | # of tokens from each, since if one sequence is very short then each token 308 | # that's truncated likely contains more information than a longer sequence. 309 | while True: 310 | total_length = len(tokens_a) + len(tokens_b) 311 | if total_length <= max_length: 312 | break 313 | if len(tokens_a) > len(tokens_b): 314 | tokens_a.pop() 315 | else: 316 | tokens_b.pop() 317 | 318 | 319 | def read_examples(input_file): 320 | """Read a list of `InputExample`s from an input file.""" 321 | examples = [] 322 | unique_id = 0 323 | with tf.gfile.GFile(input_file, "r") as reader: 324 | while True: 325 | line = tokenization.convert_to_unicode(reader.readline()) 326 | if not line: 327 | break 328 | line = line.strip() 329 | text_a = None 330 | text_b = None 331 | m = re.match(r"^(.*) \|\|\| (.*)$", line) 332 | if m is None: 333 | text_a = line 334 | else: 335 | text_a = m.group(1) 336 | text_b = m.group(2) 337 | examples.append( 338 | InputExample(unique_id=unique_id, text_a=text_a, text_b=text_b)) 339 | unique_id += 1 340 | return examples 341 | 342 | 343 | def main(_): 344 | tf.logging.set_verbosity(tf.logging.INFO) 345 | 346 | layer_indexes = [int(x) for x in FLAGS.layers.split(",")] 347 | 348 | bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file) 349 | 350 | tokenizer = tokenization.FullTokenizer( 351 | vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case) 352 | 353 | is_per_host = tf.contrib.tpu.InputPipelineConfig.PER_HOST_V2 354 | run_config = tf.contrib.tpu.RunConfig( 355 | master=FLAGS.master, 356 | tpu_config=tf.contrib.tpu.TPUConfig( 357 | num_shards=FLAGS.num_tpu_cores, 358 | per_host_input_for_training=is_per_host)) 359 | 360 | examples = read_examples(FLAGS.input_file) 361 | 362 | features = convert_examples_to_features( 363 | examples=examples, seq_length=FLAGS.max_seq_length, tokenizer=tokenizer) 364 | 365 | unique_id_to_feature = {} 366 | for feature in features: 367 | unique_id_to_feature[feature.unique_id] = feature 368 | 369 | model_fn = model_fn_builder( 370 | bert_config=bert_config, 371 | init_checkpoint=FLAGS.init_checkpoint, 372 | layer_indexes=layer_indexes, 373 | use_tpu=FLAGS.use_tpu, 374 | use_one_hot_embeddings=FLAGS.use_one_hot_embeddings) 375 | 376 | # If TPU is not available, this will fall back to normal Estimator on CPU 377 | # or GPU. 378 | estimator = tf.contrib.tpu.TPUEstimator( 379 | use_tpu=FLAGS.use_tpu, 380 | model_fn=model_fn, 381 | config=run_config, 382 | predict_batch_size=FLAGS.batch_size) 383 | 384 | input_fn = input_fn_builder( 385 | features=features, seq_length=FLAGS.max_seq_length) 386 | 387 | with codecs.getwriter("utf-8")(tf.gfile.Open(FLAGS.output_file, 388 | "w")) as writer: 389 | for result in estimator.predict(input_fn, yield_single_examples=True): 390 | unique_id = int(result["unique_id"]) 391 | feature = unique_id_to_feature[unique_id] 392 | output_json = collections.OrderedDict() 393 | output_json["linex_index"] = unique_id 394 | all_features = [] 395 | for (i, token) in enumerate(feature.tokens): 396 | all_layers = [] 397 | for (j, layer_index) in enumerate(layer_indexes): 398 | layer_output = result["layer_output_%d" % j] 399 | layers = collections.OrderedDict() 400 | layers["index"] = layer_index 401 | layers["values"] = [ 402 | round(float(x), 6) for x in layer_output[i:(i + 1)].flat 403 | ] 404 | all_layers.append(layers) 405 | features = collections.OrderedDict() 406 | features["token"] = token 407 | features["layers"] = all_layers 408 | all_features.append(features) 409 | output_json["features"] = all_features 410 | writer.write(json.dumps(output_json) + "\n") 411 | 412 | 413 | if __name__ == "__main__": 414 | flags.mark_flag_as_required("input_file") 415 | flags.mark_flag_as_required("vocab_file") 416 | flags.mark_flag_as_required("bert_config_file") 417 | flags.mark_flag_as_required("init_checkpoint") 418 | flags.mark_flag_as_required("output_file") 419 | tf.app.run() 420 | -------------------------------------------------------------------------------- /codes/further-pre-training/generate_corpus_agnews.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import spacy 3 | import numpy as np 4 | 5 | nlp=spacy.load("en_core_web_md") 6 | 7 | test_data=pd.read_csv("test.csv",header=None,sep=",").values 8 | train_data=pd.read_csv("train.csv",header=None,sep=",").values 9 | 10 | test=[] 11 | train=[] 12 | with open("AGnews_corpus_test.txt","w",encoding="utf-8") as f_test: 13 | with open("AGnews_corpus_train.txt", "w", encoding="utf-8") as f_train: 14 | with open("AGnews_corpus.txt","w",encoding="utf-8") as f: 15 | for i in range(len(test_data)): 16 | if i%1000==0:print(i) 17 | f.write(str(test_data[i][1])+"\n") 18 | f_test.write(str(test_data[i][1])+"\n") 19 | document=nlp(str(test_data[i][2])) 20 | number=0 21 | for sent in document.sents: 22 | number+=1 23 | f.write(str(sent)+"\n") 24 | f_test.write(str(sent)+"\n") 25 | test.append(number) 26 | f.write("\n") 27 | f_test.write("\n") 28 | 29 | for i in range(len(train_data)): 30 | if i%1000==0:print(i) 31 | f.write(str(train_data[i][1])+"\n") 32 | f_train.write(str(train_data[i][1])+"\n") 33 | document=nlp(str(train_data[i][2])) 34 | number=0 35 | for sent in document.sents: 36 | number+=1 37 | f.write(str(sent)+"\n") 38 | f_train.write(str(sent) + "\n") 39 | train.append(number) 40 | f.write("\n") 41 | f_train.write("\n") 42 | 43 | 44 | print("test_max=",np.max(test)) 45 | print("test_avg=",np.average(test)) 46 | print() 47 | print("train_max=",np.max(train)) 48 | print("train_avg=",np.average(train)) 49 | 50 | 51 | -------------------------------------------------------------------------------- /codes/further-pre-training/optimization.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Functions and classes related to optimization (weight updates).""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import re 22 | import tensorflow as tf 23 | 24 | 25 | def create_optimizer(loss, init_lr, num_train_steps, num_warmup_steps, use_tpu): 26 | """Creates an optimizer training op.""" 27 | global_step = tf.train.get_or_create_global_step() 28 | 29 | learning_rate = tf.constant(value=init_lr, shape=[], dtype=tf.float32) 30 | 31 | # Implements linear decay of the learning rate. 32 | learning_rate = tf.train.polynomial_decay( 33 | learning_rate, 34 | global_step, 35 | num_train_steps, 36 | end_learning_rate=0.0, 37 | power=1.0, 38 | cycle=False) 39 | 40 | # Implements linear warmup. I.e., if global_step < num_warmup_steps, the 41 | # learning rate will be `global_step/num_warmup_steps * init_lr`. 42 | if num_warmup_steps: 43 | global_steps_int = tf.cast(global_step, tf.int32) 44 | warmup_steps_int = tf.constant(num_warmup_steps, dtype=tf.int32) 45 | 46 | global_steps_float = tf.cast(global_steps_int, tf.float32) 47 | warmup_steps_float = tf.cast(warmup_steps_int, tf.float32) 48 | 49 | warmup_percent_done = global_steps_float / warmup_steps_float 50 | warmup_learning_rate = init_lr * warmup_percent_done 51 | 52 | is_warmup = tf.cast(global_steps_int < warmup_steps_int, tf.float32) 53 | learning_rate = ( 54 | (1.0 - is_warmup) * learning_rate + is_warmup * warmup_learning_rate) 55 | 56 | # It is recommended that you use this optimizer for fine tuning, since this 57 | # is how the model was trained (note that the Adam m/v variables are NOT 58 | # loaded from init_checkpoint.) 59 | optimizer = AdamWeightDecayOptimizer( 60 | learning_rate=learning_rate, 61 | weight_decay_rate=0.01, 62 | beta_1=0.9, 63 | beta_2=0.999, 64 | epsilon=1e-6, 65 | exclude_from_weight_decay=["LayerNorm", "layer_norm", "bias"]) 66 | 67 | if use_tpu: 68 | optimizer = tf.contrib.tpu.CrossShardOptimizer(optimizer) 69 | 70 | tvars = tf.trainable_variables() 71 | grads = tf.gradients(loss, tvars) 72 | 73 | # This is how the model was pre-trained. 74 | (grads, _) = tf.clip_by_global_norm(grads, clip_norm=1.0) 75 | 76 | train_op = optimizer.apply_gradients( 77 | zip(grads, tvars), global_step=global_step) 78 | 79 | # Normally the global step update is done inside of `apply_gradients`. 80 | # However, `AdamWeightDecayOptimizer` doesn't do this. But if you use 81 | # a different optimizer, you should probably take this line out. 82 | new_global_step = global_step + 1 83 | train_op = tf.group(train_op, [global_step.assign(new_global_step)]) 84 | return train_op 85 | 86 | 87 | class AdamWeightDecayOptimizer(tf.train.Optimizer): 88 | """A basic Adam optimizer that includes "correct" L2 weight decay.""" 89 | 90 | def __init__(self, 91 | learning_rate, 92 | weight_decay_rate=0.0, 93 | beta_1=0.9, 94 | beta_2=0.999, 95 | epsilon=1e-6, 96 | exclude_from_weight_decay=None, 97 | name="AdamWeightDecayOptimizer"): 98 | """Constructs a AdamWeightDecayOptimizer.""" 99 | super(AdamWeightDecayOptimizer, self).__init__(False, name) 100 | 101 | self.learning_rate = learning_rate 102 | self.weight_decay_rate = weight_decay_rate 103 | self.beta_1 = beta_1 104 | self.beta_2 = beta_2 105 | self.epsilon = epsilon 106 | self.exclude_from_weight_decay = exclude_from_weight_decay 107 | 108 | def apply_gradients(self, grads_and_vars, global_step=None, name=None): 109 | """See base class.""" 110 | assignments = [] 111 | for (grad, param) in grads_and_vars: 112 | if grad is None or param is None: 113 | continue 114 | 115 | param_name = self._get_variable_name(param.name) 116 | 117 | m = tf.get_variable( 118 | name=param_name + "/adam_m", 119 | shape=param.shape.as_list(), 120 | dtype=tf.float32, 121 | trainable=False, 122 | initializer=tf.zeros_initializer()) 123 | v = tf.get_variable( 124 | name=param_name + "/adam_v", 125 | shape=param.shape.as_list(), 126 | dtype=tf.float32, 127 | trainable=False, 128 | initializer=tf.zeros_initializer()) 129 | 130 | # Standard Adam update. 131 | next_m = ( 132 | tf.multiply(self.beta_1, m) + tf.multiply(1.0 - self.beta_1, grad)) 133 | next_v = ( 134 | tf.multiply(self.beta_2, v) + tf.multiply(1.0 - self.beta_2, 135 | tf.square(grad))) 136 | 137 | update = next_m / (tf.sqrt(next_v) + self.epsilon) 138 | 139 | # Just adding the square of the weights to the loss function is *not* 140 | # the correct way of using L2 regularization/weight decay with Adam, 141 | # since that will interact with the m and v parameters in strange ways. 142 | # 143 | # Instead we want ot decay the weights in a manner that doesn't interact 144 | # with the m/v parameters. This is equivalent to adding the square 145 | # of the weights to the loss with plain (non-momentum) SGD. 146 | if self._do_use_weight_decay(param_name): 147 | update += self.weight_decay_rate * param 148 | 149 | update_with_lr = self.learning_rate * update 150 | 151 | next_param = param - update_with_lr 152 | 153 | assignments.extend( 154 | [param.assign(next_param), 155 | m.assign(next_m), 156 | v.assign(next_v)]) 157 | return tf.group(*assignments, name=name) 158 | 159 | def _do_use_weight_decay(self, param_name): 160 | """Whether to use L2 weight decay for `param_name`.""" 161 | if not self.weight_decay_rate: 162 | return False 163 | if self.exclude_from_weight_decay: 164 | for r in self.exclude_from_weight_decay: 165 | if re.search(r, param_name) is not None: 166 | return False 167 | return True 168 | 169 | def _get_variable_name(self, param_name): 170 | """Get the variable name from the tensor name.""" 171 | m = re.match("^(.*):\\d+$", param_name) 172 | if m is not None: 173 | param_name = m.group(1) 174 | return param_name 175 | -------------------------------------------------------------------------------- /codes/further-pre-training/run_pretraining.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Run masked LM/next sentence masked_lm pre-training for BERT.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import os 22 | import modeling 23 | import optimization 24 | import tensorflow as tf 25 | 26 | flags = tf.flags 27 | 28 | FLAGS = flags.FLAGS 29 | 30 | ## Required parameters 31 | flags.DEFINE_string( 32 | "bert_config_file", None, 33 | "The config json file corresponding to the pre-trained BERT model. " 34 | "This specifies the model architecture.") 35 | 36 | flags.DEFINE_string( 37 | "input_file", None, 38 | "Input TF example files (can be a glob or comma separated).") 39 | 40 | flags.DEFINE_string( 41 | "output_dir", None, 42 | "The output directory where the model checkpoints will be written.") 43 | 44 | ## Other parameters 45 | flags.DEFINE_string( 46 | "init_checkpoint", None, 47 | "Initial checkpoint (usually from a pre-trained BERT model).") 48 | 49 | flags.DEFINE_integer( 50 | "max_seq_length", 128, 51 | "The maximum total input sequence length after WordPiece tokenization. " 52 | "Sequences longer than this will be truncated, and sequences shorter " 53 | "than this will be padded. Must match data generation.") 54 | 55 | flags.DEFINE_integer( 56 | "max_predictions_per_seq", 20, 57 | "Maximum number of masked LM predictions per sequence. " 58 | "Must match data generation.") 59 | 60 | flags.DEFINE_bool("do_train", False, "Whether to run training.") 61 | 62 | flags.DEFINE_bool("do_eval", False, "Whether to run eval on the dev set.") 63 | 64 | flags.DEFINE_integer("train_batch_size", 32, "Total batch size for training.") 65 | 66 | flags.DEFINE_integer("eval_batch_size", 8, "Total batch size for eval.") 67 | 68 | flags.DEFINE_float("learning_rate", 5e-5, "The initial learning rate for Adam.") 69 | 70 | flags.DEFINE_integer("num_train_steps", 100000, "Number of training steps.") 71 | 72 | flags.DEFINE_integer("num_warmup_steps", 10000, "Number of warmup steps.") 73 | 74 | flags.DEFINE_integer("save_checkpoints_steps", 1000, 75 | "How often to save the model checkpoint.") 76 | 77 | flags.DEFINE_integer("iterations_per_loop", 1000, 78 | "How many steps to make in each estimator call.") 79 | 80 | flags.DEFINE_integer("max_eval_steps", 100, "Maximum number of eval steps.") 81 | 82 | flags.DEFINE_bool("use_tpu", False, "Whether to use TPU or GPU/CPU.") 83 | 84 | tf.flags.DEFINE_string( 85 | "tpu_name", None, 86 | "The Cloud TPU to use for training. This should be either the name " 87 | "used when creating the Cloud TPU, or a grpc://ip.address.of.tpu:8470 " 88 | "url.") 89 | 90 | tf.flags.DEFINE_string( 91 | "tpu_zone", None, 92 | "[Optional] GCE zone where the Cloud TPU is located in. If not " 93 | "specified, we will attempt to automatically detect the GCE project from " 94 | "metadata.") 95 | 96 | tf.flags.DEFINE_string( 97 | "gcp_project", None, 98 | "[Optional] Project name for the Cloud TPU-enabled project. If not " 99 | "specified, we will attempt to automatically detect the GCE project from " 100 | "metadata.") 101 | 102 | tf.flags.DEFINE_string("master", None, "[Optional] TensorFlow master URL.") 103 | 104 | flags.DEFINE_integer( 105 | "num_tpu_cores", 8, 106 | "Only used if `use_tpu` is True. Total number of TPU cores to use.") 107 | 108 | 109 | def model_fn_builder(bert_config, init_checkpoint, learning_rate, 110 | num_train_steps, num_warmup_steps, use_tpu, 111 | use_one_hot_embeddings): 112 | """Returns `model_fn` closure for TPUEstimator.""" 113 | 114 | def model_fn(features, labels, mode, params): # pylint: disable=unused-argument 115 | """The `model_fn` for TPUEstimator.""" 116 | 117 | tf.logging.info("*** Features ***") 118 | for name in sorted(features.keys()): 119 | tf.logging.info(" name = %s, shape = %s" % (name, features[name].shape)) 120 | 121 | input_ids = features["input_ids"] 122 | input_mask = features["input_mask"] 123 | segment_ids = features["segment_ids"] 124 | masked_lm_positions = features["masked_lm_positions"] 125 | masked_lm_ids = features["masked_lm_ids"] 126 | masked_lm_weights = features["masked_lm_weights"] 127 | next_sentence_labels = features["next_sentence_labels"] 128 | 129 | is_training = (mode == tf.estimator.ModeKeys.TRAIN) 130 | 131 | model = modeling.BertModel( 132 | config=bert_config, 133 | is_training=is_training, 134 | input_ids=input_ids, 135 | input_mask=input_mask, 136 | token_type_ids=segment_ids, 137 | use_one_hot_embeddings=use_one_hot_embeddings) 138 | 139 | (masked_lm_loss, 140 | masked_lm_example_loss, masked_lm_log_probs) = get_masked_lm_output( 141 | bert_config, model.get_sequence_output(), model.get_embedding_table(), 142 | masked_lm_positions, masked_lm_ids, masked_lm_weights) 143 | 144 | (next_sentence_loss, next_sentence_example_loss, 145 | next_sentence_log_probs) = get_next_sentence_output( 146 | bert_config, model.get_pooled_output(), next_sentence_labels) 147 | 148 | total_loss = masked_lm_loss + next_sentence_loss 149 | 150 | tvars = tf.trainable_variables() 151 | 152 | initialized_variable_names = {} 153 | scaffold_fn = None 154 | if init_checkpoint: 155 | (assignment_map, initialized_variable_names 156 | ) = modeling.get_assignment_map_from_checkpoint(tvars, init_checkpoint) 157 | if use_tpu: 158 | 159 | def tpu_scaffold(): 160 | tf.train.init_from_checkpoint(init_checkpoint, assignment_map) 161 | return tf.train.Scaffold() 162 | 163 | scaffold_fn = tpu_scaffold 164 | else: 165 | tf.train.init_from_checkpoint(init_checkpoint, assignment_map) 166 | 167 | tf.logging.info("**** Trainable Variables ****") 168 | for var in tvars: 169 | init_string = "" 170 | if var.name in initialized_variable_names: 171 | init_string = ", *INIT_FROM_CKPT*" 172 | tf.logging.info(" name = %s, shape = %s%s", var.name, var.shape, 173 | init_string) 174 | 175 | output_spec = None 176 | if mode == tf.estimator.ModeKeys.TRAIN: 177 | train_op = optimization.create_optimizer( 178 | total_loss, learning_rate, num_train_steps, num_warmup_steps, use_tpu) 179 | 180 | output_spec = tf.contrib.tpu.TPUEstimatorSpec( 181 | mode=mode, 182 | loss=total_loss, 183 | train_op=train_op, 184 | scaffold_fn=scaffold_fn) 185 | elif mode == tf.estimator.ModeKeys.EVAL: 186 | 187 | def metric_fn(masked_lm_example_loss, masked_lm_log_probs, masked_lm_ids, 188 | masked_lm_weights, next_sentence_example_loss, 189 | next_sentence_log_probs, next_sentence_labels): 190 | """Computes the loss and accuracy of the model.""" 191 | masked_lm_log_probs = tf.reshape(masked_lm_log_probs, 192 | [-1, masked_lm_log_probs.shape[-1]]) 193 | masked_lm_predictions = tf.argmax( 194 | masked_lm_log_probs, axis=-1, output_type=tf.int32) 195 | masked_lm_example_loss = tf.reshape(masked_lm_example_loss, [-1]) 196 | masked_lm_ids = tf.reshape(masked_lm_ids, [-1]) 197 | masked_lm_weights = tf.reshape(masked_lm_weights, [-1]) 198 | masked_lm_accuracy = tf.metrics.accuracy( 199 | labels=masked_lm_ids, 200 | predictions=masked_lm_predictions, 201 | weights=masked_lm_weights) 202 | masked_lm_mean_loss = tf.metrics.mean( 203 | values=masked_lm_example_loss, weights=masked_lm_weights) 204 | 205 | next_sentence_log_probs = tf.reshape( 206 | next_sentence_log_probs, [-1, next_sentence_log_probs.shape[-1]]) 207 | next_sentence_predictions = tf.argmax( 208 | next_sentence_log_probs, axis=-1, output_type=tf.int32) 209 | next_sentence_labels = tf.reshape(next_sentence_labels, [-1]) 210 | next_sentence_accuracy = tf.metrics.accuracy( 211 | labels=next_sentence_labels, predictions=next_sentence_predictions) 212 | next_sentence_mean_loss = tf.metrics.mean( 213 | values=next_sentence_example_loss) 214 | 215 | return { 216 | "masked_lm_accuracy": masked_lm_accuracy, 217 | "masked_lm_loss": masked_lm_mean_loss, 218 | "next_sentence_accuracy": next_sentence_accuracy, 219 | "next_sentence_loss": next_sentence_mean_loss, 220 | } 221 | 222 | eval_metrics = (metric_fn, [ 223 | masked_lm_example_loss, masked_lm_log_probs, masked_lm_ids, 224 | masked_lm_weights, next_sentence_example_loss, 225 | next_sentence_log_probs, next_sentence_labels 226 | ]) 227 | output_spec = tf.contrib.tpu.TPUEstimatorSpec( 228 | mode=mode, 229 | loss=total_loss, 230 | eval_metrics=eval_metrics, 231 | scaffold_fn=scaffold_fn) 232 | else: 233 | raise ValueError("Only TRAIN and EVAL modes are supported: %s" % (mode)) 234 | 235 | return output_spec 236 | 237 | return model_fn 238 | 239 | 240 | def get_masked_lm_output(bert_config, input_tensor, output_weights, positions, 241 | label_ids, label_weights): 242 | """Get loss and log probs for the masked LM.""" 243 | input_tensor = gather_indexes(input_tensor, positions) 244 | 245 | with tf.variable_scope("cls/predictions"): 246 | # We apply one more non-linear transformation before the output layer. 247 | # This matrix is not used after pre-training. 248 | with tf.variable_scope("transform"): 249 | input_tensor = tf.layers.dense( 250 | input_tensor, 251 | units=bert_config.hidden_size, 252 | activation=modeling.get_activation(bert_config.hidden_act), 253 | kernel_initializer=modeling.create_initializer( 254 | bert_config.initializer_range)) 255 | input_tensor = modeling.layer_norm(input_tensor) 256 | 257 | # The output weights are the same as the input embeddings, but there is 258 | # an output-only bias for each token. 259 | output_bias = tf.get_variable( 260 | "output_bias", 261 | shape=[bert_config.vocab_size], 262 | initializer=tf.zeros_initializer()) 263 | logits = tf.matmul(input_tensor, output_weights, transpose_b=True) 264 | logits = tf.nn.bias_add(logits, output_bias) 265 | log_probs = tf.nn.log_softmax(logits, axis=-1) 266 | 267 | label_ids = tf.reshape(label_ids, [-1]) 268 | label_weights = tf.reshape(label_weights, [-1]) 269 | 270 | one_hot_labels = tf.one_hot( 271 | label_ids, depth=bert_config.vocab_size, dtype=tf.float32) 272 | 273 | # The `positions` tensor might be zero-padded (if the sequence is too 274 | # short to have the maximum number of predictions). The `label_weights` 275 | # tensor has a value of 1.0 for every real prediction and 0.0 for the 276 | # padding predictions. 277 | per_example_loss = -tf.reduce_sum(log_probs * one_hot_labels, axis=[-1]) 278 | numerator = tf.reduce_sum(label_weights * per_example_loss) 279 | denominator = tf.reduce_sum(label_weights) + 1e-5 280 | loss = numerator / denominator 281 | 282 | return (loss, per_example_loss, log_probs) 283 | 284 | 285 | def get_next_sentence_output(bert_config, input_tensor, labels): 286 | """Get loss and log probs for the next sentence prediction.""" 287 | 288 | # Simple binary classification. Note that 0 is "next sentence" and 1 is 289 | # "random sentence". This weight matrix is not used after pre-training. 290 | with tf.variable_scope("cls/seq_relationship"): 291 | output_weights = tf.get_variable( 292 | "output_weights", 293 | shape=[2, bert_config.hidden_size], 294 | initializer=modeling.create_initializer(bert_config.initializer_range)) 295 | output_bias = tf.get_variable( 296 | "output_bias", shape=[2], initializer=tf.zeros_initializer()) 297 | 298 | logits = tf.matmul(input_tensor, output_weights, transpose_b=True) 299 | logits = tf.nn.bias_add(logits, output_bias) 300 | log_probs = tf.nn.log_softmax(logits, axis=-1) 301 | labels = tf.reshape(labels, [-1]) 302 | one_hot_labels = tf.one_hot(labels, depth=2, dtype=tf.float32) 303 | per_example_loss = -tf.reduce_sum(one_hot_labels * log_probs, axis=-1) 304 | loss = tf.reduce_mean(per_example_loss) 305 | return (loss, per_example_loss, log_probs) 306 | 307 | 308 | def gather_indexes(sequence_tensor, positions): 309 | """Gathers the vectors at the specific positions over a minibatch.""" 310 | sequence_shape = modeling.get_shape_list(sequence_tensor, expected_rank=3) 311 | batch_size = sequence_shape[0] 312 | seq_length = sequence_shape[1] 313 | width = sequence_shape[2] 314 | 315 | flat_offsets = tf.reshape( 316 | tf.range(0, batch_size, dtype=tf.int32) * seq_length, [-1, 1]) 317 | flat_positions = tf.reshape(positions + flat_offsets, [-1]) 318 | flat_sequence_tensor = tf.reshape(sequence_tensor, 319 | [batch_size * seq_length, width]) 320 | output_tensor = tf.gather(flat_sequence_tensor, flat_positions) 321 | return output_tensor 322 | 323 | 324 | def input_fn_builder(input_files, 325 | max_seq_length, 326 | max_predictions_per_seq, 327 | is_training, 328 | num_cpu_threads=4): 329 | """Creates an `input_fn` closure to be passed to TPUEstimator.""" 330 | 331 | def input_fn(params): 332 | """The actual input function.""" 333 | batch_size = params["batch_size"] 334 | 335 | name_to_features = { 336 | "input_ids": 337 | tf.FixedLenFeature([max_seq_length], tf.int64), 338 | "input_mask": 339 | tf.FixedLenFeature([max_seq_length], tf.int64), 340 | "segment_ids": 341 | tf.FixedLenFeature([max_seq_length], tf.int64), 342 | "masked_lm_positions": 343 | tf.FixedLenFeature([max_predictions_per_seq], tf.int64), 344 | "masked_lm_ids": 345 | tf.FixedLenFeature([max_predictions_per_seq], tf.int64), 346 | "masked_lm_weights": 347 | tf.FixedLenFeature([max_predictions_per_seq], tf.float32), 348 | "next_sentence_labels": 349 | tf.FixedLenFeature([1], tf.int64), 350 | } 351 | 352 | # For training, we want a lot of parallel reading and shuffling. 353 | # For eval, we want no shuffling and parallel reading doesn't matter. 354 | if is_training: 355 | d = tf.data.Dataset.from_tensor_slices(tf.constant(input_files)) 356 | d = d.repeat() 357 | d = d.shuffle(buffer_size=len(input_files)) 358 | 359 | # `cycle_length` is the number of parallel files that get read. 360 | cycle_length = min(num_cpu_threads, len(input_files)) 361 | 362 | # `sloppy` mode means that the interleaving is not exact. This adds 363 | # even more randomness to the training pipeline. 364 | d = d.apply( 365 | tf.contrib.data.parallel_interleave( 366 | tf.data.TFRecordDataset, 367 | sloppy=is_training, 368 | cycle_length=cycle_length)) 369 | d = d.shuffle(buffer_size=100) 370 | else: 371 | d = tf.data.TFRecordDataset(input_files) 372 | # Since we evaluate for a fixed number of steps we don't want to encounter 373 | # out-of-range exceptions. 374 | d = d.repeat() 375 | 376 | # We must `drop_remainder` on training because the TPU requires fixed 377 | # size dimensions. For eval, we assume we are evaluating on the CPU or GPU 378 | # and we *don't* want to drop the remainder, otherwise we wont cover 379 | # every sample. 380 | d = d.apply( 381 | tf.contrib.data.map_and_batch( 382 | lambda record: _decode_record(record, name_to_features), 383 | batch_size=batch_size, 384 | num_parallel_batches=num_cpu_threads, 385 | drop_remainder=True)) 386 | return d 387 | 388 | return input_fn 389 | 390 | 391 | def _decode_record(record, name_to_features): 392 | """Decodes a record to a TensorFlow example.""" 393 | example = tf.parse_single_example(record, name_to_features) 394 | 395 | # tf.Example only supports tf.int64, but the TPU only supports tf.int32. 396 | # So cast all int64 to int32. 397 | for name in list(example.keys()): 398 | t = example[name] 399 | if t.dtype == tf.int64: 400 | t = tf.to_int32(t) 401 | example[name] = t 402 | 403 | return example 404 | 405 | 406 | def main(_): 407 | tf.logging.set_verbosity(tf.logging.INFO) 408 | 409 | if not FLAGS.do_train and not FLAGS.do_eval: 410 | raise ValueError("At least one of `do_train` or `do_eval` must be True.") 411 | 412 | bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file) 413 | 414 | tf.gfile.MakeDirs(FLAGS.output_dir) 415 | 416 | input_files = [] 417 | for input_pattern in FLAGS.input_file.split(","): 418 | input_files.extend(tf.gfile.Glob(input_pattern)) 419 | 420 | tf.logging.info("*** Input Files ***") 421 | for input_file in input_files: 422 | tf.logging.info(" %s" % input_file) 423 | 424 | tpu_cluster_resolver = None 425 | if FLAGS.use_tpu and FLAGS.tpu_name: 426 | tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver( 427 | FLAGS.tpu_name, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project) 428 | 429 | is_per_host = tf.contrib.tpu.InputPipelineConfig.PER_HOST_V2 430 | run_config = tf.contrib.tpu.RunConfig( 431 | cluster=tpu_cluster_resolver, 432 | master=FLAGS.master, 433 | model_dir=FLAGS.output_dir, 434 | save_checkpoints_steps=FLAGS.save_checkpoints_steps, 435 | tpu_config=tf.contrib.tpu.TPUConfig( 436 | iterations_per_loop=FLAGS.iterations_per_loop, 437 | num_shards=FLAGS.num_tpu_cores, 438 | per_host_input_for_training=is_per_host)) 439 | 440 | model_fn = model_fn_builder( 441 | bert_config=bert_config, 442 | init_checkpoint=FLAGS.init_checkpoint, 443 | learning_rate=FLAGS.learning_rate, 444 | num_train_steps=FLAGS.num_train_steps, 445 | num_warmup_steps=FLAGS.num_warmup_steps, 446 | use_tpu=FLAGS.use_tpu, 447 | use_one_hot_embeddings=FLAGS.use_tpu) 448 | 449 | # If TPU is not available, this will fall back to normal Estimator on CPU 450 | # or GPU. 451 | estimator = tf.contrib.tpu.TPUEstimator( 452 | use_tpu=FLAGS.use_tpu, 453 | model_fn=model_fn, 454 | config=run_config, 455 | train_batch_size=FLAGS.train_batch_size, 456 | eval_batch_size=FLAGS.eval_batch_size) 457 | 458 | if FLAGS.do_train: 459 | tf.logging.info("***** Running training *****") 460 | tf.logging.info(" Batch size = %d", FLAGS.train_batch_size) 461 | train_input_fn = input_fn_builder( 462 | input_files=input_files, 463 | max_seq_length=FLAGS.max_seq_length, 464 | max_predictions_per_seq=FLAGS.max_predictions_per_seq, 465 | is_training=True) 466 | estimator.train(input_fn=train_input_fn, max_steps=FLAGS.num_train_steps) 467 | 468 | if FLAGS.do_eval: 469 | tf.logging.info("***** Running evaluation *****") 470 | tf.logging.info(" Batch size = %d", FLAGS.eval_batch_size) 471 | 472 | eval_input_fn = input_fn_builder( 473 | input_files=input_files, 474 | max_seq_length=FLAGS.max_seq_length, 475 | max_predictions_per_seq=FLAGS.max_predictions_per_seq, 476 | is_training=False) 477 | 478 | result = estimator.evaluate( 479 | input_fn=eval_input_fn, steps=FLAGS.max_eval_steps) 480 | 481 | output_eval_file = os.path.join(FLAGS.output_dir, "eval_results.txt") 482 | with tf.gfile.GFile(output_eval_file, "w") as writer: 483 | tf.logging.info("***** Eval results *****") 484 | for key in sorted(result.keys()): 485 | tf.logging.info(" %s = %s", key, str(result[key])) 486 | writer.write("%s = %s\n" % (key, str(result[key]))) 487 | 488 | 489 | if __name__ == "__main__": 490 | flags.mark_flag_as_required("input_file") 491 | flags.mark_flag_as_required("bert_config_file") 492 | flags.mark_flag_as_required("output_dir") 493 | tf.app.run() 494 | -------------------------------------------------------------------------------- /codes/further-pre-training/tokenization.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Tokenization classes.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import collections 22 | import re 23 | import unicodedata 24 | import six 25 | import tensorflow as tf 26 | 27 | 28 | def validate_case_matches_checkpoint(do_lower_case, init_checkpoint): 29 | """Checks whether the casing config is consistent with the checkpoint name.""" 30 | 31 | # The casing has to be passed in by the user and there is no explicit check 32 | # as to whether it matches the checkpoint. The casing information probably 33 | # should have been stored in the bert_config.json file, but it's not, so 34 | # we have to heuristically detect it to validate. 35 | 36 | if not init_checkpoint: 37 | return 38 | 39 | m = re.match("^.*?([A-Za-z0-9_-]+)/bert_model.ckpt", init_checkpoint) 40 | if m is None: 41 | return 42 | 43 | model_name = m.group(1) 44 | 45 | lower_models = [ 46 | "uncased_L-24_H-1024_A-16", "uncased_L-12_H-768_A-12", 47 | "multilingual_L-12_H-768_A-12", "chinese_L-12_H-768_A-12" 48 | ] 49 | 50 | cased_models = [ 51 | "cased_L-12_H-768_A-12", "cased_L-24_H-1024_A-16", 52 | "multi_cased_L-12_H-768_A-12" 53 | ] 54 | 55 | is_bad_config = False 56 | if model_name in lower_models and not do_lower_case: 57 | is_bad_config = True 58 | actual_flag = "False" 59 | case_name = "lowercased" 60 | opposite_flag = "True" 61 | 62 | if model_name in cased_models and do_lower_case: 63 | is_bad_config = True 64 | actual_flag = "True" 65 | case_name = "cased" 66 | opposite_flag = "False" 67 | 68 | if is_bad_config: 69 | raise ValueError( 70 | "You passed in `--do_lower_case=%s` with `--init_checkpoint=%s`. " 71 | "However, `%s` seems to be a %s model, so you " 72 | "should pass in `--do_lower_case=%s` so that the fine-tuning matches " 73 | "how the model was pre-training. If this error is wrong, please " 74 | "just comment out this check." % (actual_flag, init_checkpoint, 75 | model_name, case_name, opposite_flag)) 76 | 77 | 78 | def convert_to_unicode(text): 79 | """Converts `text` to Unicode (if it's not already), assuming utf-8 input.""" 80 | if six.PY3: 81 | if isinstance(text, str): 82 | return text 83 | elif isinstance(text, bytes): 84 | return text.decode("utf-8", "ignore") 85 | else: 86 | raise ValueError("Unsupported string type: %s" % (type(text))) 87 | elif six.PY2: 88 | if isinstance(text, str): 89 | return text.decode("utf-8", "ignore") 90 | elif isinstance(text, unicode): 91 | return text 92 | else: 93 | raise ValueError("Unsupported string type: %s" % (type(text))) 94 | else: 95 | raise ValueError("Not running on Python2 or Python 3?") 96 | 97 | 98 | def printable_text(text): 99 | """Returns text encoded in a way suitable for print or `tf.logging`.""" 100 | 101 | # These functions want `str` for both Python2 and Python3, but in one case 102 | # it's a Unicode string and in the other it's a byte string. 103 | if six.PY3: 104 | if isinstance(text, str): 105 | return text 106 | elif isinstance(text, bytes): 107 | return text.decode("utf-8", "ignore") 108 | else: 109 | raise ValueError("Unsupported string type: %s" % (type(text))) 110 | elif six.PY2: 111 | if isinstance(text, str): 112 | return text 113 | elif isinstance(text, unicode): 114 | return text.encode("utf-8") 115 | else: 116 | raise ValueError("Unsupported string type: %s" % (type(text))) 117 | else: 118 | raise ValueError("Not running on Python2 or Python 3?") 119 | 120 | 121 | def load_vocab(vocab_file): 122 | """Loads a vocabulary file into a dictionary.""" 123 | vocab = collections.OrderedDict() 124 | index = 0 125 | with tf.gfile.GFile(vocab_file, "r") as reader: 126 | while True: 127 | token = convert_to_unicode(reader.readline()) 128 | if not token: 129 | break 130 | token = token.strip() 131 | vocab[token] = index 132 | index += 1 133 | return vocab 134 | 135 | 136 | def convert_by_vocab(vocab, items): 137 | """Converts a sequence of [tokens|ids] using the vocab.""" 138 | output = [] 139 | for item in items: 140 | output.append(vocab[item]) 141 | return output 142 | 143 | 144 | def convert_tokens_to_ids(vocab, tokens): 145 | return convert_by_vocab(vocab, tokens) 146 | 147 | 148 | def convert_ids_to_tokens(inv_vocab, ids): 149 | return convert_by_vocab(inv_vocab, ids) 150 | 151 | 152 | def whitespace_tokenize(text): 153 | """Runs basic whitespace cleaning and splitting on a piece of text.""" 154 | text = text.strip() 155 | if not text: 156 | return [] 157 | tokens = text.split() 158 | return tokens 159 | 160 | 161 | class FullTokenizer(object): 162 | """Runs end-to-end tokenziation.""" 163 | 164 | def __init__(self, vocab_file, do_lower_case=True): 165 | self.vocab = load_vocab(vocab_file) 166 | self.inv_vocab = {v: k for k, v in self.vocab.items()} 167 | self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case) 168 | self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab) 169 | 170 | def tokenize(self, text): 171 | split_tokens = [] 172 | for token in self.basic_tokenizer.tokenize(text): 173 | for sub_token in self.wordpiece_tokenizer.tokenize(token): 174 | split_tokens.append(sub_token) 175 | 176 | return split_tokens 177 | 178 | def convert_tokens_to_ids(self, tokens): 179 | return convert_by_vocab(self.vocab, tokens) 180 | 181 | def convert_ids_to_tokens(self, ids): 182 | return convert_by_vocab(self.inv_vocab, ids) 183 | 184 | 185 | class BasicTokenizer(object): 186 | """Runs basic tokenization (punctuation splitting, lower casing, etc.).""" 187 | 188 | def __init__(self, do_lower_case=True): 189 | """Constructs a BasicTokenizer. 190 | 191 | Args: 192 | do_lower_case: Whether to lower case the input. 193 | """ 194 | self.do_lower_case = do_lower_case 195 | 196 | def tokenize(self, text): 197 | """Tokenizes a piece of text.""" 198 | text = convert_to_unicode(text) 199 | text = self._clean_text(text) 200 | 201 | # This was added on November 1st, 2018 for the multilingual and Chinese 202 | # models. This is also applied to the English models now, but it doesn't 203 | # matter since the English models were not trained on any Chinese data 204 | # and generally don't have any Chinese data in them (there are Chinese 205 | # characters in the vocabulary because Wikipedia does have some Chinese 206 | # words in the English Wikipedia.). 207 | text = self._tokenize_chinese_chars(text) 208 | 209 | orig_tokens = whitespace_tokenize(text) 210 | split_tokens = [] 211 | for token in orig_tokens: 212 | if self.do_lower_case: 213 | token = token.lower() 214 | token = self._run_strip_accents(token) 215 | split_tokens.extend(self._run_split_on_punc(token)) 216 | 217 | output_tokens = whitespace_tokenize(" ".join(split_tokens)) 218 | return output_tokens 219 | 220 | def _run_strip_accents(self, text): 221 | """Strips accents from a piece of text.""" 222 | text = unicodedata.normalize("NFD", text) 223 | output = [] 224 | for char in text: 225 | cat = unicodedata.category(char) 226 | if cat == "Mn": 227 | continue 228 | output.append(char) 229 | return "".join(output) 230 | 231 | def _run_split_on_punc(self, text): 232 | """Splits punctuation on a piece of text.""" 233 | chars = list(text) 234 | i = 0 235 | start_new_word = True 236 | output = [] 237 | while i < len(chars): 238 | char = chars[i] 239 | if _is_punctuation(char): 240 | output.append([char]) 241 | start_new_word = True 242 | else: 243 | if start_new_word: 244 | output.append([]) 245 | start_new_word = False 246 | output[-1].append(char) 247 | i += 1 248 | 249 | return ["".join(x) for x in output] 250 | 251 | def _tokenize_chinese_chars(self, text): 252 | """Adds whitespace around any CJK character.""" 253 | output = [] 254 | for char in text: 255 | cp = ord(char) 256 | if self._is_chinese_char(cp): 257 | output.append(" ") 258 | output.append(char) 259 | output.append(" ") 260 | else: 261 | output.append(char) 262 | return "".join(output) 263 | 264 | def _is_chinese_char(self, cp): 265 | """Checks whether CP is the codepoint of a CJK character.""" 266 | # This defines a "chinese character" as anything in the CJK Unicode block: 267 | # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block) 268 | # 269 | # Note that the CJK Unicode block is NOT all Japanese and Korean characters, 270 | # despite its name. The modern Korean Hangul alphabet is a different block, 271 | # as is Japanese Hiragana and Katakana. Those alphabets are used to write 272 | # space-separated words, so they are not treated specially and handled 273 | # like the all of the other languages. 274 | if ((cp >= 0x4E00 and cp <= 0x9FFF) or # 275 | (cp >= 0x3400 and cp <= 0x4DBF) or # 276 | (cp >= 0x20000 and cp <= 0x2A6DF) or # 277 | (cp >= 0x2A700 and cp <= 0x2B73F) or # 278 | (cp >= 0x2B740 and cp <= 0x2B81F) or # 279 | (cp >= 0x2B820 and cp <= 0x2CEAF) or 280 | (cp >= 0xF900 and cp <= 0xFAFF) or # 281 | (cp >= 0x2F800 and cp <= 0x2FA1F)): # 282 | return True 283 | 284 | return False 285 | 286 | def _clean_text(self, text): 287 | """Performs invalid character removal and whitespace cleanup on text.""" 288 | output = [] 289 | for char in text: 290 | cp = ord(char) 291 | if cp == 0 or cp == 0xfffd or _is_control(char): 292 | continue 293 | if _is_whitespace(char): 294 | output.append(" ") 295 | else: 296 | output.append(char) 297 | return "".join(output) 298 | 299 | 300 | class WordpieceTokenizer(object): 301 | """Runs WordPiece tokenziation.""" 302 | 303 | def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=200): 304 | self.vocab = vocab 305 | self.unk_token = unk_token 306 | self.max_input_chars_per_word = max_input_chars_per_word 307 | 308 | def tokenize(self, text): 309 | """Tokenizes a piece of text into its word pieces. 310 | 311 | This uses a greedy longest-match-first algorithm to perform tokenization 312 | using the given vocabulary. 313 | 314 | For example: 315 | input = "unaffable" 316 | output = ["un", "##aff", "##able"] 317 | 318 | Args: 319 | text: A single token or whitespace separated tokens. This should have 320 | already been passed through `BasicTokenizer. 321 | 322 | Returns: 323 | A list of wordpiece tokens. 324 | """ 325 | 326 | text = convert_to_unicode(text) 327 | 328 | output_tokens = [] 329 | for token in whitespace_tokenize(text): 330 | chars = list(token) 331 | if len(chars) > self.max_input_chars_per_word: 332 | output_tokens.append(self.unk_token) 333 | continue 334 | 335 | is_bad = False 336 | start = 0 337 | sub_tokens = [] 338 | while start < len(chars): 339 | end = len(chars) 340 | cur_substr = None 341 | while start < end: 342 | substr = "".join(chars[start:end]) 343 | if start > 0: 344 | substr = "##" + substr 345 | if substr in self.vocab: 346 | cur_substr = substr 347 | break 348 | end -= 1 349 | if cur_substr is None: 350 | is_bad = True 351 | break 352 | sub_tokens.append(cur_substr) 353 | start = end 354 | 355 | if is_bad: 356 | output_tokens.append(self.unk_token) 357 | else: 358 | output_tokens.extend(sub_tokens) 359 | return output_tokens 360 | 361 | 362 | def _is_whitespace(char): 363 | """Checks whether `chars` is a whitespace character.""" 364 | # \t, \n, and \r are technically contorl characters but we treat them 365 | # as whitespace since they are generally considered as such. 366 | if char == " " or char == "\t" or char == "\n" or char == "\r": 367 | return True 368 | cat = unicodedata.category(char) 369 | if cat == "Zs": 370 | return True 371 | return False 372 | 373 | 374 | def _is_control(char): 375 | """Checks whether `chars` is a control character.""" 376 | # These are technically control characters but we count them as whitespace 377 | # characters. 378 | if char == "\t" or char == "\n" or char == "\r": 379 | return False 380 | cat = unicodedata.category(char) 381 | if cat.startswith("C"): 382 | return True 383 | return False 384 | 385 | 386 | def _is_punctuation(char): 387 | """Checks whether `chars` is a punctuation character.""" 388 | cp = ord(char) 389 | # We treat all non-letter/number ASCII as punctuation. 390 | # Characters such as "^", "$", and "`" are not in the Unicode 391 | # Punctuation class but we treat them as punctuation anyways, for 392 | # consistency. 393 | if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or 394 | (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)): 395 | return True 396 | cat = unicodedata.category(char) 397 | if cat.startswith("P"): 398 | return True 399 | return False 400 | --------------------------------------------------------------------------------