├── .gitignore ├── .travis.yml ├── LICENSE ├── LICENSE-APACHE ├── README.md ├── requirements.txt ├── setup.py └── tbert ├── __init__.py ├── attention.py ├── bert.py ├── cli ├── __init__.py ├── cmp_jsonl.py ├── convert.py ├── extract_features.py └── run_classifier.py ├── data.py ├── embedding.py ├── gelu.py ├── optimization.py ├── test ├── __init__.py └── test_bert.py ├── tf_util.py └── transformer.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | *.pyc 3 | *.log 4 | .venv 5 | build 6 | dist 7 | *.egg-info 8 | err 9 | .vscode 10 | .pytest_cache/ 11 | data/ 12 | tf/ 13 | to/ 14 | obsolete/ 15 | download_glue_data.py 16 | glue_data/ 17 | *.pickle 18 | *.tsv 19 | build/ 20 | dist/ 21 | .ipynb_checkpoints/ 22 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | language: python 2 | python: 3 | - "3.6" 4 | install: 5 | - pip install -q -r requirements.txt 6 | - (mkdir tf; cd tf; git clone https://github.com/google-research/bert) 7 | script: 8 | - PYTHONPATH=.:tf/bert pytest tbert/test 9 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2017 Innodata Labs 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /LICENSE-APACHE: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | 179 | APPENDIX: How to apply the Apache License to your work. 180 | 181 | To apply the Apache License to your work, attach the following 182 | boilerplate notice, with the fields enclosed by brackets "[]" 183 | replaced with your own identifying information. (Don't include 184 | the brackets!) The text should be enclosed in the appropriate 185 | comment syntax for the file format. We also recommend that a 186 | file or class name and description of purpose be included on the 187 | same "printed page" as the copyright notice for easier 188 | identification within third-party archives. 189 | 190 | Copyright [yyyy] [name of copyright owner] 191 | 192 | Licensed under the Apache License, Version 2.0 (the "License"); 193 | you may not use this file except in compliance with the License. 194 | You may obtain a copy of the License at 195 | 196 | http://www.apache.org/licenses/LICENSE-2.0 197 | 198 | Unless required by applicable law or agreed to in writing, software 199 | distributed under the License is distributed on an "AS IS" BASIS, 200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 201 | See the License for the specific language governing permissions and 202 | limitations under the License. 203 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # tBERT 2 | [![PyPI version](https://badge.fury.io/py/tbert.svg)](https://badge.fury.io/py/tbert) 3 | [![Build Status](https://travis-ci.com/innodatalabs/tbert.svg?branch=master)](https://travis-ci.com/innodatalabs/tbert) 4 | 5 | BERT model converted to PyTorch. 6 | 7 | Please, **do NOT use this repo**, instead use the (better) library from 8 | HuggingFace: https://github.com/huggingface/pytorch-pretrained-BERT.git 9 | 10 | This repo is kept as an example of converting TF model to PyTorch (utilis may be handy in case I need 11 | to do some thing like this again). 12 | 13 | This is a literal port of BERT code from TensorFlow to PyTorch. 14 | See the [original TF BERT repo here](https://github.com/google-research/bert). 15 | 16 | We provide a script to convert TF BERT pre-trained checkpoint to tBERT: `tbert.cli.convert` 17 | 18 | Testing is done to ensure that tBERT code behaves exactly as TF BERT. 19 | 20 | ## License 21 | This work uses MIT license. 22 | 23 | Original code is covered by Apache 2.0 License. 24 | 25 | ## Installation 26 | 27 | Python 3.6 or better is required. 28 | 29 | Easiest way to install is with the `pip`: 30 | ``` 31 | pip install tbert 32 | ``` 33 | Now you can start using tBERT models in your code! 34 | 35 | ## Pre-trained models 36 | Google-trained models, converted to tBERT format. For description of models, see 37 | the [original TF BERT repo here](https://github.com/google-research/bert#pre-trained-models): 38 | 39 | * [Base, Uncased](https://storage.googleapis.com/public.innodatalabs.com/tbert-uncased_L-12_H-768_A-12.zip) 40 | * [Large, Uncased](https://storage.googleapis.com/public.innodatalabs.com/tbert-uncased_L-24_H-1024_A-16.zip) 41 | * [Base, Cased](https://storage.googleapis.com/public.innodatalabs.com/tbert-cased_L-12_H-768_A-12.zip) 42 | * [Large, Cased](https://storage.googleapis.com/public.innodatalabs.com/tbert-cased_L-24_H-1024_A-16.zip) 43 | * [Base, Multilingual Cased (New, recommended)](https://storage.googleapis.com/public.innodatalabs.com/tbert-multi_cased_L-12_H-768_A-12.zip) 44 | * [Base, Multilingual Uncased (Not recommended)](https://storage.googleapis.com/public.innodatalabs.com/tbert-multilingual_L-12_H-768_A-12.zip) 45 | * [Base, Chinese](https://storage.googleapis.com/public.innodatalabs.com/tbert-chinese_L-12_H-768_A-12.zip) 46 | 47 | ## Using tBERT model in your PyTorch code 48 | 49 | ### tbert.bert.Bert 50 | This is the main juice - the Bert transformer. It is a normal PyTorch module. 51 | You can use it stand-alone or in combination with other PyTorch modules. 52 | 53 | ```python 54 | from tbert.bert import Bert 55 | 56 | config = dict( 57 | attention_probs_dropout_prob=0.1, 58 | directionality="bidi", 59 | hidden_act="gelu", 60 | hidden_dropout_prob=0.1, 61 | hidden_size=768, 62 | initializer_range=0.02, 63 | intermediate_size=3072, 64 | max_position_embeddings=512, 65 | num_attention_heads=12, 66 | num_hidden_layers=12, 67 | type_vocab_size=2, 68 | vocab_size=105879 69 | ) 70 | 71 | bert = Bert(config) 72 | # ... should load trained parameters (see below) 73 | 74 | input_ids = torch.LongTensor([[1, 2, 3, 4, 5, 0]]) 75 | input_type_ids = torch.LongTensor([[0, 0, 1, 1, 1, 0]]) 76 | input_mask = torch.LongTensor([[1, 1, 1, 1, 1, 0]]) 77 | 78 | activations = bert(input_ids, input_type_ids, input_mask) 79 | ``` 80 | Returns an array of activations (for each hidden layer). 81 | Typically only the topmost, or few top layers are used. 82 | Each element in the array is a Tensor of shape [B*S, H] 83 | where B is the batch size, S is the sequence length, and H is the 84 | size of the hidden layer. 85 | 86 | ### tbert.bert.BertPooler 87 | This is the Bert transformer with pooling layer on the top. 88 | Convenient for sequence classification tasks. Use is very similar to 89 | that of `tbert.bert.Bert` module: 90 | ```python 91 | from tbert.bert import Bert 92 | 93 | config = dict( 94 | attention_probs_dropout_prob=0.1, 95 | directionality="bidi", 96 | hidden_act="gelu", 97 | hidden_dropout_prob=0.1, 98 | hidden_size=768, 99 | initializer_range=0.02, 100 | intermediate_size=3072, 101 | max_position_embeddings=512, 102 | num_attention_heads=12, 103 | num_hidden_layers=12, 104 | type_vocab_size=2, 105 | vocab_size=105879 106 | ) 107 | 108 | bert_pooler = BertPooler(config) 109 | # ... should load trained parameters (see below) 110 | 111 | input_ids = torch.LongTensor([[1, 2, 3, 4, 5, 0]]) 112 | input_type_ids = torch.LongTensor([[0, 0, 1, 1, 1, 0]]) 113 | input_mask = torch.LongTensor([[1, 1, 1, 1, 1, 0]]) 114 | 115 | activation = bert_pooler(input_ids, input_type_ids, input_mask) 116 | ``` 117 | Returns a single tensor of size [B, H], where 118 | B is the batch size, and H is the size of the hidden layer. 119 | 120 | ### Programmatically loading pre-trained weights 121 | To initialize `tbert.bert.Bert` or `tbert.bert.BertPooler` from pre-trained 122 | saved checkpoint: 123 | ``` 124 | ... 125 | bert = Bert(config) 126 | bert.load_pretrained(dir_name) 127 | ``` 128 | Here, `dir_name` should be a directory containing pre-trained tBIRT model, 129 | with `bert_model.pickle` and `pooler_model.pickle` files. See below to learn how 130 | to convert published TF BERT pre-trained models to tBERT format. 131 | 132 | Similarly, `load_pretrained` method can be used on `tbert.bert.BertPooler` 133 | instance. 134 | 135 | ## Installing optional dependencies 136 | Optional deps are needed to use CLI utilities: 137 | * to convert TF BERT checkpoint to tBERT format 138 | * to extract features from a sequence 139 | * to run training of a classifier 140 | 141 | ``` 142 | pip install -r requirements.txt 143 | mkdir tf 144 | cd tf 145 | git clone https://github.com/google-research/bert 146 | cd .. 147 | export PYTHONPATH=.:tf/bert 148 | ``` 149 | 150 | Now all is set up: 151 | ``` 152 | python -m tbert.cli.extract_features --help 153 | python -m tbert.cli.convert --help 154 | python -m tbert.cli.run_classifier --help 155 | ``` 156 | 157 | ## Running unit tests 158 | ``` 159 | pip install pytest 160 | pytest tbert/test 161 | ``` 162 | 163 | ## Converting TF BERT pre-trained checkpoint to tBERT 164 | 165 | * Download TF BERT checkpoint and unzip it 166 | ``` 167 | mkdir data 168 | cd data 169 | wget https://storage.googleapis.com/bert_models/2018_11_03/multilingual_L-12_H-768_A-12.zip 170 | unzip multilingual_L-12_H-768_A-12.zip 171 | cd .. 172 | ``` 173 | * Run the converter 174 | ``` 175 | python -m tbert.cli.convert \ 176 | data/multilingual_L-12_H-768_A-12 \ 177 | data/tbert-multilingual_L-12_H-768_A-12 178 | ``` 179 | 180 | ## Extracting features 181 | 182 | Make sure that you have pre-trained tBERT model (see section above). 183 | 184 | ``` 185 | echo "Who was Jim Henson ? ||| Jim Henson was a puppeteer" > /tmp/input.txt 186 | echo "Empty answer is cool!" >> /tmp/input.txt 187 | 188 | python -m tbert.cli.extract_features \ 189 | /tmp/input.txt \ 190 | /tmp/output-tbert.jsonl \ 191 | data/tbert-multilingual_L-12_H-768_A-12 192 | ``` 193 | 194 | ## Comparing TF BERT and tBERT results 195 | 196 | Run TF BERT `extract_features`: 197 | ``` 198 | echo "Who was Jim Henson ? ||| Jim Henson was a puppeteer" > /tmp/input.txt 199 | echo "Empty answer is cool!" >> /tmp/input.txt 200 | 201 | export BERT_BASE_DIR=data/multilingual_L-12_H-768_A-12 202 | 203 | python -m extract_features \ 204 | --input_file=/tmp/input.txt \ 205 | --output_file=/tmp/output-tf.jsonl \ 206 | --vocab_file=$BERT_BASE_DIR/vocab.txt \ 207 | --bert_config_file=$BERT_BASE_DIR/bert_config.json \ 208 | --init_checkpoint=$BERT_BASE_DIR/bert_model.ckpt \ 209 | --layers=-1,-2,-3,-4 \ 210 | --max_seq_length=128 \ 211 | --batch_size=8 212 | ``` 213 | 214 | This creates file `/tmp/output-tf.jsonl`. Now, compare this to the JSON-L file created 215 | by tBERT: 216 | 217 | ``` 218 | python -m tbert.cli.cmp_jsonl \ 219 | --tolerance 5e-5 \ 220 | /tmp/output-tbert.jsonl \ 221 | /tmp/output-tf.jsonl 222 | ``` 223 | 224 | Expect output similar to this: 225 | ``` 226 | Max float values delta: 3.6e-05 227 | Structure is identical 228 | ``` 229 | 230 | ## Fine-tuning a classifier 231 | 232 | Download GLUE datasets, as explained 233 | [here](https://github.com/google-research/bert#sentence-and-sentence-pair-classification-tasks). 234 | In the following we assume that 235 | GLUE datasets are in the `glue_data` directory. 236 | 237 | To train MRPC task, do this: 238 | ``` 239 | python -m tbert.cli.run_classifier \ 240 | data/tbert-multilingual_L-12_H-768_A-12 \ 241 | /tmp \ 242 | --problem mrpc \ 243 | --data_dir glue_data/MRPC \ 244 | --do_train \ 245 | --num_train_steps 600 \ 246 | --num_warmup_steps 60 \ 247 | --do_eval 248 | ``` 249 | 250 | Expect to see something similar to that: 251 | ``` 252 | ... 253 | Step: 550, loss: 0.039, learning rates: 1.888888888888889e-06 254 | Step: 560, loss: 0.014, learning rates: 1.5185185185185186e-06 255 | Step: 570, loss: 0.017, learning rates: 1.1481481481481482e-06 256 | Step: 580, loss: 0.021, learning rates: 7.777777777777779e-07 257 | Step: 590, loss: 0.053, learning rates: 4.074074074074075e-07 258 | Step: 600, loss: 0.061, learning rates: 3.703703703703704e-08 259 | Saved trained model 260 | *** Evaluating *** 261 | Number of samples evaluated: 408 262 | Average per-sample loss: 0.4922609218195373 263 | Accuracy: 0.8504901960784313 264 | ``` 265 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch~=1.0.0 2 | tensorflow~=1.12.0 3 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | from tbert import __version__, __description__, __url__, __author__, \ 3 | __author_email__, __keywords__ 4 | 5 | NAME = 'tbert' 6 | 7 | setup( 8 | name=NAME, 9 | version=__version__, 10 | description=__description__, 11 | long_description='See ' + __url__, 12 | url=__url__, 13 | author=__author__, 14 | author_email=__author_email__, 15 | keywords=__keywords__, 16 | 17 | license='MIT', 18 | classifiers=[ 19 | 'Development Status :: 4 - Beta', 20 | 'Intended Audience :: Developers', 21 | 'License :: OSI Approved :: MIT License', 22 | 'Programming Language :: Python :: 3.6', 23 | 'Programming Language :: Python :: Implementation :: CPython', 24 | 'Programming Language :: Python :: Implementation :: PyPy', 25 | 'Topic :: Software Development :: Libraries :: Python Modules', 26 | ], 27 | packages=[NAME], 28 | install_requires=[], # do not drag in any deps, to ease re-use 29 | ) -------------------------------------------------------------------------------- /tbert/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = '0.1.8' 2 | __description__ = 'BERT Neural Net module for pytorch' 3 | __url__ = 'https://github.com/innodatalabs/tbert' 4 | __author__ = 'Mike Kroutikov' 5 | __author_email__ = 'mkroutikov@innodata.com' 6 | __keywords__ = ['nlp', 'neural networks', 'pytorch', 'tensorflow'] 7 | -------------------------------------------------------------------------------- /tbert/attention.py: -------------------------------------------------------------------------------- 1 | # The MIT License 2 | # Copyright 2019 Innodata Labs and Mike Kroutikov 3 | # 4 | # PyTorch port of 5 | # https://github.com/google-research/bert/modeling.py 6 | # 7 | # Original code copyright follows: 8 | # 9 | # Copyright 2018 The Google AI Language Team Authors. 10 | # 11 | # Licensed under the Apache License, Version 2.0 (the "License"); 12 | # you may not use this file except in compliance with the License. 13 | # You may obtain a copy of the License at 14 | # 15 | # http://www.apache.org/licenses/LICENSE-2.0 16 | # 17 | # Unless required by applicable law or agreed to in writing, software 18 | # distributed under the License is distributed on an "AS IS" BASIS, 19 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 20 | # See the License for the specific language governing permissions and 21 | # limitations under the License.import json 22 | # 23 | import torch 24 | import torch.nn.functional as F 25 | import math 26 | 27 | 28 | def init_linear(m, initializer_range): 29 | torch.nn.init.normal_(m.weight, std=initializer_range) 30 | m.bias.data.zero_() 31 | 32 | 33 | class Attention(torch.nn.Module): 34 | 35 | def __init__(self, 36 | query_size, 37 | key_size, 38 | num_heads, 39 | head_size, 40 | dropout=0.1, 41 | initializer_range=0.02 42 | ): 43 | torch.nn.Module.__init__(self) 44 | 45 | self.query_size = query_size 46 | self.key_size = key_size 47 | self.num_heads = num_heads 48 | self.head_size = head_size 49 | 50 | self.query = torch.nn.Linear(query_size, key_size) 51 | self.key = torch.nn.Linear(key_size, key_size) 52 | self.value = torch.nn.Linear(key_size, key_size) 53 | self.dropout = torch.nn.Dropout(dropout) 54 | 55 | init_linear(self.query, initializer_range) 56 | init_linear(self.key, initializer_range) 57 | init_linear(self.value, initializer_range) 58 | 59 | def forward(self, query, key, value, mask=None, batch_size=1): 60 | ''' 61 | query [B*Q, N*H] - query sequence 62 | key [B*K, N*H] - key sequence 63 | value [B*K, N*H] - value sequence 64 | mask [B, 1, Q, K] - attention mask (optional) 65 | batch_size - the batch size (for attention reshaping) 66 | 67 | where: 68 | B - batch size 69 | Q - sequence length of query 70 | K - sequence length of key and value (must be the same) 71 | N - number of heads 72 | H - size of one head 73 | 74 | returns: 75 | [B*K, N*H] - value weighted with the attention 76 | ''' 77 | B = batch_size 78 | Q = query.size(0) // batch_size 79 | K = key.size(0) // batch_size 80 | N = self.num_heads 81 | H = self.head_size 82 | 83 | q = self.query(query) # [B*Q, N*H] 84 | k = self.key(key) # [B*K, N*H] 85 | v = self.value(value) # [B*K, N*H] 86 | 87 | # [B*Q, N*H] -> [B, Q, N, H] -> [B, N, Q, H] 88 | q = q.view(B, Q, N, H).transpose(1, 2) 89 | # [B*K, N*H] -> [B, K, N, H] -> [B, N, K, H] 90 | k = k.view(B, K, N, H).transpose(1, 2) 91 | 92 | # -> [B, N, Q, K] 93 | scores = torch.matmul(q, k.transpose(2, 3)) 94 | scores *= 1. / math.sqrt(H) 95 | 96 | if mask is not None: 97 | scores += mask 98 | 99 | w = F.softmax(scores, dim=3) 100 | w = self.dropout(w) 101 | 102 | # [B*K, N*H] -> [B, K, N, H] -> [B, N, K, H] 103 | v = v.view(B, K, N, H).transpose(1, 2) 104 | 105 | # [B, N, Q, H] -> [B, Q, N, H] 106 | c = torch.matmul(w, v).transpose(1, 2).contiguous() 107 | 108 | # [B*Q, N*H] 109 | return c.view(-1, N * H) 110 | 111 | -------------------------------------------------------------------------------- /tbert/bert.py: -------------------------------------------------------------------------------- 1 | # The MIT License 2 | # Copyright 2019 Innodata Labs and Mike Kroutikov 3 | # 4 | # PyTorch port of 5 | # https://github.com/google-research/bert/modeling.py 6 | # 7 | # Original code copyright follows: 8 | # 9 | # Copyright 2018 The Google AI Language Team Authors. 10 | # 11 | # Licensed under the Apache License, Version 2.0 (the "License"); 12 | # you may not use this file except in compliance with the License. 13 | # You may obtain a copy of the License at 14 | # 15 | # http://www.apache.org/licenses/LICENSE-2.0 16 | # 17 | # Unless required by applicable law or agreed to in writing, software 18 | # distributed under the License is distributed on an "AS IS" BASIS, 19 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 20 | # See the License for the specific language governing permissions and 21 | # limitations under the License.import json 22 | # 23 | import pickle 24 | import torch 25 | from tbert.embedding import BertEmbedding 26 | from tbert.transformer import TransformerEncoder 27 | from tbert.attention import init_linear 28 | 29 | 30 | class Bert(torch.nn.Module): 31 | '''BERT Encoder model. 32 | 33 | Reference: 34 | [BERT: Pre-training of Deep Bidirectional 35 | Transformers for Language Understanding](https://arxiv.org/abs/1810.04805). 36 | ''' 37 | 38 | def __init__(self, config): 39 | torch.nn.Module.__init__(self) 40 | 41 | if config['attention_probs_dropout_prob'] != config['hidden_dropout_prob']: 42 | raise NotImplementedError() 43 | 44 | if config['hidden_act'] != 'gelu': 45 | raise NotImplementedError() 46 | 47 | dropout = config['attention_probs_dropout_prob'] 48 | 49 | self.embedding = BertEmbedding( 50 | token_vocab_size=config['vocab_size'], 51 | segment_vocab_size=config['type_vocab_size'], 52 | hidden_size=config['hidden_size'], 53 | max_position_embeddings=config['max_position_embeddings'], 54 | initializer_range=config['initializer_range'] 55 | ) 56 | 57 | self.encoder = torch.nn.ModuleList([ 58 | TransformerEncoder( 59 | hidden_size=config['hidden_size'], 60 | num_heads=config['num_attention_heads'], 61 | intermediate_size=config['intermediate_size'], 62 | dropout=dropout, 63 | initializer_range=config['initializer_range'] 64 | ) 65 | for _ in range(config['num_hidden_layers']) 66 | ]) 67 | 68 | def forward(self, input_ids, input_type_ids=None, input_mask=None): 69 | B = input_ids.size(0) # batch size 70 | 71 | if input_mask is None: 72 | input_mask = torch.ones_like(input_ids) 73 | if input_type_ids is None: 74 | input_type_ids = torch.zeros_like(input_ids) 75 | 76 | # credit to: https://github.com/huggingface/pytorch-pretrained-BERT/blob/master/modeling.py 77 | att_mask = input_mask.unsqueeze(1).unsqueeze(2) 78 | att_mask = (1.0 - att_mask.float()) * -10000.0 79 | 80 | y = self.embedding(input_ids, input_type_ids) 81 | 82 | # reshape to matrix. Apparently for speed -MK 83 | y = y.view(-1, y.size(-1)) 84 | 85 | outputs = [] 86 | for layer in self.encoder: 87 | y = layer(y, att_mask, batch_size=B) 88 | outputs.append(y) 89 | 90 | return outputs 91 | 92 | def load_pretrained(self, dir_name): 93 | with open(f'{dir_name}/bert_model.pickle', 'rb') as f: 94 | self.load_state_dict(pickle.load(f)) 95 | 96 | def save_pretrained(self, dir_name): 97 | with open(f'{dir_name}/bert_model.pickle', 'wb') as f: 98 | pickle.dump(self.state_dict(), f) 99 | 100 | 101 | class BertPooler(torch.nn.Module): 102 | '''BERT Encoder model with pooling layer. 103 | 104 | Reference: 105 | [BERT: Pre-training of Deep Bidirectional 106 | Transformers for Language Understanding](https://arxiv.org/abs/1810.04805). 107 | ''' 108 | def __init__(self, config): 109 | torch.nn.Module.__init__(self) 110 | 111 | if config['attention_probs_dropout_prob'] != config['hidden_dropout_prob']: 112 | raise NotImplementedError() 113 | 114 | dropout = config['attention_probs_dropout_prob'] 115 | hidden_size = config['hidden_size'] 116 | 117 | self.bert = Bert(config) 118 | 119 | self.pooler = torch.nn.Linear(hidden_size, hidden_size) 120 | self.dropout = torch.nn.Dropout(dropout) 121 | 122 | init_linear(self.pooler, config['initializer_range']) 123 | 124 | def forward(self, input_ids, input_type_ids=None, input_mask=None): 125 | batch_size = input_ids.size(0) 126 | 127 | activations = self.bert(input_ids, input_type_ids, input_mask) 128 | 129 | x = activations[-1] # use top layer only 130 | x = x.view(batch_size, -1, x.size(-1)) # [B, S, H] 131 | # take activations of the first token (aka BERT-style "pooling") 132 | x = x[:, 0:1, :].squeeze(1) 133 | 134 | x = self.pooler(x) 135 | x = torch.tanh(x) 136 | 137 | return x 138 | 139 | def load_pretrained(self, dir_name): 140 | self.bert.load_pretrained(dir_name) 141 | 142 | with open(f'{dir_name}/pooler_model.pickle', 'rb') as f: 143 | self.pooler.load_state_dict(pickle.load(f)) 144 | 145 | def save_pretrained(self, dir_name): 146 | self.bert.save_pretrained(dir_name) 147 | 148 | with open(f'{dir_name}/pooler_model.pickle', 'wb') as f: 149 | pickle.dump(self.pooler.state_dict(), f) 150 | -------------------------------------------------------------------------------- /tbert/cli/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/innodatalabs/tbert/84c1c9507b3b1bffd2a08a86efaf9bc9955271e0/tbert/cli/__init__.py -------------------------------------------------------------------------------- /tbert/cli/cmp_jsonl.py: -------------------------------------------------------------------------------- 1 | # The MIT License 2 | # Copyright 2019 Innodata Labs and Mike Kroutikov 3 | # 4 | import json 5 | from types import SimpleNamespace 6 | 7 | 8 | def cmp_dict(d1, d2, ctx): 9 | k1 = set(d1.keys()) 10 | k2 = set(d2.keys()) 11 | if k1 != k2: 12 | ctx.error = 'dict keys mismatch' 13 | return 14 | 15 | for key in k1: 16 | cmp_x(d1[key], d2[key], ctx) 17 | if ctx.error is not None: 18 | ctx.path.append(key) 19 | return 20 | 21 | 22 | def cmp_int(i1, i2, ctx): 23 | if i1 != i2: 24 | ctx.error = f'value mismatch {i1} vs {i2}' 25 | 26 | 27 | def cmp_float(f1, f2, ctx): 28 | ctx.delta = max(abs(f1-f2), ctx.delta) 29 | if ctx.delta > ctx.tolerance: 30 | ctx.error = f'float value mismatch: {f1} vs {f2}' 31 | 32 | 33 | def cmp_str(s1, s2, ctx): 34 | if s1 != s2: 35 | ctx.error = f'str value mismatch: {s1[:10]} vs {s2[:10]}' 36 | 37 | 38 | def cmp_list(l1, l2, ctx): 39 | if len(l1) != len(l2): 40 | ctx.error = f'list length mismatch {len(l1)} vs {len(l2)}' 41 | return 42 | 43 | for index,(a,b) in enumerate(zip(l1, l2)): 44 | cmp_x(a, b, ctx) 45 | if ctx.error is not None: 46 | ctx.path.append(index) 47 | return 48 | 49 | 50 | _DISPATCH = { 51 | int: cmp_int, 52 | float: cmp_float, 53 | str: cmp_str, 54 | dict: cmp_dict, 55 | list: cmp_list, 56 | } 57 | 58 | def cmp_x(a, b, ctx): 59 | if type(a) is not type(b): 60 | ctx.error = f'type mismatch: {type(a)} vs {type(b)}' 61 | return 62 | 63 | _DISPATCH[type(a)](a, b, ctx) 64 | 65 | 66 | if __name__ == '__main__': 67 | import argparse 68 | 69 | parser = argparse.ArgumentParser(description='Compares two JSON-L files') 70 | parser.add_argument('jsonl1', help='Path to first JSON-L file') 71 | parser.add_argument('jsonl2', help='Path to first JSON-L file') 72 | parser.add_argument('--tolerance', default=1.e-5, type=float, help='Float comparisom tolerance') 73 | 74 | args = parser.parse_args() 75 | 76 | with open(args.jsonl1, 'r', encoding='utf-8') as f1: 77 | with open(args.jsonl2, 'r', encoding='utf-8') as f2: 78 | 79 | f1 = iter(f1) 80 | f2 = iter(f2) 81 | 82 | ctx = SimpleNamespace( 83 | error=None, 84 | path=[], 85 | tolerance=args.tolerance, 86 | delta=0. 87 | ) 88 | 89 | while True: 90 | l1 = next(f1, None) 91 | l2 = next(f2, None) 92 | 93 | if l1 is None: 94 | if l2 is None: 95 | break 96 | print('Premature end of file 1') 97 | break 98 | elif l2 is None: 99 | print('Premature end of file 2') 100 | break 101 | 102 | cmp_x(json.loads(l1), json.loads(l2), ctx) 103 | if ctx.error is not None: 104 | path = '/'.join(str(x) for x in ctx.path) 105 | print(ctx.error, 'at', path) 106 | break 107 | 108 | print('Max float values delta:', ctx.delta) 109 | if not ctx.error: 110 | print('Structure is identical') 111 | parser.exit(-1 if ctx.error else 0) 112 | 113 | -------------------------------------------------------------------------------- /tbert/cli/convert.py: -------------------------------------------------------------------------------- 1 | # The MIT License 2 | # Copyright 2019 Innodata Labs and Mike Kroutikov 3 | # 4 | import json 5 | from tbert.tf_util import read_tf_checkpoint, make_bert_pooler_state_dict 6 | from tbert.bert import BertPooler 7 | import modeling 8 | import tensorflow as tf 9 | 10 | 11 | if __name__ == '__main__': 12 | import argparse 13 | import os 14 | import shutil 15 | import pickle 16 | 17 | parser = argparse.ArgumentParser(description='Converts TF BERT checkpoint to tBERT one') 18 | 19 | parser.add_argument('input_dir', help='Directory containing pre-trained TF BERT data (bert_config.json, vocab.txt, and bert_model.chpt') 20 | parser.add_argument('output_dir', help='Directory where to write tBERT cehckoint (will be created if does not exist)') 21 | 22 | args = parser.parse_args() 23 | if args.input_dir == args.output_dir: 24 | raise ValueError('Can not write to the same directory as input_dir') 25 | 26 | src = lambda s: args.input_dir + '/' + s 27 | trg = lambda s: args.output_dir + '/' + s 28 | 29 | with open(src('bert_config.json'), 'r', encoding='utf-8') as f: 30 | config = json.load(f) 31 | 32 | print(json.dumps(config, indent=2)) 33 | 34 | os.makedirs(args.output_dir, exist_ok=True) 35 | shutil.copyfile(src('bert_config.json'), trg('bert_config.json')) 36 | shutil.copyfile(src('vocab.txt'), trg('vocab.txt')) 37 | 38 | bert_vars = read_tf_checkpoint(src('bert_model.ckpt')) 39 | 40 | bert_pooler = BertPooler(config) 41 | bert_pooler.load_state_dict( 42 | make_bert_pooler_state_dict(bert_vars, config['num_hidden_layers']) 43 | ) 44 | 45 | bert_pooler.save_pretrained(args.output_dir) 46 | 47 | print('Sucessfully created tBERT model in', args.output_dir) 48 | -------------------------------------------------------------------------------- /tbert/cli/extract_features.py: -------------------------------------------------------------------------------- 1 | # The MIT License 2 | # Copyright 2019 Innodata Labs and Mike Kroutikov 3 | # 4 | # PyTorch port of 5 | # https://github.com/google-research/bert/extract_features.py 6 | # 7 | # Original code copyright follows: 8 | # 9 | # Copyright 2018 The Google AI Language Team Authors. 10 | # 11 | # Licensed under the Apache License, Version 2.0 (the "License"); 12 | # you may not use this file except in compliance with the License. 13 | # You may obtain a copy of the License at 14 | # 15 | # http://www.apache.org/licenses/LICENSE-2.0 16 | # 17 | # Unless required by applicable law or agreed to in writing, software 18 | # distributed under the License is distributed on an "AS IS" BASIS, 19 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 20 | # See the License for the specific language governing permissions and 21 | # limitations under the License.import json 22 | import pickle 23 | import collections 24 | from tbert.data import parse_example, example_to_feats, batcher 25 | import tokenization # from original BERT repo 26 | import torch 27 | 28 | 29 | def read_examples(filename, max_seq_len, tokenizer): 30 | '''Reads examples from text file and converts to features''' 31 | 32 | with open(filename, 'r', encoding='utf-8') as f: 33 | for line in f: 34 | line = line.strip() 35 | text_a, text_b = parse_example(line) 36 | feats = example_to_feats( 37 | text_a, 38 | text_b, 39 | max_seq_len, 40 | tokenizer 41 | ) 42 | yield feats 43 | 44 | 45 | def predict_json_features(bert, examples, batch_size=32, layer_indexes=None): 46 | '''Runs BERT model on examples and creates JSON output object for each''' 47 | if layer_indexes is None: 48 | layer_indexes = [-1, -2, -3, -4] 49 | 50 | unique_id = 0 51 | for b in batcher(examples, batch_size=batch_size): 52 | input_ids = torch.LongTensor(b['input_ids']) 53 | input_type_ids = torch.LongTensor(b['input_type_ids']) 54 | input_mask = torch.LongTensor(b['input_mask']) 55 | 56 | out = bert(input_ids, input_type_ids, input_mask) 57 | num_items_in_batch = input_ids.size(0) 58 | for idx in range(num_items_in_batch): 59 | all_features = [] 60 | output_json = collections.OrderedDict([ 61 | ('linex_index', unique_id), 62 | ('features', all_features), 63 | ]) 64 | tokens = b['tokens'][idx] 65 | for i, tk in enumerate(tokens): 66 | all_layers = [] 67 | all_features.append(collections.OrderedDict([ 68 | ('token', tk), 69 | ('layers', all_layers) 70 | ])) 71 | for j, layer_index in enumerate(layer_indexes): 72 | layer_output = out[layer_index] 73 | layer_output = layer_output.view(num_items_in_batch, -1, layer_output.size(-1)) 74 | values = [round(float(x), 6) for x in layer_output[idx, i, :]] 75 | all_layers.append(collections.OrderedDict([ 76 | ('index', layer_index), 77 | ('values', values), 78 | ])) 79 | yield output_json 80 | unique_id += 1 81 | 82 | 83 | if __name__ == '__main__': 84 | import argparse 85 | from tbert.bert import Bert 86 | 87 | parser = argparse.ArgumentParser(description='Reads text file and extracts BERT features for each sample') 88 | 89 | parser.add_argument('input_file', help='Input text file - one example per line') 90 | parser.add_argument('output_file', help='Name of the output JSONL file') 91 | parser.add_argument('checkpoint_dir', help='Directory with pretrained tBERT checkpoint') 92 | parser.add_argument('--layers', default='-1,-2,-3,-4', help='List of layers to include into the output, default="%(default)s"') 93 | parser.add_argument('--batch_size', default=32, help='Batch size, default %(default)s') 94 | parser.add_argument('--max_seq_length', default=128, help='Sequence size limit (after tokenization), default is %(default)s') 95 | parser.add_argument('--do_lower_case', default=True, help='Set to false to retain case-sensitive information, default %(default)s') 96 | 97 | args = parser.parse_args() 98 | 99 | ckpt = lambda s: args.checkpoint_dir + '/' + s 100 | 101 | with open(ckpt('bert_config.json'), 'r', encoding='utf-8') as f: 102 | config = json.load(f) 103 | print(json.dumps(config, indent=2)) 104 | 105 | if config['max_position_embeddings'] < args.max_seq_length: 106 | raise ValueError('max_seq_length parameter can not exceed config["max_position_embeddings"]') 107 | 108 | tokenizer = tokenization.FullTokenizer( 109 | vocab_file=ckpt('vocab.txt'), 110 | do_lower_case=args.do_lower_case, 111 | ) 112 | 113 | bert = Bert(config) 114 | 115 | with open(ckpt('bert_model.pickle'), 'rb') as f: 116 | bert.load_state_dict(pickle.load(f)) 117 | bert.eval() 118 | 119 | layer_indexes = eval('[' + args.layers + ']') 120 | 121 | examples = read_examples(args.input_file, args.max_seq_length, tokenizer) 122 | 123 | with open(args.output_file, 'w', encoding='utf-8') as f: 124 | for feat_json in predict_json_features( 125 | bert, 126 | examples, 127 | batch_size=args.batch_size, 128 | layer_indexes=layer_indexes): 129 | f.write(json.dumps(feat_json) + '\n') 130 | 131 | print('All done') 132 | -------------------------------------------------------------------------------- /tbert/cli/run_classifier.py: -------------------------------------------------------------------------------- 1 | # The MIT License 2 | # Copyright 2019 Innodata Labs and Mike Kroutikov 3 | # 4 | # PyTorch port of 5 | # https://github.com/google-research/bert/run_classifier.py 6 | # 7 | # Original code copyright follows: 8 | # 9 | # Copyright 2018 The Google AI Language Team Authors. 10 | # 11 | # Licensed under the Apache License, Version 2.0 (the "License"); 12 | # you may not use this file except in compliance with the License. 13 | # You may obtain a copy of the License at 14 | # 15 | # http://www.apache.org/licenses/LICENSE-2.0 16 | # 17 | # Unless required by applicable law or agreed to in writing, software 18 | # distributed under the License is distributed on an "AS IS" BASIS, 19 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 20 | # See the License for the specific language governing permissions and 21 | # limitations under the License.import json 22 | # 23 | import pickle 24 | import json 25 | import csv 26 | import itertools 27 | import torch 28 | from torch.utils import data 29 | import torch.nn.functional as F 30 | from tbert.data import example_to_feats 31 | from tbert.bert import BertPooler 32 | from tbert.attention import init_linear 33 | from tbert.data import repeating_reader, batcher, shuffler 34 | from tbert.optimization import LinearDecayWithWarmupLR 35 | import tokenization 36 | 37 | 38 | class BertClassifier(torch.nn.Module): 39 | 40 | def __init__(self, config, num_classes): 41 | torch.nn.Module.__init__(self) 42 | 43 | self.bert_pooler = BertPooler(config) 44 | self.output = torch.nn.Linear(config['hidden_size'], num_classes) 45 | self.dropout = torch.nn.Dropout(config['hidden_dropout_prob']) 46 | 47 | init_linear(self.output, config['initializer_range']) 48 | 49 | def forward(self, input_ids, input_type_ids=None, input_mask=None): 50 | 51 | x = self.bert_pooler(input_ids, input_type_ids, input_mask) 52 | x = self.output(x) 53 | x = self.dropout(x) 54 | 55 | return x 56 | 57 | def load_pretrained(self, dir_name): 58 | self.bert_pooler.load_pretrained(dir_name) 59 | 60 | 61 | def _read_tsv(input_file, quotechar=None): 62 | """Reads a tab separated value file.""" 63 | with open(input_file, 'r', encoding='utf-8') as f: 64 | yield from csv.reader(f, delimiter='\t', quotechar=quotechar) 65 | 66 | 67 | def _xnli_reader(data_dir, label_vocab, partition='train', lang='zh'): 68 | if partition == 'train': 69 | for i, line in enumerate( 70 | _read_tsv(f'{data_dir}/multinli/multinli.train.{lang}.tsv') 71 | ): 72 | if i == 0: 73 | continue 74 | guid = f'{partition}-{i}' 75 | text_a = line[0] 76 | text_b = line[1] 77 | label = line[2] 78 | if label == 'contradictory': 79 | label = 'contradiction' 80 | yield guid, text_a, text_b, label_vocab[label] 81 | 82 | elif partition == 'dev': 83 | for i, line in enumerate( 84 | _read_tsv(f'{data_dir}/xnli.dev.tsv') 85 | ): 86 | if i == 0: 87 | continue 88 | guid = f'{partition}-{i}' 89 | if line[0] != lang: 90 | continue 91 | text_a = line[6] 92 | text_b = line[7] 93 | label = line[1] 94 | yield guid, text_a, text_b, label_vocab[label] 95 | 96 | else: 97 | raise ValueError('no such partition in this dataset: %r' % partition) 98 | 99 | 100 | def _mnli_reader(data_dir, label_vocab, partition='train'): 101 | 102 | fname = { 103 | 'train': f'{data_dir}/train.tsv', 104 | 'dev' : f'{data_dir}/dev_matched.tsv', 105 | 'test' : f'{data_dir}/test_matched.tsv', 106 | }.get(partition) 107 | 108 | if fname is None: 109 | raise ValueError('no such partition in this dataset: %r' % partition) 110 | 111 | for i,line in enumerate(_read_tsv(fname)): 112 | if i == 0: 113 | continue 114 | giud = f'{partition}-{line[0]}' 115 | text_a = line[8] 116 | text_b = line[9] 117 | if partition == 'test': 118 | label = 'contradiction' 119 | else: 120 | label = line[-1] 121 | yield giud, text_a, text_b, label_vocab[label] 122 | 123 | 124 | def _mrpc_reader(data_dir, label_vocab, partition='train'): 125 | 126 | fname = { 127 | 'train': f'{data_dir}/train.tsv', 128 | 'dev' : f'{data_dir}/dev.tsv', 129 | 'test' : f'{data_dir}/test.tsv', 130 | }.get(partition) 131 | 132 | if fname is None: 133 | raise ValueError('no such partition in this dataset: %r' % partition) 134 | 135 | for i,line in enumerate(_read_tsv(fname)): 136 | if i == 0: 137 | continue 138 | giud = f'{partition}-{i}' 139 | text_a = line[3] 140 | text_b = line[4] 141 | if partition == 'test': 142 | label = '0' 143 | else: 144 | label = line[0] 145 | yield giud, text_a, text_b, label_vocab[label] 146 | 147 | 148 | def _cola_reader(data_dir, label_vocab, partition='train'): 149 | 150 | fname = { 151 | 'train': f'{data_dir}/train.tsv', 152 | 'dev' : f'{data_dir}/dev.tsv', 153 | 'test' : f'{data_dir}/test.tsv', 154 | }.get(partition) 155 | 156 | if fname is None: 157 | raise ValueError('no such partition in this dataset: %r' % partition) 158 | 159 | for i,line in enumerate(_read_tsv(fname)): 160 | if partition == 'test' and i == 0: 161 | continue 162 | giud = f'{partition}-{i}' 163 | if partition == 'test': 164 | label = '0' 165 | text_a = line[1] 166 | else: 167 | label = line[1] 168 | text_a = line[3] 169 | yield giud, text_a, None, label_vocab[label] 170 | 171 | 172 | _PROBLEMS = { 173 | 'xnli': dict( 174 | labels=['contradiction', 'entailment', 'neutral'], 175 | reader=_xnli_reader 176 | ), 177 | 'mnli': dict( 178 | labels=['contradiction', 'entailment', 'neutral'], 179 | reader=_mnli_reader 180 | ), 181 | 'mrpc': dict( 182 | labels=['0', '1'], 183 | reader=_mrpc_reader 184 | ), 185 | 'cola': dict( 186 | labels=['0', '1'], 187 | reader=_cola_reader 188 | ), 189 | } 190 | 191 | 192 | def feats_reader(reader, seq_length, tokenizer): 193 | '''Reads samples from reader and makes a feature dictionary for each''' 194 | 195 | for guid, text_a, text_b, label_id in reader: 196 | feats = example_to_feats(text_a, text_b, seq_length, tokenizer) 197 | feats.update(label_id=label_id, guid=guid) 198 | yield feats 199 | 200 | 201 | if __name__ == '__main__': 202 | import argparse 203 | from tbert.bert import Bert 204 | 205 | parser = argparse.ArgumentParser(description='Reads text file and extracts BERT features for each sample') 206 | 207 | parser.add_argument('pretrained_dir', help='Directory with pretrained tBERT checkpoint') 208 | parser.add_argument('output_dir', help='Where to save trained model (and were to load from for evaluation/prediction)') 209 | parser.add_argument('--batch_size', default=32, help='Batch size, default %(default)s') 210 | parser.add_argument('--max_seq_length', default=128, help='Sequence size limit (after tokenization), default is %(default)s') 211 | parser.add_argument('--do_lower_case', default=True, help='Set to false to retain case-sensitive information, default %(default)s') 212 | 213 | parser.add_argument('--problem', required=True, choices={'cola', 'mnli', 'mrpc', 'xnli'}, help='problem type') 214 | parser.add_argument('--data_dir', required=True, help='Directory with the data') 215 | parser.add_argument('--do_train', action='store_true', help='Set this flag to run training') 216 | parser.add_argument('--do_eval', action='store_true', help='Set this flag to run evaluation') 217 | parser.add_argument('--do_predict', action='store_true', help='Set this flag to run prediction') 218 | 219 | parser.add_argument('--learning_rate', default=2.e-5, help='Learning rate for training, default %(default)s') 220 | parser.add_argument('--num_train_steps', default=1000, type=int, help='Number of training steps, default %(default)s') 221 | parser.add_argument('--num_warmup_steps', default=200, type=int, help='Number of learning rate warmup steps, default %(default)s') 222 | parser.add_argument('--macro_batch', default=1, help='Number of batches to accumulate gradiends before optimizer does the update, default %(default)s') 223 | parser.add_argument('--print_every', default=10, help='How often to print training stats, default %(default)s') 224 | 225 | args = parser.parse_args() 226 | 227 | problem = _PROBLEMS[args.problem] 228 | label_vocab = { 229 | label: i 230 | for i, label in enumerate(problem['labels']) 231 | } 232 | problem_reader = problem['reader'] 233 | 234 | inp = lambda s: f'{args.pretrained_dir}/{s}' 235 | out = lambda s: f'{args.output_dir}/{s}' 236 | 237 | with open(inp('bert_config.json'), 'r', encoding='utf-8') as f: 238 | config = json.load(f) 239 | print(json.dumps(config, indent=2)) 240 | 241 | if config['max_position_embeddings'] < args.max_seq_length: 242 | raise ValueError('max_seq_length parameter can not exceed config["max_position_embeddings"]') 243 | 244 | print('Loading vocabulary...') 245 | tokenizer = tokenization.FullTokenizer( 246 | vocab_file=inp('vocab.txt'), 247 | do_lower_case=args.do_lower_case, 248 | ) 249 | print('Done loading vocabulary.') 250 | 251 | classifier = BertClassifier(config, len(label_vocab)) 252 | if args.do_train: 253 | print('Loading pre-trained weights...') 254 | classifier.load_pretrained(args.pretrained_dir) 255 | print('Done loading pre-trained weights.') 256 | 257 | device = torch.device('cpu') 258 | if torch.cuda.is_available(): 259 | device = torch.device('cuda') 260 | classifier.to(device) 261 | 262 | if args.do_train: 263 | print('*** Training ***') 264 | classifier.train() 265 | 266 | reader = repeating_reader( 267 | -1, # repeat indefinetely 268 | problem_reader, 269 | args.data_dir, 270 | label_vocab, 271 | partition='train' 272 | ) 273 | 274 | reader = shuffler(reader, buffer_size=1000) 275 | 276 | reader = feats_reader( 277 | reader, 278 | args.max_seq_length, 279 | tokenizer 280 | ) 281 | 282 | opt = torch.optim.Adam( 283 | classifier.parameters(), 284 | lr=args.learning_rate, 285 | betas=(0.9, 0.999), 286 | eps=1.e-6 287 | ) 288 | 289 | lr_schedule = LinearDecayWithWarmupLR( 290 | opt, 291 | args.num_train_steps, 292 | args.num_warmup_steps 293 | ) 294 | 295 | step = 0 296 | for b in itertools.islice( 297 | batcher(reader, batch_size=args.batch_size), 298 | args.num_train_steps*args.macro_batch): 299 | input_ids = torch.LongTensor(b['input_ids']).to(device) 300 | input_type_ids = torch.LongTensor(b['input_type_ids']).to(device) 301 | input_mask = torch.LongTensor(b['input_mask']).to(device) 302 | label_id = torch.LongTensor(b['label_id']).to(device) 303 | 304 | logits = classifier(input_ids, input_type_ids, input_mask) 305 | log_probs = F.log_softmax(logits, dim=-1) 306 | loss = F.nll_loss(log_probs, label_id, reduction='elementwise_mean') 307 | loss.backward() 308 | 309 | step += 1 310 | if step % args.macro_batch == 0: 311 | opt.step() 312 | lr_schedule.step() 313 | opt.zero_grad() 314 | 315 | if step % args.print_every == 0: 316 | lrs = [p['lr'] for p in opt.param_groups][0] 317 | print(f'Step: {step:>10}, loss: {loss.item():6.2}, learning rates: {lrs:8}') 318 | 319 | # save trained 320 | with open(f'{args.output_dir}/bert_classifier.pickle', 'wb') as f: 321 | pickle.dump(classifier.state_dict(), f) 322 | print('Saved trained model') 323 | else: 324 | # load trained 325 | with open(f'{args.output_dir}/bert_classifier.pickle', 'rb') as f: 326 | classifier.load_state_dict(pickle.load(f)) 327 | print('Loaded checkpoint') 328 | 329 | if args.do_eval: 330 | print('*** Evaluating ***') 331 | classifier.eval() 332 | 333 | reader = feats_reader( 334 | problem_reader(args.data_dir, label_vocab, partition='dev'), 335 | args.max_seq_length, 336 | tokenizer 337 | ) 338 | 339 | total_loss = 0. 340 | total_samples = 0 341 | total_hits = 0 342 | for b in batcher(reader, batch_size=args.batch_size): 343 | input_ids = torch.LongTensor(b['input_ids']).to(device) 344 | input_type_ids = torch.LongTensor(b['input_type_ids']).to(device) 345 | input_mask = torch.LongTensor(b['input_mask']).to(device) 346 | label_id = torch.LongTensor(b['label_id']).to(device) 347 | 348 | logits = classifier(input_ids, input_type_ids, input_mask) 349 | log_probs = F.log_softmax(logits, dim=-1) 350 | loss = F.nll_loss(log_probs, label_id, reduction='sum').item() 351 | prediction = torch.argmax(log_probs, dim=-1) 352 | hits = (label_id == prediction).sum().item() 353 | 354 | total_loss += loss 355 | total_hits += hits 356 | total_samples += input_ids.size(0) 357 | 358 | print('Number of samples evaluated:', total_samples) 359 | print('Average per-sample loss:', total_loss / total_samples) 360 | print('Accuracy:', total_hits / total_samples) 361 | 362 | if args.do_predict: 363 | print('*** Predicting ***') 364 | classifier.eval() 365 | 366 | reader = feats_reader( 367 | problem_reader(args.data_dir, label_vocab, partition='test'), 368 | args.max_seq_length, 369 | tokenizer 370 | ) 371 | 372 | with open(f'{args.output_dir}/test_results.tsv', 'w') as f: 373 | for b in batcher(reader, batch_size=args.batch_size): 374 | input_ids = torch.LongTensor(b['input_ids']).to(device) 375 | input_type_ids = torch.LongTensor(b['input_type_ids']).to(device) 376 | input_mask = torch.LongTensor(b['input_mask']).to(device) 377 | 378 | logits = classifier(input_ids, input_type_ids, input_mask) 379 | prob = F.softmax(logits, dim=-1) 380 | for i in range(prob.size(0)): 381 | f.write('\t'.join(str(p) for p in prob[i].tolist()) + '\n') 382 | 383 | print('All done') 384 | -------------------------------------------------------------------------------- /tbert/data.py: -------------------------------------------------------------------------------- 1 | # The MIT License 2 | # Copyright 2019 Innodata Labs and Mike Kroutikov 3 | # 4 | # Heavily borrows from: 5 | # https://github.com/google-research/bert/extract_features.py 6 | # 7 | # Original code copyright follows: 8 | # 9 | # Copyright 2018 The Google AI Language Team Authors. 10 | # 11 | # Licensed under the Apache License, Version 2.0 (the "License"); 12 | # you may not use this file except in compliance with the License. 13 | # You may obtain a copy of the License at 14 | # 15 | # http://www.apache.org/licenses/LICENSE-2.0 16 | # 17 | # Unless required by applicable law or agreed to in writing, software 18 | # distributed under the License is distributed on an "AS IS" BASIS, 19 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 20 | # See the License for the specific language governing permissions and 21 | # limitations under the License.import json 22 | # 23 | import re 24 | import collections 25 | import random 26 | import torch 27 | 28 | 29 | def parse_example(line): 30 | m = re.match(r"^(.*) \|\|\| (.*)$", line) 31 | if m is None: 32 | return (line, None) 33 | 34 | return m.group(1), m.group(2) 35 | 36 | 37 | def example_to_feats(text_a, text_b, seq_length, tokenizer): 38 | 39 | tokens_a = tokenizer.tokenize(text_a) 40 | 41 | tokens_b = None 42 | if text_b: 43 | tokens_b = tokenizer.tokenize(text_b) 44 | 45 | if tokens_b: 46 | # Modifies `tokens_a` and `tokens_b` in place so that the total 47 | # length is less than the specified length. 48 | # Account for [CLS], [SEP], [SEP] with "- 3" 49 | _truncate_seq_pair(tokens_a, tokens_b, seq_length - 3) 50 | else: 51 | # Account for [CLS] and [SEP] with "- 2" 52 | if len(tokens_a) > seq_length - 2: 53 | tokens_a = tokens_a[0:(seq_length - 2)] 54 | 55 | # The convention in BERT is: 56 | # (a) For sequence pairs: 57 | # tokens: [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP] 58 | # type_ids: 0 0 0 0 0 0 0 0 1 1 1 1 1 1 59 | # (b) For single sequences: 60 | # tokens: [CLS] the dog is hairy . [SEP] 61 | # type_ids: 0 0 0 0 0 0 0 62 | # 63 | # Where "type_ids" are used to indicate whether this is the first 64 | # sequence or the second sequence. The embedding vectors for `type=0` and 65 | # `type=1` were learned during pre-training and are added to the wordpiece 66 | # embedding vector (and position vector). This is not *strictly* necessary 67 | # since the [SEP] token unambiguously separates the sequences, but it makes 68 | # it easier for the model to learn the concept of sequences. 69 | # 70 | # For classification tasks, the first vector (corresponding to [CLS]) is 71 | # used as as the "sentence vector". Note that this only makes sense because 72 | # the entire model is fine-tuned. 73 | tokens = ['[CLS]'] 74 | tokens.extend(tokens_a) 75 | tokens.append('[SEP]') 76 | 77 | input_type_ids = [0] * len(tokens) 78 | 79 | if tokens_b: 80 | tokens.extend(tokens_b) 81 | tokens.append('[SEP]') 82 | input_type_ids.extend([1]*(len(tokens_b)+1)) 83 | 84 | input_ids = tokenizer.convert_tokens_to_ids(tokens) 85 | 86 | # The mask has 1 for real tokens and 0 for padding tokens. Only real 87 | # tokens are attended to. 88 | input_mask = [1] * len(input_ids) 89 | 90 | # Zero-pad up to the sequence length. 91 | if len(input_ids) < seq_length: 92 | padding = [0] * (seq_length - len(input_ids)) 93 | input_ids.extend(padding) 94 | input_mask.extend(padding) 95 | input_type_ids.extend(padding) 96 | 97 | assert len(input_ids) == seq_length 98 | assert len(input_mask) == seq_length 99 | assert len(input_type_ids) == seq_length 100 | 101 | return dict( 102 | tokens=tokens, 103 | input_ids=input_ids, 104 | input_mask=input_mask, 105 | input_type_ids=input_type_ids 106 | ) 107 | 108 | 109 | def _truncate_seq_pair(tokens_a, tokens_b, max_length): 110 | """Truncates a sequence pair in place to the maximum length.""" 111 | 112 | # This is a simple heuristic which will always truncate the longer sequence 113 | # one token at a time. This makes more sense than truncating an equal percent 114 | # of tokens from each, since if one sequence is very short then each token 115 | # that's truncated likely contains more information than a longer sequence. 116 | while True: 117 | total_length = len(tokens_a) + len(tokens_b) 118 | if total_length <= max_length: 119 | break 120 | if len(tokens_a) > len(tokens_b): 121 | tokens_a.pop() 122 | else: 123 | tokens_b.pop() 124 | 125 | 126 | def group(sequence, batch_size=32, allow_incomplete=True): 127 | '''Groups input stream into batches of at most batch_size''' 128 | buffer = [] 129 | for s in sequence: 130 | buffer.append(s) 131 | if len(buffer) >= batch_size: 132 | yield buffer[:] 133 | buffer.clear() 134 | 135 | if len(buffer) > 0 and allow_incomplete: 136 | yield buffer 137 | 138 | 139 | def batcher(sequence, batch_size=32, allow_incomplete=True): 140 | '''Batches input sequence of features''' 141 | 142 | def shape_batch(batch): 143 | out = collections.defaultdict(list) 144 | for seq in batch: 145 | for key, val in seq.items(): 146 | out[key].append(val) 147 | return dict(out) 148 | 149 | for batch in group( 150 | sequence, 151 | batch_size=batch_size, 152 | allow_incomplete=allow_incomplete 153 | ): 154 | yield shape_batch(batch) 155 | 156 | 157 | def shuffler(stream, buffer_size=100000): 158 | '''Shuffles stream of input samples. 159 | 160 | Uses internal buffer to hold samples delayed for shuffling. 161 | Bigger buffer size gives better shuffling. 162 | ''' 163 | buffer = [] 164 | for sample in stream: 165 | if len(buffer) >= buffer_size: 166 | random.shuffle(buffer) 167 | yield from buffer[len(buffer)//2:] 168 | del buffer[len(buffer)//2:] 169 | buffer.append(sample) 170 | 171 | random.shuffle(buffer) 172 | yield from buffer 173 | 174 | 175 | def repeating_reader(num_epochs: int, reader_factory, *av, **kav): 176 | '''Creates a bigger stream of samples by repeating data 177 | for the specified number of epochs. To repeat indefinetely 178 | use num_epochs=-1 179 | ''' 180 | while num_epochs != 0: 181 | yield from reader_factory(*av, **kav) 182 | num_epochs -= 1 183 | -------------------------------------------------------------------------------- /tbert/embedding.py: -------------------------------------------------------------------------------- 1 | # The MIT License 2 | # Copyright 2019 Innodata Labs and Mike Kroutikov 3 | # 4 | # PyTorch port of 5 | # https://github.com/google-research/bert/modeling.py 6 | # 7 | # Original code copyright follows: 8 | # 9 | # Copyright 2018 The Google AI Language Team Authors. 10 | # 11 | # Licensed under the Apache License, Version 2.0 (the "License"); 12 | # you may not use this file except in compliance with the License. 13 | # You may obtain a copy of the License at 14 | # 15 | # http://www.apache.org/licenses/LICENSE-2.0 16 | # 17 | # Unless required by applicable law or agreed to in writing, software 18 | # distributed under the License is distributed on an "AS IS" BASIS, 19 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 20 | # See the License for the specific language governing permissions and 21 | # limitations under the License.import json 22 | # 23 | import torch 24 | 25 | 26 | class BertEmbedding(torch.nn.Module): 27 | 28 | def __init__(self, 29 | token_vocab_size=105879, 30 | segment_vocab_size=2, 31 | hidden_size=768, 32 | max_position_embeddings=512, 33 | initializer_range=0.02): 34 | ''' 35 | token_vocab_size - size of token (word pieces) vocabulary 36 | segment_vocab_size - number of segments (BERT uses 2 always, do not change) 37 | hidden_size - size of the hidden transformer layer (number of embedding dimensions) 38 | max_position_embeddings - longest sequence size this model will support 39 | ''' 40 | torch.nn.Module.__init__(self) 41 | 42 | self.token_embedding = torch.nn.Embedding(token_vocab_size, hidden_size, padding_idx=0) 43 | self.segment_embedding = torch.nn.Embedding(segment_vocab_size, hidden_size) 44 | self.position_embedding = torch.nn.Parameter( 45 | data=torch.zeros( 46 | max_position_embeddings, 47 | hidden_size, 48 | dtype=torch.float32 49 | ) 50 | ) 51 | self.layer_norm = torch.nn.LayerNorm(hidden_size, eps=1.e-12) 52 | 53 | # apply weight initialization 54 | torch.nn.init.normal_(self.token_embedding.weight, std=initializer_range) 55 | torch.nn.init.normal_(self.segment_embedding.weight, std=initializer_range) 56 | torch.nn.init.normal_(self.position_embedding, std=initializer_range) 57 | 58 | def forward(self, input_ids, input_type_ids): 59 | ''' 60 | input_ids - LongTensor of shape [B, S] containing token ids (padded with 0) 61 | input_type_ids - LongTensor of shape [B, S] containing token segment ids. 62 | These are: 0 for the tokens in first segment, and 1 for the tokens 63 | in second segment 64 | 65 | Here: B - batch size, S - sequence length 66 | ''' 67 | batch_size = input_ids.size(0) 68 | seq_len = input_ids.size(1) 69 | x = self.token_embedding(input_ids) 70 | s = self.segment_embedding(input_type_ids) 71 | p = self.position_embedding[:seq_len, :].unsqueeze(0).repeat((batch_size, 1, 1)) 72 | 73 | return self.layer_norm(x + s + p) 74 | 75 | -------------------------------------------------------------------------------- /tbert/gelu.py: -------------------------------------------------------------------------------- 1 | # The MIT License 2 | # Copyright 2019 Innodata Labs and Mike Kroutikov 3 | # 4 | # PyTorch port of 5 | # https://github.com/google-research/bert/modeling.py 6 | # 7 | # Original code copyright follows: 8 | # 9 | # Copyright 2018 The Google AI Language Team Authors. 10 | # 11 | # Licensed under the Apache License, Version 2.0 (the "License"); 12 | # you may not use this file except in compliance with the License. 13 | # You may obtain a copy of the License at 14 | # 15 | # http://www.apache.org/licenses/LICENSE-2.0 16 | # 17 | # Unless required by applicable law or agreed to in writing, software 18 | # distributed under the License is distributed on an "AS IS" BASIS, 19 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 20 | # See the License for the specific language governing permissions and 21 | # limitations under the License.import json 22 | # 23 | import torch 24 | import math 25 | 26 | 27 | def gelu(x): 28 | '''Gaussian Error Linear Unit - a smooth version of RELU''' 29 | cdf = 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0))) 30 | return x * cdf 31 | -------------------------------------------------------------------------------- /tbert/optimization.py: -------------------------------------------------------------------------------- 1 | # The MIT License 2 | # Copyright 2019 Innodata Labs and Mike Kroutikov 3 | # 4 | # PyTorch port of 5 | # https://github.com/google-research/bert/modeling.py 6 | # 7 | # Original code copyright follows: 8 | # 9 | # Copyright 2018 The Google AI Language Team Authors. 10 | # 11 | # Licensed under the Apache License, Version 2.0 (the "License"); 12 | # you may not use this file except in compliance with the License. 13 | # You may obtain a copy of the License at 14 | # 15 | # http://www.apache.org/licenses/LICENSE-2.0 16 | # 17 | # Unless required by applicable law or agreed to in writing, software 18 | # distributed under the License is distributed on an "AS IS" BASIS, 19 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 20 | # See the License for the specific language governing permissions and 21 | # limitations under the License.import json 22 | # 23 | from torch.optim.lr_scheduler import LambdaLR 24 | 25 | 26 | class LinearDecayWithWarmupLR(LambdaLR): 27 | 28 | def __init__(self, optimizer, train_steps, warmup_steps, last_epoch=-1): 29 | 30 | def schedule(step): 31 | if step <= warmup_steps: 32 | return step / warmup_steps 33 | assert step <= train_steps 34 | return (train_steps - step) / (train_steps - warmup_steps) 35 | 36 | LambdaLR.__init__(self, optimizer, schedule, last_epoch=last_epoch) 37 | -------------------------------------------------------------------------------- /tbert/test/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/innodatalabs/tbert/84c1c9507b3b1bffd2a08a86efaf9bc9955271e0/tbert/test/__init__.py -------------------------------------------------------------------------------- /tbert/test/test_bert.py: -------------------------------------------------------------------------------- 1 | # The MIT License 2 | # Copyright 2019 Innodata Labs and Mike Kroutikov 3 | # 4 | import tensorflow as tf 5 | import torch 6 | import numpy as np 7 | import random 8 | from tbert.tf_util import tracer_session, get_tf_bert_init_params, \ 9 | run_tf_bert_once, run_tbert_once, run_tbert_pooler_once 10 | 11 | 12 | # to get stable results 13 | tf.set_random_seed(1) 14 | torch.manual_seed(1) 15 | np.random.seed(1) 16 | random.seed(1) 17 | 18 | 19 | CONFIG_MICRO = dict( 20 | attention_probs_dropout_prob=0.1, 21 | directionality="bidi", 22 | hidden_act="gelu", 23 | hidden_dropout_prob=0.1, 24 | hidden_size=10, 25 | initializer_range=0.02, 26 | intermediate_size=10, 27 | max_position_embeddings=20, 28 | num_attention_heads=2, 29 | num_hidden_layers=1, 30 | type_vocab_size=2, 31 | vocab_size=100 32 | ) 33 | 34 | PARAMS_MICRO = get_tf_bert_init_params(CONFIG_MICRO) 35 | 36 | 37 | CONFIG_BIG = dict( 38 | attention_probs_dropout_prob = 0.1, 39 | directionality = "bidi", 40 | hidden_act = "gelu", 41 | hidden_dropout_prob = 0.1, 42 | hidden_size = 768, 43 | initializer_range = 0.02, 44 | intermediate_size = 3072, 45 | max_position_embeddings = 512, 46 | num_attention_heads = 12, 47 | num_hidden_layers = 12, 48 | pooler_fc_size = 768, 49 | pooler_num_attention_heads = 12, 50 | pooler_num_fc_layers = 3, 51 | pooler_size_per_head = 128, 52 | pooler_type = "first_token_transform", 53 | type_vocab_size = 2, 54 | vocab_size = 105879 55 | ) 56 | 57 | PARAMS_BIG = get_tf_bert_init_params(CONFIG_BIG) 58 | 59 | 60 | def assert_same(*av, tolerance=1.e-6): 61 | tf_out, _ = run_tf_bert_once(*av) 62 | tbert_out = run_tbert_once(*av) 63 | 64 | # compare 65 | assert len(tf_out) == len(tbert_out) 66 | for x,y in zip(tf_out, tbert_out): 67 | delta = np.max(np.abs(x.flatten()-y.flatten())) 68 | assert delta < tolerance, delta 69 | 70 | 71 | def assert_same_pooler(*av, tolerance=1.e-6): 72 | _, tf_logits = run_tf_bert_once(*av) 73 | tbert_logits = run_tbert_pooler_once(*av) 74 | 75 | # compare 76 | delta = np.max(np.abs(tf_logits.flatten()-tbert_logits.flatten())) 77 | assert delta < tolerance, delta 78 | 79 | 80 | def make_random_inputs(vocab_size, shape): 81 | input_ids = np.random.randint(vocab_size, size=shape) 82 | input_type_ids = np.random.randint(2, size=shape) 83 | input_mask = np.random.randint(2, size=shape) 84 | 85 | return input_ids, input_type_ids, input_mask 86 | 87 | 88 | def test_smoke(): 89 | input_ids = np.array([[1, 2, 3, 4, 5, 0]]) 90 | input_type_ids = np.array([[0, 0, 1, 1, 1, 0]]) 91 | input_mask = np.array([[1, 1, 1, 1, 1, 0]]) 92 | 93 | assert_same(CONFIG_MICRO, PARAMS_MICRO, input_ids, input_type_ids, input_mask) 94 | 95 | 96 | def test_random(): 97 | input_ids, input_type_ids, input_mask = make_random_inputs(100, (2, 5)) 98 | 99 | assert_same(CONFIG_MICRO, PARAMS_MICRO, input_ids, input_type_ids, input_mask) 100 | 101 | 102 | def test_random_big(): 103 | input_ids, input_type_ids, input_mask = make_random_inputs(10000, (10, 128)) 104 | 105 | assert_same(CONFIG_BIG, PARAMS_BIG, input_ids, input_type_ids, input_mask, 106 | tolerance=1e-4) 107 | 108 | 109 | def test_pooler(): 110 | input_ids, input_type_ids, input_mask = make_random_inputs(100, (2, 5)) 111 | 112 | assert_same_pooler(CONFIG_MICRO, PARAMS_MICRO, input_ids, input_type_ids, input_mask) 113 | 114 | 115 | def test_pooler_big(): 116 | input_ids, input_type_ids, input_mask = make_random_inputs(10000, (10, 128)) 117 | 118 | assert_same_pooler(CONFIG_BIG, PARAMS_BIG, input_ids, input_type_ids, input_mask, 119 | tolerance=2.e-6) 120 | -------------------------------------------------------------------------------- /tbert/tf_util.py: -------------------------------------------------------------------------------- 1 | # The MIT License 2 | # Copyright 2019 Innodata Labs and Mike Kroutikov 3 | # 4 | ''' 5 | Utilities to trace TF graph execution, capture TF BERT parameters, 6 | and convert them to PyTorch. 7 | ''' 8 | import contextlib 9 | import tensorflow as tf 10 | import modeling 11 | import torch 12 | import numpy as np 13 | from tbert.bert import Bert, BertPooler 14 | 15 | 16 | def read_tf_checkpoint(init_checkpoint): 17 | '''Reads standard TF checkpoint and returns all variables. 18 | 19 | Returns: 20 | dictionary var_name==>numpy_array 21 | ''' 22 | c = tf.train.load_checkpoint(init_checkpoint) 23 | return { 24 | name: c.get_tensor(name) 25 | for name in c.get_variable_to_shape_map().keys() 26 | } 27 | 28 | def make_state_dict(vvars, mapping, **fmt): 29 | '''Creates PyTorch *state dict* from TF variables and mapping info 30 | 31 | Mapping from TF to PyTorch is defined by a dictionary with keys 32 | being PyTorch parameter name, and values being a dictionary containing: 33 | * path - the TF variable name 34 | * transpose - optional True/False flag to indicate that TF value need 35 | to be transposed when copied to PyTorch 36 | 37 | Path may contain formatting templates, that will be expanded using (optional) 38 | "fmt" keyword parameters 39 | 40 | Returns PyTorch state dictionary 41 | ''' 42 | 43 | def make_tensor(item): 44 | var = vvars[item['path'].format(**fmt)] 45 | if item.get('transpose'): 46 | var = var.T 47 | return torch.FloatTensor(var) 48 | 49 | return { 50 | name: make_tensor(item) 51 | for name, item in mapping.items() 52 | } 53 | 54 | # Mapping spec for BERT embedder 55 | EMBED_SPEC = { 56 | 'token_embedding.weight' : { 57 | 'path': 'bert/embeddings/word_embeddings', 58 | }, 59 | 'segment_embedding.weight' : { 60 | 'path': 'bert/embeddings/token_type_embeddings', 61 | }, 62 | 'position_embedding' : { 63 | 'path': 'bert/embeddings/position_embeddings', 64 | }, 65 | 'layer_norm.weight' : { 66 | 'path': 'bert/embeddings/LayerNorm/gamma', 67 | }, 68 | 'layer_norm.bias' : { 69 | 'path': 'bert/embeddings/LayerNorm/beta', 70 | }, 71 | } 72 | 73 | # mapping spec for BERT encoder 74 | ENCODER_SPEC = { 75 | 'attention.query.weight' : { 76 | 'path': 'bert/encoder/layer_{L}/attention/self/query/kernel', 77 | 'transpose': True 78 | }, 79 | 'attention.query.bias' : { 80 | 'path': 'bert/encoder/layer_{L}/attention/self/query/bias' 81 | }, 82 | 'attention.key.weight' : { 83 | 'path': 'bert/encoder/layer_{L}/attention/self/key/kernel', 84 | 'transpose': True 85 | }, 86 | 'attention.key.bias' : { 87 | 'path': 'bert/encoder/layer_{L}/attention/self/key/bias' 88 | }, 89 | 'attention.value.weight' : { 90 | 'path': 'bert/encoder/layer_{L}/attention/self/value/kernel', 91 | 'transpose': True 92 | }, 93 | 'attention.value.bias' : { 94 | 'path': 'bert/encoder/layer_{L}/attention/self/value/bias' 95 | }, 96 | 'dense.weight' : { 97 | 'path': 'bert/encoder/layer_{L}/attention/output/dense/kernel', 98 | 'transpose': True 99 | }, 100 | 'dense.bias' : { 101 | 'path': 'bert/encoder/layer_{L}/attention/output/dense/bias' 102 | }, 103 | 'dense_layer_norm.weight' : { 104 | 'path': 'bert/encoder/layer_{L}/attention/output/LayerNorm/gamma' 105 | }, 106 | 'dense_layer_norm.bias' : { 107 | 'path': 'bert/encoder/layer_{L}/attention/output/LayerNorm/beta' 108 | }, 109 | 'intermediate.weight' : { 110 | 'path': 'bert/encoder/layer_{L}/intermediate/dense/kernel', 111 | 'transpose': True 112 | }, 113 | 'intermediate.bias' : { 114 | 'path': 'bert/encoder/layer_{L}/intermediate/dense/bias' 115 | }, 116 | 'output.weight' : { 117 | 'path': 'bert/encoder/layer_{L}/output/dense/kernel', 118 | 'transpose': True 119 | }, 120 | 'output.bias' : { 121 | 'path': 'bert/encoder/layer_{L}/output/dense/bias' 122 | }, 123 | 'output_layer_norm.weight' : { 124 | 'path': 'bert/encoder/layer_{L}/output/LayerNorm/gamma' 125 | }, 126 | 'output_layer_norm.bias' : { 127 | 'path': 'bert/encoder/layer_{L}/output/LayerNorm/beta' 128 | }, 129 | } 130 | 131 | 132 | POOLER_SPEC = { 133 | 'weight' : { 134 | 'path': 'bert/pooler/dense/kernel', 135 | 'transpose': True 136 | }, 137 | 'bias' : { 138 | 'path': 'bert/pooler/dense/bias' 139 | }, 140 | } 141 | 142 | 143 | def make_bert_state_dict(vvars, num_hidden_layers=12): 144 | '''Creates tBERT *state dict* from TF BERT parameters''' 145 | 146 | state_dict = {} 147 | 148 | state_dict.update({ 149 | f'embedding.{name}': array 150 | for name, array in make_state_dict(vvars, EMBED_SPEC).items() 151 | }) 152 | 153 | for layer in range(num_hidden_layers): 154 | layer_state = make_state_dict(vvars, ENCODER_SPEC, L=layer) 155 | state_dict.update({ 156 | f'encoder.{layer}.{name}': array 157 | for name, array in layer_state.items() 158 | }) 159 | 160 | return state_dict 161 | 162 | 163 | def make_bert_pooler_state_dict(vvars, num_hidden_layers=12): 164 | '''Creates tBERT *state dict* from TF BERT parameters''' 165 | state_dict = {} 166 | 167 | state_dict.update({ 168 | f'bert.{name}': array 169 | for name, array in make_bert_state_dict(vvars, num_hidden_layers).items() 170 | }) 171 | 172 | state_dict.update({ 173 | f'pooler.{name}': array 174 | for name, array in make_state_dict(vvars, POOLER_SPEC).items() 175 | }) 176 | 177 | return state_dict 178 | 179 | 180 | class Tracer: 181 | '''Wraps tf.Session to provide convenience methods to set 182 | trainable parameters from numpy array, and to read trainable 183 | parameters as numpy arrays. 184 | ''' 185 | 186 | def __init__(self, sess): 187 | self.sess = sess 188 | self._tracer_init_ops = {} 189 | 190 | @property 191 | def graph(self): 192 | return self.sess.graph 193 | 194 | def run(self, *av, **kav): 195 | return self.sess.run(*av, **kav) 196 | 197 | def getmany(self, names): 198 | graph = self.graph 199 | vv = [graph.get_tensor_by_name(name+':0') for name in names] 200 | 201 | return dict(zip(names, self.run(vv))) 202 | 203 | def __getitem__(self, name): 204 | assert type(name_or_names) is str 205 | return self.getmany([name])[name] 206 | 207 | def update(self, params): 208 | feed_dict = {} 209 | to_run = [] 210 | graph = self.graph 211 | for name, array in params.items(): 212 | if name not in self._tracer_init_ops: 213 | with tf.name_scope('initTracer'): 214 | var = graph.get_tensor_by_name(name+':0') 215 | assert array.dtype == np.float32 216 | p = tf.placeholder(tf.float32, name=name) 217 | op = tf.assign(var, p) 218 | self._tracer_init_ops[name] = (op, p) 219 | op, p = self._tracer_init_ops[name] 220 | to_run.append(op) 221 | feed_dict[p] = array 222 | self.run(to_run, feed_dict=feed_dict) 223 | 224 | def __setitem__(self, name, array): 225 | self.update({name: array}) 226 | 227 | def trainable_variables(self): 228 | names = [v.name.rstrip(':0') for v in tf.trainable_variables()] 229 | return self.getmany(names) 230 | 231 | 232 | @contextlib.contextmanager 233 | def tracer_session(): 234 | with tf.Graph().as_default(), tf.Session() as _sess: 235 | tracer = Tracer(_sess) 236 | yield tracer 237 | 238 | 239 | def run_tf_bert_once(config, params, input_ids, input_type_ids=None, input_mask=None): 240 | '''Created TF BERT model from config and params, and runs it on the provided inputs 241 | 242 | If initialization params are not provided, then TF BERT is randomly initialized 243 | 244 | Retuns the array of activations of all encoder layers. 245 | ''' 246 | with tracer_session() as sess: 247 | if input_type_ids is None: 248 | input_type_ids = np.zeros_like(input_ids) 249 | if input_mask is None: 250 | input_type_ids = np.ones_like(input_ids) 251 | 252 | pinput_ids = tf.placeholder(dtype=tf.int32, shape=input_ids.shape, name='input_ids') 253 | pinput_mask = tf.placeholder(dtype=tf.int32, shape=input_mask.shape, name='input_mask') 254 | pinput_type_ids = tf.placeholder(dtype=tf.int32, shape=input_type_ids.shape, name='input_type_ids') 255 | 256 | model = modeling.BertModel( 257 | modeling.BertConfig.from_dict(config), 258 | is_training=False, 259 | input_ids=pinput_ids, 260 | input_mask=pinput_mask, 261 | token_type_ids=pinput_type_ids, 262 | use_one_hot_embeddings=False, 263 | ) 264 | 265 | sess.run(tf.global_variables_initializer()) 266 | sess.update(params) # set all trainable params 267 | 268 | num_hidden_layers = config['num_hidden_layers'] 269 | to_eval = [] 270 | for reshape_id in range(2, 2 + num_hidden_layers): 271 | to_eval.append( 272 | sess.graph.get_tensor_by_name(f'bert/encoder/Reshape_{reshape_id}:0') 273 | ) 274 | pooler_output = sess.graph.get_tensor_by_name('bert/pooler/dense/Tanh:0') 275 | 276 | out, pout = sess.run((to_eval, pooler_output), feed_dict={ 277 | pinput_ids: input_ids, 278 | pinput_type_ids: input_type_ids, 279 | pinput_mask: input_mask 280 | }) 281 | 282 | return out, pout 283 | 284 | def get_tf_bert_init_params(config): 285 | '''Created TF BERT model from config and params, and runs it on the provided inputs 286 | 287 | If initialization params are not provided, then TF BERT is randomly initialized 288 | 289 | Retuns the array of activations of all encoder layers. 290 | ''' 291 | input_ids = np.zeros(dtype=np.int32, shape=(1, 20)) 292 | input_type_ids = np.zeros_like(input_ids) 293 | input_mask = np.ones_like(input_ids) 294 | 295 | with tracer_session() as sess: 296 | pinput_ids = tf.placeholder(dtype=tf.int32, shape=input_ids.shape, name='input_ids') 297 | pinput_mask = tf.placeholder(dtype=tf.int32, shape=input_mask.shape, name='input_mask') 298 | pinput_type_ids = tf.placeholder(dtype=tf.int32, shape=input_type_ids.shape, name='input_type_ids') 299 | 300 | model = modeling.BertModel( 301 | modeling.BertConfig.from_dict(config), 302 | is_training=False, 303 | input_ids=pinput_ids, 304 | input_mask=pinput_mask, 305 | token_type_ids=pinput_type_ids, 306 | use_one_hot_embeddings=False, 307 | ) 308 | 309 | sess.run(tf.global_variables_initializer()) 310 | 311 | return sess.trainable_variables() 312 | 313 | 314 | def run_tbert_once(config, params, input_ids, input_type_ids, input_mask): 315 | '''Runs tBERT model using TF parameters''' 316 | 317 | # init tBERT model 318 | state_dict = make_bert_state_dict(params, config['num_hidden_layers']) 319 | bert = Bert(config) 320 | bert.load_state_dict(state_dict) 321 | 322 | # run tBERT on the same input 323 | bert.eval() 324 | with torch.no_grad(): 325 | out = bert( 326 | torch.LongTensor(input_ids), 327 | torch.LongTensor(input_type_ids), 328 | torch.LongTensor(input_mask) 329 | ) 330 | 331 | return [v.data.numpy() for v in out] 332 | 333 | 334 | def run_tbert_pooler_once(config, params, input_ids, input_type_ids, input_mask): 335 | '''Runs tBERT model using TF parameters''' 336 | 337 | # init tBERT model 338 | state_dict = make_bert_pooler_state_dict(params, config['num_hidden_layers']) 339 | 340 | classifier = BertPooler(config) 341 | classifier.load_state_dict(state_dict) 342 | 343 | # run tBERT on the same input 344 | classifier.eval() 345 | with torch.no_grad(): 346 | out = classifier( 347 | torch.LongTensor(input_ids), 348 | torch.LongTensor(input_type_ids), 349 | torch.LongTensor(input_mask) 350 | ) 351 | 352 | return out.data.numpy() 353 | -------------------------------------------------------------------------------- /tbert/transformer.py: -------------------------------------------------------------------------------- 1 | # The MIT License 2 | # Copyright 2019 Innodata Labs and Mike Kroutikov 3 | # 4 | # PyTorch port of 5 | # https://github.com/google-research/bert/modeling.py 6 | # 7 | # Original code copyright follows: 8 | # 9 | # Copyright 2018 The Google AI Language Team Authors. 10 | # 11 | # Licensed under the Apache License, Version 2.0 (the "License"); 12 | # you may not use this file except in compliance with the License. 13 | # You may obtain a copy of the License at 14 | # 15 | # http://www.apache.org/licenses/LICENSE-2.0 16 | # 17 | # Unless required by applicable law or agreed to in writing, software 18 | # distributed under the License is distributed on an "AS IS" BASIS, 19 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 20 | # See the License for the specific language governing permissions and 21 | # limitations under the License.import json 22 | # 23 | import torch 24 | from tbert.gelu import gelu 25 | from tbert.attention import Attention, init_linear 26 | 27 | 28 | class TransformerEncoder(torch.nn.Module): 29 | 30 | def __init__(self, 31 | hidden_size=768, 32 | num_heads=12, 33 | intermediate_size=3072, 34 | dropout=0.1, 35 | initializer_range=0.02): 36 | ''' 37 | hidden_size - hidden size, must be multiple of num_heads 38 | num_heads - number of attention heads. 39 | intermediate_size - size of the intermediate dense layer 40 | dropout - dropout probability (0. means "no dropout") 41 | initializer_range - stddev for random weight matrix initialization 42 | ''' 43 | torch.nn.Module.__init__(self) 44 | 45 | if hidden_size % num_heads: 46 | raise ValueError( 47 | 'hidden size must be a multiple of the number of attention heads' 48 | ) 49 | 50 | self.attention = Attention( 51 | hidden_size, 52 | hidden_size, 53 | num_heads, 54 | hidden_size // num_heads, 55 | dropout=dropout, 56 | initializer_range=initializer_range 57 | ) 58 | 59 | self.dense = torch.nn.Linear(hidden_size, hidden_size) 60 | self.dropout = torch.nn.Dropout(dropout) 61 | self.dense_layer_norm = torch.nn.LayerNorm(hidden_size, eps=1.e-12) 62 | self.intermediate = torch.nn.Linear(hidden_size, intermediate_size) 63 | self.output = torch.nn.Linear(intermediate_size, hidden_size) 64 | self.output_layer_norm = torch.nn.LayerNorm(hidden_size, eps=1.e-12) 65 | 66 | init_linear(self.dense, initializer_range) 67 | init_linear(self.intermediate, initializer_range) 68 | init_linear(self.output, initializer_range) 69 | 70 | def forward(self, inp, att_mask=None, batch_size=1): 71 | ''' 72 | B - batch size 73 | S - sequence length 74 | H - hidden size 75 | 76 | inp - a float matrix with embedded input sequences, shape [B*S, H] 77 | att_mask - an int tensor of shape [B, 1, S, S] - the self-attention mask 78 | batch_size - batch size 79 | 80 | Returns: a matrix of the same dims as inp (so that encoders are 81 | stackable) 82 | ''' 83 | # --> [B*S, H] 84 | x = self.attention(inp, inp, inp, att_mask, batch_size=batch_size) 85 | # --> [B*S, H] 86 | x = self.dense(x) 87 | x = self.dropout(x) 88 | x = self.dense_layer_norm(inp + x) 89 | x2 = self.output(gelu(self.intermediate(x))) 90 | x = self.output_layer_norm(x + x2) 91 | 92 | return x 93 | 94 | 95 | class TransformerDecoder(torch.nn.Module): 96 | 97 | def __init__(self, 98 | hidden_size=768, 99 | num_heads=12, 100 | intermediate_size=3072, 101 | dropout=0.1, 102 | initializer_range=0.02): 103 | ''' 104 | hidden_size - hidden size, must be multiple of num_heads 105 | num_heads - number of attention heads. 106 | intermediate_size - size of the intermediate dense layer 107 | dropout - dropout probability (0. means "no dropout") 108 | ''' 109 | torch.nn.Module.__init__(self) 110 | 111 | if hidden_size % num_heads: 112 | raise ValueError( 113 | 'hidden size must be a multiple of the number of attention heads' 114 | ) 115 | 116 | self.attention = Attention( 117 | hidden_size, 118 | hidden_size, 119 | num_heads, 120 | hidden_size // num_heads, 121 | dropout=dropout, 122 | initializer_range=initializer_range 123 | ) 124 | 125 | self.encoder_attention = Attention( 126 | hidden_size, 127 | hidden_size, 128 | num_heads, 129 | hidden_size // num_heads, 130 | dropout=dropout, 131 | initializer_range=initializer_range 132 | ) 133 | 134 | self.dense = torch.nn.Linear(hidden_size, hidden_size) 135 | self.dropout = torch.nn.Dropout(dropout) 136 | self.dense_layer_norm = torch.nn.LayerNorm(hidden_size, eps=1.e-12) 137 | self.intermediate = torch.nn.Linear(hidden_size, intermediate_size) 138 | self.output = torch.nn.Linear(intermediate_size, hidden_size) 139 | self.output_layer_norm = torch.nn.LayerNorm(hidden_size, eps=1.e-12) 140 | 141 | init_linear(self.dense, initializer_range) 142 | init_linear(self.intermediate, initializer_range) 143 | init_linear(self.output, initializer_range) 144 | 145 | def forward(self, inp, enc_inp, att_mask=None, enc_att_mask=None, batch_size=1): 146 | ''' 147 | B - batch size 148 | S - sequence length 149 | E - encoder sequence length 150 | H - hidden size 151 | 152 | inp - a float matrix with embedded input sequences, shape [B*S, H] 153 | enc_inp - a float matrix with embedded activations from encoder layer, shape [B*E, H] 154 | att_mask - an int tensor of shape [B, 1, S, S] - the self-attention mask 155 | enc_att_mask - an int tensor of shape [B, 1, E, S] - the attention mask from encoder data 156 | batch_size - batch size 157 | 158 | Returns: a matrix of the same dims as inp (so that decoders are 159 | stackable) 160 | ''' 161 | # --> [B*S, H] 162 | x = self.attention(inp, inp, inp, att_mask, batch_size=batch_size) 163 | 164 | # apply attention on encoder 165 | x = self.encoder_attention(enc_inp, x, x, enc_att_mask, batch_size=batch_size) 166 | 167 | # --> [B*S, H] 168 | x = self.dense(x) 169 | x = self.dropout(x) 170 | x = self.dense_layer_norm(inp + x) 171 | x2 = self.output(gelu(self.intermediate(x))) 172 | x = self.output_layer_norm(x + x2) 173 | 174 | return x 175 | --------------------------------------------------------------------------------