├── .gitignore ├── .travis.yml ├── LICENSE ├── MANIFEST.in ├── README.rst ├── bert ├── __init__.py ├── attention.py ├── embeddings.py ├── layer.py ├── loader.py ├── loader_albert.py ├── model.py ├── tokenization │ ├── __init__.py │ ├── albert_tokenization.py │ └── bert_tokenization.py ├── transformer.py └── version.py ├── check-before-commit.sh ├── examples ├── gpu_movie_reviews.ipynb └── tpu_movie_reviews.ipynb ├── requirements-dev.txt ├── requirements.txt ├── setup.py └── tests ├── __init__.py ├── ext ├── __init__.py └── modeling.py ├── nonci ├── __init__.py ├── test_attention.py ├── test_bert.py ├── test_compare_pretrained.py ├── test_load_pretrained_weights.py ├── test_multi_lang.py ├── test_stock_weights.py └── test_transformer.py ├── test_adapter_finetune.py ├── test_adapter_freeze.py ├── test_albert_create.py ├── test_attention.py ├── test_common.py ├── test_compare_activations.py ├── test_eager.py ├── test_extend_segments.py └── test_extend_tokens.py /.gitignore: -------------------------------------------------------------------------------- 1 | .idea 2 | .python-version 3 | .venv 4 | **/__pycache__/ 5 | *.egg-info/ 6 | *.pyc 7 | *.coverage 8 | build/ 9 | dist/ 10 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | sudo: false 2 | language: python 3 | python: 4 | - "3.6" 5 | dist: trusty 6 | 7 | # Enable 3.7 without globally enabling sudo and dist: xenial for other build jobs 8 | matrix: 9 | include: 10 | - python: 3.7 11 | dist: xenial 12 | sudo: true 13 | 14 | env: 15 | - PEP8_IGNORE="E221,E501,W504,W391,E241" 16 | 17 | # command to install dependencies 18 | install: 19 | - pip install --upgrade pip setuptools 20 | - pip install tensorflow 21 | - pip install -r requirements.txt 22 | - pip install -r requirements-dev.txt 23 | 24 | # command to run tests 25 | # require 100% coverage (not including test files) to pass Travis CI test 26 | # To skip pypy: - if [[ $TRAVIS_PYTHON_VERSION != 'pypy' ]]; then DOSTUFF ; fi 27 | script: 28 | - export MAJOR_PYTHON_VERSION=`echo $TRAVIS_PYTHON_VERSION | cut -c 1` 29 | - coverage run --source=bert 30 | $(which nosetests) -v 31 | --with-doctest tests/ 32 | --exclude-dir tests/nonci/ 33 | - if [[ $TRAVIS_PYTHON_VERSION == 3.7 ]]; then coverage report --show-missing --fail-under=60 --omit "bert/tokenization/*_tokenization.py" ; fi 34 | - if [[ $TRAVIS_PYTHON_VERSION == 3.7 ]]; then pycodestyle --ignore=$PEP8_IGNORE --exclude=tests,.venv -r --show-source . ; fi 35 | # For convenience, make sure simple test commands work 36 | - python setup.py develop 37 | # - py.test 38 | # - nosetests -e tests.nonci.* 39 | 40 | # load coverage status to https://coveralls.io 41 | after_success: 42 | - 'echo travis-python: $TRAVIS_PYTHON_VERSION' 43 | - if [[ $TRAVIS_PYTHON_VERSION == 3.7 ]]; then pip install coveralls; COVERALLS_REPO_TOKEN=$COVERALLS_REPO_TOKEN coveralls ; fi 44 | 45 | notifications: 46 | email: false 47 | 48 | deploy: 49 | provider: pypi 50 | username: "__token__" 51 | password: 52 | secure: vk0//gdQP/beFjOZ2S9nBOgCwTFGROPKvzjZUfZCoE1eeI6w+PKY37dqYhglglum0Hdz8MJUqpzNkmLroPbS/7J7g+EnSMXV/aPmteCPv5xwwBbUPTlaJN3Amx3MNERiOXjliDUeDXyBFZOO55pF5Ytqv2PKH6StAaNAxtOETb23tEA7xlC2wX1/lhKA31jGqTwmEQKQMc77mvXjTugvjjwrWe0C7ijOjzGu8TypR52VEFdM+m1/KFlU1sajD6BwWcYk15uLbrJS87bN37RtQWnzJo39sQvckj397TLlSaroALp9cLONZBriEZOgEYK3k4KxdbzhiT9HYKGLO050LmGzH1bXOeZ0FTrqejcUiOOjCkBQGSlOyiOaZcWXXrMgwlgyQY+Nsl40hzUhNfkJIahoUsg/LlTjkPd5GWCOaNrbfdU2Q2hHAPRWJh03mg1uE5nXZbqeMEf3wjgfjqGQHQP6aWt/tDjOM7flnjvJOaKZRQxscTwGzMCDH1aZEsDny+d8TmnENb7/pcGHL811HlQe+wzMAND9si+BrbPet3vIi1HuJKfwT+DJWDi/Jwur+vpgSxBgeQDlobPAm06hRxpqhTZS0xda48PVII07wDJl+L5Q3r8fM6eV9akj3Xx+zOdUiqO74pFQDF+UhCiEXroR8CMilJzTt2WJIFPtu8E= 53 | on: 54 | tags: true 55 | branch: master 56 | condition: "$TRAVIS_PYTHON_VERSION = 3.7" 57 | skip_existing: true 58 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 kpe 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include requirements-dev.txt 2 | include requirements.txt 3 | include README.rst 4 | -------------------------------------------------------------------------------- /README.rst: -------------------------------------------------------------------------------- 1 | BERT for TensorFlow v2 2 | ====================== 3 | 4 | |Build Status| |Coverage Status| |Version Status| |Python Versions| |Downloads| 5 | 6 | This repo contains a `TensorFlow 2.0`_ `Keras`_ implementation of `google-research/bert`_ 7 | with support for loading of the original `pre-trained weights`_, 8 | and producing activations **numerically identical** to the one calculated by the original model. 9 | 10 | `ALBERT`_ and `adapter-BERT`_ are also supported by setting the corresponding 11 | configuration parameters (``shared_layer=True``, ``embedding_size`` for `ALBERT`_ 12 | and ``adapter_size`` for `adapter-BERT`_). Setting both will result in an adapter-ALBERT 13 | by sharing the BERT parameters across all layers while adapting every layer with layer specific adapter. 14 | 15 | The implementation is build from scratch using only basic tensorflow operations, 16 | following the code in `google-research/bert/modeling.py`_ 17 | (but skipping dead code and applying some simplifications). It also utilizes `kpe/params-flow`_ to reduce 18 | common Keras boilerplate code (related to passing model and layer configuration arguments). 19 | 20 | `bert-for-tf2`_ should work with both `TensorFlow 2.0`_ and `TensorFlow 1.14`_ or newer. 21 | 22 | NEWS 23 | ---- 24 | - **30.Jul.2020** - `VERBOSE=0` env variable for suppressing stdout output. 25 | - **06.Apr.2020** - using latest ``py-params`` introducing ``WithParams`` base for ``Layer`` 26 | and ``Model``. See news in `kpe/py-params`_ for how to update (``_construct()`` signature has change and 27 | requires calling ``super().__construct()``). 28 | - **06.Jan.2020** - support for loading the tar format weights from `google-research/ALBERT`. 29 | - **18.Nov.2019** - ALBERT tokenization added (make sure to import as ``from bert import albert_tokenization`` or ``from bert import bert_tokenization``). 30 | 31 | - **08.Nov.2019** - using v2 per default when loading the `TFHub/albert`_ weights of `google-research/ALBERT`_. 32 | 33 | - **05.Nov.2019** - minor ALBERT word embeddings refactoring (``word_embeddings_2`` -> ``word_embeddings_projector``) and related parameter freezing fixes. 34 | 35 | - **04.Nov.2019** - support for extra (task specific) token embeddings using negative token ids. 36 | 37 | - **29.Oct.2019** - support for loading of the pre-trained ALBERT weights released by `google-research/ALBERT`_ at `TFHub/albert`_. 38 | 39 | - **11.Oct.2019** - support for loading of the pre-trained ALBERT weights released by `brightmart/albert_zh ALBERT for Chinese`_. 40 | 41 | - **10.Oct.2019** - support for `ALBERT`_ through the ``shared_layer=True`` 42 | and ``embedding_size=128`` params. 43 | 44 | - **03.Sep.2019** - walkthrough on fine tuning with adapter-BERT and storing the 45 | fine tuned fraction of the weights in a separate checkpoint (see ``tests/test_adapter_finetune.py``). 46 | 47 | - **02.Sep.2019** - support for extending the token type embeddings of a pre-trained model 48 | by returning the mismatched weights in ``load_stock_weights()`` (see ``tests/test_extend_segments.py``). 49 | 50 | - **25.Jul.2019** - there are now two colab notebooks under ``examples/`` showing how to 51 | fine-tune an IMDB Movie Reviews sentiment classifier from pre-trained BERT weights 52 | using an `adapter-BERT`_ model architecture on a GPU or TPU in Google Colab. 53 | 54 | - **28.Jun.2019** - v.0.3.0 supports `adapter-BERT`_ (`google-research/adapter-bert`_) 55 | for "Parameter-Efficient Transfer Learning for NLP", i.e. fine-tuning small overlay adapter 56 | layers over BERT's transformer encoders without changing the frozen BERT weights. 57 | 58 | 59 | 60 | LICENSE 61 | ------- 62 | 63 | MIT. See `License File `_. 64 | 65 | Install 66 | ------- 67 | 68 | ``bert-for-tf2`` is on the Python Package Index (PyPI): 69 | 70 | :: 71 | 72 | pip install bert-for-tf2 73 | 74 | 75 | Usage 76 | ----- 77 | 78 | BERT in `bert-for-tf2` is implemented as a Keras layer. You could instantiate it like this: 79 | 80 | .. code:: python 81 | 82 | from bert import BertModelLayer 83 | 84 | l_bert = BertModelLayer(**BertModelLayer.Params( 85 | vocab_size = 16000, # embedding params 86 | use_token_type = True, 87 | use_position_embeddings = True, 88 | token_type_vocab_size = 2, 89 | 90 | num_layers = 12, # transformer encoder params 91 | hidden_size = 768, 92 | hidden_dropout = 0.1, 93 | intermediate_size = 4*768, 94 | intermediate_activation = "gelu", 95 | 96 | adapter_size = None, # see arXiv:1902.00751 (adapter-BERT) 97 | 98 | shared_layer = False, # True for ALBERT (arXiv:1909.11942) 99 | embedding_size = None, # None for BERT, wordpiece embedding size for ALBERT 100 | 101 | name = "bert" # any other Keras layer params 102 | )) 103 | 104 | or by using the ``bert_config.json`` from a `pre-trained google model`_: 105 | 106 | .. code:: python 107 | 108 | import bert 109 | 110 | model_dir = ".models/uncased_L-12_H-768_A-12" 111 | 112 | bert_params = bert.params_from_pretrained_ckpt(model_dir) 113 | l_bert = bert.BertModelLayer.from_params(bert_params, name="bert") 114 | 115 | 116 | now you can use the BERT layer in your Keras model like this: 117 | 118 | .. code:: python 119 | 120 | from tensorflow import keras 121 | 122 | max_seq_len = 128 123 | l_input_ids = keras.layers.Input(shape=(max_seq_len,), dtype='int32') 124 | l_token_type_ids = keras.layers.Input(shape=(max_seq_len,), dtype='int32') 125 | 126 | # using the default token_type/segment id 0 127 | output = l_bert(l_input_ids) # output: [batch_size, max_seq_len, hidden_size] 128 | model = keras.Model(inputs=l_input_ids, outputs=output) 129 | model.build(input_shape=(None, max_seq_len)) 130 | 131 | # provide a custom token_type/segment id as a layer input 132 | output = l_bert([l_input_ids, l_token_type_ids]) # [batch_size, max_seq_len, hidden_size] 133 | model = keras.Model(inputs=[l_input_ids, l_token_type_ids], outputs=output) 134 | model.build(input_shape=[(None, max_seq_len), (None, max_seq_len)]) 135 | 136 | if you choose to use `adapter-BERT`_ by setting the `adapter_size` parameter, 137 | you would also like to freeze all the original BERT layers by calling: 138 | 139 | .. code:: python 140 | 141 | l_bert.apply_adapter_freeze() 142 | 143 | and once the model has been build or compiled, the original pre-trained weights 144 | can be loaded in the BERT layer: 145 | 146 | .. code:: python 147 | 148 | import bert 149 | 150 | bert_ckpt_file = os.path.join(model_dir, "bert_model.ckpt") 151 | bert.load_stock_weights(l_bert, bert_ckpt_file) 152 | 153 | **N.B.** see `tests/test_bert_activations.py`_ for a complete example. 154 | 155 | FAQ 156 | --- 157 | 0. In all the examlpes bellow, **please note** the line: 158 | 159 | .. code:: python 160 | 161 | # use in a Keras Model here, and call model.build() 162 | 163 | for a quick test, you can replace it with something like: 164 | 165 | .. code:: python 166 | 167 | model = keras.models.Sequential([ 168 | keras.layers.InputLayer(input_shape=(128,)), 169 | l_bert, 170 | keras.layers.Lambda(lambda x: x[:, 0, :]), 171 | keras.layers.Dense(2) 172 | ]) 173 | model.build(input_shape=(None, 128)) 174 | 175 | 176 | 1. How to use BERT with the `google-research/bert`_ pre-trained weights? 177 | 178 | .. code:: python 179 | 180 | model_name = "uncased_L-12_H-768_A-12" 181 | model_dir = bert.fetch_google_bert_model(model_name, ".models") 182 | model_ckpt = os.path.join(model_dir, "bert_model.ckpt") 183 | 184 | bert_params = bert.params_from_pretrained_ckpt(model_dir) 185 | l_bert = bert.BertModelLayer.from_params(bert_params, name="bert") 186 | 187 | # use in a Keras Model here, and call model.build() 188 | 189 | bert.load_bert_weights(l_bert, model_ckpt) # should be called after model.build() 190 | 191 | 2. How to use ALBERT with the `google-research/ALBERT`_ pre-trained weights (fetching from TFHub)? 192 | 193 | see `tests/nonci/test_load_pretrained_weights.py `_: 194 | 195 | .. code:: python 196 | 197 | model_name = "albert_base" 198 | model_dir = bert.fetch_tfhub_albert_model(model_name, ".models") 199 | model_params = bert.albert_params(model_name) 200 | l_bert = bert.BertModelLayer.from_params(model_params, name="albert") 201 | 202 | # use in a Keras Model here, and call model.build() 203 | 204 | bert.load_albert_weights(l_bert, albert_dir) # should be called after model.build() 205 | 206 | 3. How to use ALBERT with the `google-research/ALBERT`_ pre-trained weights (non TFHub)? 207 | 208 | see `tests/nonci/test_load_pretrained_weights.py `_: 209 | 210 | .. code:: python 211 | 212 | model_name = "albert_base_v2" 213 | model_dir = bert.fetch_google_albert_model(model_name, ".models") 214 | model_ckpt = os.path.join(albert_dir, "model.ckpt-best") 215 | 216 | model_params = bert.albert_params(model_dir) 217 | l_bert = bert.BertModelLayer.from_params(model_params, name="albert") 218 | 219 | # use in a Keras Model here, and call model.build() 220 | 221 | bert.load_albert_weights(l_bert, model_ckpt) # should be called after model.build() 222 | 223 | 4. How to use ALBERT with the `brightmart/albert_zh`_ pre-trained weights? 224 | 225 | see `tests/nonci/test_albert.py `_: 226 | 227 | .. code:: python 228 | 229 | model_name = "albert_base" 230 | model_dir = bert.fetch_brightmart_albert_model(model_name, ".models") 231 | model_ckpt = os.path.join(model_dir, "albert_model.ckpt") 232 | 233 | bert_params = bert.params_from_pretrained_ckpt(model_dir) 234 | l_bert = bert.BertModelLayer.from_params(bert_params, name="bert") 235 | 236 | # use in a Keras Model here, and call model.build() 237 | 238 | bert.load_albert_weights(l_bert, model_ckpt) # should be called after model.build() 239 | 240 | 5. How to tokenize the input for the `google-research/bert`_ models? 241 | 242 | .. code:: python 243 | 244 | do_lower_case = not (model_name.find("cased") == 0 or model_name.find("multi_cased") == 0) 245 | bert.bert_tokenization.validate_case_matches_checkpoint(do_lower_case, model_ckpt) 246 | vocab_file = os.path.join(model_dir, "vocab.txt") 247 | tokenizer = bert.bert_tokenization.FullTokenizer(vocab_file, do_lower_case) 248 | tokens = tokenizer.tokenize("Hello, BERT-World!") 249 | token_ids = tokenizer.convert_tokens_to_ids(tokens) 250 | 251 | 6. How to tokenize the input for `brightmart/albert_zh`? 252 | 253 | .. code:: python 254 | 255 | import params_flow pf 256 | 257 | # fetch the vocab file 258 | albert_zh_vocab_url = "https://raw.githubusercontent.com/brightmart/albert_zh/master/albert_config/vocab.txt" 259 | vocab_file = pf.utils.fetch_url(albert_zh_vocab_url, model_dir) 260 | 261 | tokenizer = bert.albert_tokenization.FullTokenizer(vocab_file) 262 | tokens = tokenizer.tokenize("你好世界") 263 | token_ids = tokenizer.convert_tokens_to_ids(tokens) 264 | 265 | 7. How to tokenize the input for the `google-research/ALBERT`_ models? 266 | 267 | .. code:: python 268 | 269 | import sentencepiece as spm 270 | 271 | spm_model = os.path.join(model_dir, "assets", "30k-clean.model") 272 | sp = spm.SentencePieceProcessor() 273 | sp.load(spm_model) 274 | do_lower_case = True 275 | 276 | processed_text = bert.albert_tokenization.preprocess_text("Hello, World!", lower=do_lower_case) 277 | token_ids = bert.albert_tokenization.encode_ids(sp, processed_text) 278 | 279 | 8. How to tokenize the input for the Chinese `google-research/ALBERT`_ models? 280 | 281 | .. code:: python 282 | 283 | import bert 284 | 285 | vocab_file = os.path.join(model_dir, "vocab.txt") 286 | tokenizer = bert.albert_tokenization.FullTokenizer(vocab_file=vocab_file) 287 | tokens = tokenizer.tokenize(u"你好世界") 288 | token_ids = tokenizer.convert_tokens_to_ids(tokens) 289 | 290 | Resources 291 | --------- 292 | 293 | - `BERT`_ - BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding 294 | - `adapter-BERT`_ - adapter-BERT: Parameter-Efficient Transfer Learning for NLP 295 | - `ALBERT`_ - ALBERT: A Lite BERT for Self-Supervised Learning of Language Representations 296 | - `google-research/bert`_ - the original `BERT`_ implementation 297 | - `google-research/ALBERT`_ - the original `ALBERT`_ implementation by Google 298 | - `google-research/albert(old)`_ - the old location of the original `ALBERT`_ implementation by Google 299 | - `brightmart/albert_zh`_ - pre-trained `ALBERT`_ weights for Chinese 300 | - `kpe/params-flow`_ - A Keras coding style for reducing `Keras`_ boilerplate code in custom layers by utilizing `kpe/py-params`_ 301 | 302 | .. _`kpe/params-flow`: https://github.com/kpe/params-flow 303 | .. _`kpe/py-params`: https://github.com/kpe/py-params 304 | .. _`bert-for-tf2`: https://github.com/kpe/bert-for-tf2 305 | 306 | .. _`Keras`: https://keras.io 307 | .. _`pre-trained weights`: https://github.com/google-research/bert#pre-trained-models 308 | .. _`google-research/bert`: https://github.com/google-research/bert 309 | .. _`google-research/bert/modeling.py`: https://github.com/google-research/bert/blob/master/modeling.py 310 | .. _`BERT`: https://arxiv.org/abs/1810.04805 311 | .. _`pre-trained google model`: https://github.com/google-research/bert 312 | .. _`tests/test_bert_activations.py`: https://github.com/kpe/bert-for-tf2/blob/master/tests/test_compare_activations.py 313 | .. _`TensorFlow 2.0`: https://www.tensorflow.org/versions/r2.0/api_docs/python/tf 314 | .. _`TensorFlow 1.14`: https://www.tensorflow.org/versions/r1.14/api_docs/python/tf 315 | 316 | .. _`google-research/adapter-bert`: https://github.com/google-research/adapter-bert/ 317 | .. _`adapter-BERT`: https://arxiv.org/abs/1902.00751 318 | .. _`ALBERT`: https://arxiv.org/abs/1909.11942 319 | .. _`brightmart/albert_zh ALBERT for Chinese`: https://github.com/brightmart/albert_zh 320 | .. _`brightmart/albert_zh`: https://github.com/brightmart/albert_zh 321 | .. _`google ALBERT weights`: https://github.com/google-research/google-research/tree/master/albert 322 | .. _`google-research/albert(old)`: https://github.com/google-research/google-research/tree/master/albert 323 | .. _`google-research/ALBERT`: https://github.com/google-research/ALBERT 324 | .. _`TFHub/albert`: https://tfhub.dev/google/albert_base/2 325 | 326 | .. |Build Status| image:: https://travis-ci.com/kpe/bert-for-tf2.svg?branch=master 327 | :target: https://travis-ci.com/kpe/bert-for-tf2 328 | .. |Coverage Status| image:: https://coveralls.io/repos/kpe/bert-for-tf2/badge.svg?branch=master 329 | :target: https://coveralls.io/r/kpe/bert-for-tf2?branch=master 330 | .. |Version Status| image:: https://badge.fury.io/py/bert-for-tf2.svg 331 | :target: https://badge.fury.io/py/bert-for-tf2 332 | .. |Python Versions| image:: https://img.shields.io/pypi/pyversions/bert-for-tf2.svg 333 | .. |Downloads| image:: https://img.shields.io/pypi/dm/bert-for-tf2.svg 334 | .. |Twitter| image:: https://img.shields.io/twitter/follow/siddhadev?logo=twitter&label=&style= 335 | :target: https://twitter.com/intent/user?screen_name=siddhadev 336 | -------------------------------------------------------------------------------- /bert/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # 3 | # created by kpe on 15.Mar.2019 at 15:28 4 | # 5 | from __future__ import division, absolute_import, print_function 6 | 7 | from .version import __version__ 8 | 9 | from .attention import AttentionLayer 10 | from .layer import Layer 11 | from .model import BertModelLayer 12 | 13 | from .tokenization import bert_tokenization 14 | from .tokenization import albert_tokenization 15 | 16 | from .loader import StockBertConfig, load_stock_weights, params_from_pretrained_ckpt 17 | from .loader import load_stock_weights as load_bert_weights 18 | from .loader import bert_models_google, fetch_google_bert_model 19 | from .loader_albert import load_albert_weights, albert_params 20 | from .loader_albert import albert_models_tfhub, albert_models_brightmart 21 | from .loader_albert import fetch_tfhub_albert_model, fetch_brightmart_albert_model, fetch_google_albert_model 22 | -------------------------------------------------------------------------------- /bert/attention.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # 3 | # created by kpe on 15.Mar.2019 at 12:52 4 | # 5 | 6 | from __future__ import absolute_import, division, print_function 7 | 8 | import tensorflow as tf 9 | from tensorflow.python import keras 10 | from tensorflow.python.keras import backend as K 11 | 12 | from bert.layer import Layer 13 | 14 | 15 | class AttentionLayer(Layer): 16 | class Params(Layer.Params): 17 | num_heads = None 18 | size_per_head = None 19 | initializer_range = 0.02 20 | query_activation = None 21 | key_activation = None 22 | value_activation = None 23 | attention_dropout = 0.1 24 | negative_infinity = -10000.0 # used for attention scores before softmax 25 | 26 | @staticmethod 27 | def create_attention_mask(from_shape, input_mask): 28 | """ 29 | Creates 3D attention. 30 | :param from_shape: [batch_size, from_seq_len, ...] 31 | :param input_mask: [batch_size, seq_len] 32 | :return: [batch_size, from_seq_len, seq_len] 33 | """ 34 | 35 | mask = tf.cast(tf.expand_dims(input_mask, axis=1), tf.float32) # [B, 1, T] 36 | ones = tf.expand_dims(tf.ones(shape=from_shape[:2], dtype=tf.float32), axis=-1) # [B, F, 1] 37 | mask = ones * mask # broadcast along two dimensions 38 | 39 | return mask # [B, F, T] 40 | 41 | def _construct(self, **kwargs): 42 | super()._construct(**kwargs) 43 | self.query_activation = self.params.query_activation 44 | self.key_activation = self.params.key_activation 45 | self.value_activation = self.params.value_activation 46 | 47 | self.query_layer = None 48 | self.key_layer = None 49 | self.value_layer = None 50 | 51 | self.supports_masking = True 52 | 53 | # noinspection PyAttributeOutsideInit 54 | def build(self, input_shape): 55 | self.input_spec = keras.layers.InputSpec(shape=input_shape) 56 | 57 | dense_units = self.params.num_heads * self.params.size_per_head # N*H 58 | # 59 | # B, F, T, N, H - batch, from_seq_len, to_seq_len, num_heads, size_per_head 60 | # 61 | self.query_layer = keras.layers.Dense(units=dense_units, activation=self.query_activation, 62 | kernel_initializer=self.create_initializer(), 63 | name="query") 64 | self.key_layer = keras.layers.Dense(units=dense_units, activation=self.key_activation, 65 | kernel_initializer=self.create_initializer(), 66 | name="key") 67 | self.value_layer = keras.layers.Dense(units=dense_units, activation=self.value_activation, 68 | kernel_initializer=self.create_initializer(), 69 | name="value") 70 | self.dropout_layer = keras.layers.Dropout(self.params.attention_dropout) 71 | 72 | super(AttentionLayer, self).build(input_shape) 73 | 74 | def compute_output_shape(self, input_shape): 75 | from_shape = input_shape 76 | 77 | # from_shape # [B, F, W] [batch_size, from_seq_length, from_width] 78 | # input_mask_shape # [B, F] 79 | 80 | output_shape = [from_shape[0], from_shape[1], self.params.num_heads * self.params.size_per_head] 81 | 82 | return output_shape # [B, F, N*H] 83 | 84 | # noinspection PyUnusedLocal 85 | def call(self, inputs, mask=None, training=None, **kwargs): 86 | from_tensor = inputs 87 | to_tensor = inputs 88 | if mask is None: 89 | sh = self.get_shape_list(from_tensor) 90 | mask = tf.ones(sh[:2], dtype=tf.int32) 91 | attention_mask = AttentionLayer.create_attention_mask(tf.shape(input=from_tensor), mask) 92 | 93 | # from_tensor shape - [batch_size, from_seq_length, from_width] 94 | input_shape = tf.shape(input=from_tensor) 95 | batch_size, from_seq_len, from_width = input_shape[0], input_shape[1], input_shape[2] 96 | to_seq_len = from_seq_len 97 | 98 | # [B, F, N*H] -> [B, N, F, H] 99 | def transpose_for_scores(input_tensor, seq_len): 100 | output_shape = [batch_size, seq_len, 101 | self.params.num_heads, self.params.size_per_head] 102 | output_tensor = K.reshape(input_tensor, output_shape) 103 | return tf.transpose(a=output_tensor, perm=[0, 2, 1, 3]) # [B,N,F,H] 104 | 105 | query = self.query_layer(from_tensor) # [B,F, N*H] [batch_size, from_seq_len, N*H] 106 | key = self.key_layer(to_tensor) # [B,T, N*H] 107 | value = self.value_layer(to_tensor) # [B,T, N*H] 108 | 109 | query = transpose_for_scores(query, from_seq_len) # [B, N, F, H] 110 | key = transpose_for_scores(key, to_seq_len) # [B, N, T, H] 111 | 112 | attention_scores = tf.matmul(query, key, transpose_b=True) # [B, N, F, T] 113 | attention_scores = attention_scores / tf.sqrt(float(self.params.size_per_head)) 114 | 115 | if attention_mask is not None: 116 | attention_mask = tf.expand_dims(attention_mask, axis=1) # [B, 1, F, T] 117 | # {1, 0} -> {0.0, -inf} 118 | adder = (1.0 - tf.cast(attention_mask, tf.float32)) * self.params.negative_infinity 119 | attention_scores = tf.add(attention_scores, adder) # adding to softmax -> its like removing them entirely 120 | 121 | # scores to probabilities 122 | attention_probs = tf.nn.softmax(attention_scores) # [B, N, F, T] 123 | 124 | # This is actually dropping out entire tokens to attend to, which might 125 | # seem a bit unusual, but is taken from the original Transformer paper. 126 | attention_probs = self.dropout_layer(attention_probs, 127 | training=training) # [B, N, F, T] 128 | 129 | # [B,T,N,H] 130 | value = tf.reshape(value, [batch_size, to_seq_len, 131 | self.params.num_heads, self.params.size_per_head]) 132 | value = tf.transpose(a=value, perm=[0, 2, 1, 3]) # [B, N, T, H] 133 | 134 | context_layer = tf.matmul(attention_probs, value) # [B, N, F, H] 135 | context_layer = tf.transpose(a=context_layer, perm=[0, 2, 1, 3]) # [B, F, N, H] 136 | 137 | output_shape = [batch_size, from_seq_len, 138 | self.params.num_heads * self.params.size_per_head] 139 | context_layer = tf.reshape(context_layer, output_shape) 140 | return context_layer # [B, F, N*H] 141 | 142 | # noinspection PyUnusedLocal 143 | def compute_mask(self, inputs, mask=None): 144 | return mask # [B, F] 145 | 146 | -------------------------------------------------------------------------------- /bert/embeddings.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # 3 | # created by kpe on 28.Mar.2019 at 12:33 4 | # 5 | 6 | from __future__ import absolute_import, division, print_function 7 | 8 | import tensorflow as tf 9 | import params_flow as pf 10 | 11 | from tensorflow import keras 12 | from tensorflow.keras import backend as K 13 | 14 | import bert 15 | 16 | 17 | class PositionEmbeddingLayer(bert.Layer): 18 | class Params(bert.Layer.Params): 19 | max_position_embeddings = 512 20 | hidden_size = 128 21 | 22 | # noinspection PyUnusedLocal 23 | def _construct(self, **kwargs): 24 | super()._construct(**kwargs) 25 | self.embedding_table = None 26 | 27 | # noinspection PyAttributeOutsideInit 28 | def build(self, input_shape): 29 | # input_shape: () of seq_len 30 | if input_shape is not None: 31 | assert input_shape.ndims == 0 32 | self.input_spec = keras.layers.InputSpec(shape=input_shape, dtype='int32') 33 | else: 34 | self.input_spec = keras.layers.InputSpec(shape=(), dtype='int32') 35 | 36 | self.embedding_table = self.add_weight(name="embeddings", 37 | dtype=K.floatx(), 38 | shape=[self.params.max_position_embeddings, self.params.hidden_size], 39 | initializer=self.create_initializer()) 40 | super(PositionEmbeddingLayer, self).build(input_shape) 41 | 42 | # noinspection PyUnusedLocal 43 | def call(self, inputs, **kwargs): 44 | # just return the embedding after verifying 45 | # that seq_len is less than max_position_embeddings 46 | seq_len = inputs 47 | 48 | assert_op = tf.compat.v2.debugging.assert_less_equal(seq_len, self.params.max_position_embeddings) 49 | 50 | with tf.control_dependencies([assert_op]): 51 | # slice to seq_len 52 | full_position_embeddings = tf.slice(self.embedding_table, 53 | [0, 0], 54 | [seq_len, -1]) 55 | output = full_position_embeddings 56 | return output 57 | 58 | 59 | class EmbeddingsProjector(bert.Layer): 60 | class Params(bert.Layer.Params): 61 | hidden_size = 768 62 | embedding_size = None # None for BERT, not None for ALBERT 63 | project_embeddings_with_bias = True # in ALBERT - True for Google, False for brightmart/albert_zh 64 | 65 | # noinspection PyUnusedLocal 66 | def _construct(self, **kwargs): 67 | super()._construct(**kwargs) 68 | self.projector_layer = None # for ALBERT 69 | self.projector_bias_layer = None # for ALBERT 70 | 71 | def build(self, input_shape): 72 | emb_shape = input_shape 73 | self.input_spec = keras.layers.InputSpec(shape=emb_shape) 74 | assert emb_shape[-1] == self.params.embedding_size 75 | 76 | # ALBERT word embeddings projection 77 | self.projector_layer = self.add_weight(name="projector", 78 | shape=[self.params.embedding_size, 79 | self.params.hidden_size], 80 | dtype=K.floatx()) 81 | if self.params.project_embeddings_with_bias: 82 | self.projector_bias_layer = self.add_weight(name="bias", 83 | shape=[self.params.hidden_size], 84 | dtype=K.floatx()) 85 | super(EmbeddingsProjector, self).build(input_shape) 86 | 87 | def call(self, inputs, **kwargs): 88 | input_embedding = inputs 89 | assert input_embedding.shape[-1] == self.params.embedding_size 90 | 91 | # ALBERT: project embedding to hidden_size 92 | output = tf.matmul(input_embedding, self.projector_layer) 93 | if self.projector_bias_layer is not None: 94 | output = tf.add(output, self.projector_bias_layer) 95 | 96 | return output 97 | 98 | 99 | class BertEmbeddingsLayer(bert.Layer): 100 | class Params(PositionEmbeddingLayer.Params, 101 | EmbeddingsProjector.Params): 102 | vocab_size = None 103 | use_token_type = True 104 | use_position_embeddings = True 105 | token_type_vocab_size = 2 106 | hidden_size = 768 107 | hidden_dropout = 0.1 108 | 109 | extra_tokens_vocab_size = None # size of the extra (task specific) token vocabulary (using negative token ids) 110 | 111 | # 112 | # ALBERT support - set embedding_size (or None for BERT) 113 | # 114 | embedding_size = None # None for BERT, not None for ALBERT 115 | project_embeddings_with_bias = True # in ALBERT - True for Google, False for brightmart/albert_zh 116 | project_position_embeddings = True # in ALEBRT - True for Google, False for brightmart/albert_zh 117 | 118 | mask_zero = False 119 | 120 | # noinspection PyUnusedLocal 121 | def _construct(self, **kwargs): 122 | super()._construct(**kwargs) 123 | self.word_embeddings_layer = None 124 | self.extra_word_embeddings_layer = None # for task specific tokens (negative token ids) 125 | self.token_type_embeddings_layer = None 126 | self.position_embeddings_layer = None 127 | self.word_embeddings_projector_layer = None # for ALBERT 128 | self.layer_norm_layer = None 129 | self.dropout_layer = None 130 | 131 | self.support_masking = self.params.mask_zero 132 | 133 | # noinspection PyAttributeOutsideInit 134 | def build(self, input_shape): 135 | if isinstance(input_shape, list): 136 | assert len(input_shape) == 2 137 | input_ids_shape, token_type_ids_shape = input_shape 138 | self.input_spec = [keras.layers.InputSpec(shape=input_ids_shape), 139 | keras.layers.InputSpec(shape=token_type_ids_shape)] 140 | else: 141 | input_ids_shape = input_shape 142 | self.input_spec = keras.layers.InputSpec(shape=input_ids_shape) 143 | 144 | # use either hidden_size for BERT or embedding_size for ALBERT 145 | embedding_size = self.params.hidden_size if self.params.embedding_size is None else self.params.embedding_size 146 | 147 | self.word_embeddings_layer = keras.layers.Embedding( 148 | input_dim=self.params.vocab_size, 149 | output_dim=embedding_size, 150 | mask_zero=self.params.mask_zero, 151 | name="word_embeddings" 152 | ) 153 | if self.params.extra_tokens_vocab_size is not None: 154 | self.extra_word_embeddings_layer = keras.layers.Embedding( 155 | input_dim=self.params.extra_tokens_vocab_size + 1, # +1 is for a /0 vector 156 | output_dim=embedding_size, 157 | mask_zero=self.params.mask_zero, 158 | embeddings_initializer=self.create_initializer(), 159 | name="extra_word_embeddings" 160 | ) 161 | 162 | # ALBERT word embeddings projection 163 | if self.params.embedding_size is not None: 164 | self.word_embeddings_projector_layer = EmbeddingsProjector.from_params( 165 | self.params, name="word_embeddings_projector") 166 | 167 | position_embedding_size = embedding_size if self.params.project_position_embeddings else self.params.hidden_size 168 | 169 | if self.params.use_token_type: 170 | self.token_type_embeddings_layer = keras.layers.Embedding( 171 | input_dim=self.params.token_type_vocab_size, 172 | output_dim=position_embedding_size, 173 | mask_zero=False, 174 | name="token_type_embeddings" 175 | ) 176 | if self.params.use_position_embeddings: 177 | self.position_embeddings_layer = PositionEmbeddingLayer.from_params( 178 | self.params, 179 | name="position_embeddings", 180 | hidden_size=position_embedding_size 181 | ) 182 | 183 | self.layer_norm_layer = pf.LayerNormalization(name="LayerNorm") 184 | self.dropout_layer = keras.layers.Dropout(rate=self.params.hidden_dropout) 185 | 186 | super(BertEmbeddingsLayer, self).build(input_shape) 187 | 188 | def call(self, inputs, mask=None, training=None): 189 | if isinstance(inputs, list): 190 | assert 2 == len(inputs), "Expecting inputs to be a [input_ids, token_type_ids] list" 191 | input_ids, token_type_ids = inputs 192 | else: 193 | input_ids = inputs 194 | token_type_ids = None 195 | 196 | input_ids = tf.cast(input_ids, dtype=tf.int32) 197 | 198 | if self.extra_word_embeddings_layer is not None: 199 | token_mask = tf.cast(tf.greater_equal(input_ids, 0), tf.int32) 200 | extra_mask = tf.cast(tf.less(input_ids, 0), tf.int32) 201 | token_ids = token_mask * input_ids 202 | extra_tokens = extra_mask * (-input_ids) 203 | token_output = self.word_embeddings_layer(token_ids) 204 | extra_output = self.extra_word_embeddings_layer(extra_tokens) 205 | embedding_output = tf.add(token_output, 206 | extra_output * tf.expand_dims(tf.cast(extra_mask, K.floatx()), axis=-1)) 207 | else: 208 | embedding_output = self.word_embeddings_layer(input_ids) 209 | 210 | # ALBERT: for brightmart/albert_zh weights - project only token embeddings 211 | if not self.params.project_position_embeddings: 212 | if self.word_embeddings_projector_layer: 213 | embedding_output = self.word_embeddings_projector_layer(embedding_output) 214 | 215 | if token_type_ids is not None: 216 | token_type_ids = tf.cast(token_type_ids, dtype=tf.int32) 217 | embedding_output += self.token_type_embeddings_layer(token_type_ids) 218 | 219 | if self.position_embeddings_layer is not None: 220 | seq_len = input_ids.shape.as_list()[1] 221 | emb_size = embedding_output.shape[-1] 222 | 223 | pos_embeddings = self.position_embeddings_layer(seq_len) 224 | # broadcast over all dimension except the last two [..., seq_len, width] 225 | broadcast_shape = [1] * (embedding_output.shape.ndims - 2) + [seq_len, emb_size] 226 | embedding_output += tf.reshape(pos_embeddings, broadcast_shape) 227 | 228 | embedding_output = self.layer_norm_layer(embedding_output) 229 | embedding_output = self.dropout_layer(embedding_output, training=training) 230 | 231 | # ALBERT: for google-research/albert weights - project all embeddings 232 | if self.params.project_position_embeddings: 233 | if self.word_embeddings_projector_layer: 234 | embedding_output = self.word_embeddings_projector_layer(embedding_output) 235 | 236 | return embedding_output # [B, seq_len, hidden_size] 237 | 238 | def compute_mask(self, inputs, mask=None): 239 | if isinstance(inputs, list): 240 | assert 2 == len(inputs), "Expecting inputs to be a [input_ids, token_type_ids] list" 241 | input_ids, token_type_ids = inputs 242 | else: 243 | input_ids = inputs 244 | token_type_ids = None 245 | 246 | if not self.support_masking: 247 | return None 248 | 249 | return tf.not_equal(input_ids, 0) 250 | -------------------------------------------------------------------------------- /bert/layer.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # 3 | # created by kpe on 28.Mar.2019 at 12:46 4 | # 5 | 6 | from __future__ import absolute_import, division, print_function 7 | 8 | import tensorflow as tf 9 | 10 | import params_flow as pf 11 | from params_flow.activations import gelu 12 | 13 | 14 | class Layer(pf.Layer): 15 | """ Common abstract base layer for all BERT layers. """ 16 | class Params(pf.Layer.Params): 17 | initializer_range = 0.02 18 | 19 | def create_initializer(self): 20 | return tf.keras.initializers.TruncatedNormal(stddev=self.params.initializer_range) 21 | 22 | @staticmethod 23 | def get_activation(activation_string): 24 | if not isinstance(activation_string, str): 25 | return activation_string 26 | if not activation_string: 27 | return None 28 | 29 | act = activation_string.lower() 30 | if act == "linear": 31 | return None 32 | elif act == "relu": 33 | return tf.nn.relu 34 | elif act == "gelu": 35 | return gelu 36 | elif act == "tanh": 37 | return tf.tanh 38 | else: 39 | raise ValueError("Unsupported activation: %s" % act) 40 | -------------------------------------------------------------------------------- /bert/loader.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # 3 | # created by kpe on 28.Mar.2019 at 14:01 4 | # 5 | 6 | from __future__ import absolute_import, division, print_function 7 | 8 | import os 9 | import re 10 | 11 | import tensorflow as tf 12 | from tensorflow import keras 13 | 14 | import params_flow as pf 15 | import params 16 | 17 | from bert.model import BertModelLayer 18 | 19 | _verbose = os.environ.get('VERBOSE', 1) # verbose print per default 20 | trace = print if int(_verbose) else lambda *a, **k: None 21 | 22 | bert_models_google = { 23 | "uncased_L-12_H-768_A-12": "https://storage.googleapis.com/bert_models/2018_10_18/uncased_L-12_H-768_A-12.zip", 24 | "uncased_L-24_H-1024_A-16": "https://storage.googleapis.com/bert_models/2018_10_18/uncased_L-24_H-1024_A-16.zip", 25 | "cased_L-12_H-768_A-12": "https://storage.googleapis.com/bert_models/2018_10_18/cased_L-12_H-768_A-12.zip", 26 | "cased_L-24_H-1024_A-16": "https://storage.googleapis.com/bert_models/2018_10_18/cased_L-24_H-1024_A-16.zip", 27 | "multi_cased_L-12_H-768_A-12": "https://storage.googleapis.com/bert_models/2018_11_23/multi_cased_L-12_H-768_A-12.zip", 28 | "multilingual_L-12_H-768_A-12": "https://storage.googleapis.com/bert_models/2018_11_03/multilingual_L-12_H-768_A-12.zip", 29 | "chinese_L-12_H-768_A-12": "https://storage.googleapis.com/bert_models/2018_11_03/chinese_L-12_H-768_A-12.zip", 30 | "wwm_uncased_L-24_H-1024_A-16": "https://storage.googleapis.com/bert_models/2019_05_30/wwm_uncased_L-24_H-1024_A-16.zip", 31 | "wwm_cased_L-24_H-1024_A-16": "https://storage.googleapis.com/bert_models/2019_05_30/wwm_cased_L-24_H-1024_A-16.zip", 32 | } 33 | 34 | 35 | def fetch_google_bert_model(model_name: str, fetch_dir: str): 36 | if model_name not in bert_models_google: 37 | raise ValueError("BERT model with name:[{}] not found, try one of:{}".format( 38 | model_name, bert_models_google)) 39 | else: 40 | fetch_url = bert_models_google[model_name] 41 | 42 | fetched_file = pf.utils.fetch_url(fetch_url, fetch_dir=fetch_dir) 43 | fetched_dir = pf.utils.unpack_archive(fetched_file) 44 | fetched_dir = os.path.join(fetched_dir, model_name) 45 | return fetched_dir 46 | 47 | 48 | def map_from_stock_variale_name(name, prefix="bert"): 49 | name = name.split(":")[0] 50 | ns = name.split("/") 51 | pns = prefix.split("/") 52 | 53 | # assert ns[0] == "bert" 54 | 55 | name = "/".join(pns + ns[1:]) 56 | ns = name.split("/") 57 | 58 | if ns[1] not in ["encoder", "embeddings"]: 59 | return None 60 | if ns[1] == "embeddings": 61 | if ns[2] == "LayerNorm": 62 | return name 63 | else: 64 | return name + "/embeddings" 65 | if ns[1] == "encoder": 66 | if ns[3] == "intermediate": 67 | return "/".join(ns[:4] + ns[5:]) 68 | else: 69 | return name 70 | return None 71 | 72 | 73 | def map_to_stock_variable_name(name, prefix="bert"): 74 | name = name.split(":")[0] 75 | ns = name.split("/") 76 | pns = prefix.split("/") 77 | 78 | if ns[:len(pns)] != pns: 79 | return None 80 | 81 | name = "/".join(["bert"] + ns[len(pns):]) 82 | ns = name.split("/") 83 | 84 | if ns[1] not in ["encoder", "embeddings"]: 85 | return None 86 | if ns[1] == "embeddings": 87 | if ns[2] == "LayerNorm": 88 | return name 89 | elif ns[2] == "word_embeddings_projector": 90 | ns[2] = "word_embeddings_2" 91 | if ns[3] == "projector": 92 | ns[3] = "embeddings" 93 | return "/".join(ns[:-1]) 94 | return "/".join(ns) 95 | else: 96 | return "/".join(ns[:-1]) 97 | if ns[1] == "encoder": 98 | if ns[3] == "intermediate": 99 | return "/".join(ns[:4] + ["dense"] + ns[4:]) 100 | else: 101 | return name 102 | return None 103 | 104 | 105 | class StockBertConfig(params.Params): 106 | attention_probs_dropout_prob = None # 0.1 107 | hidden_act = None # "gelu" 108 | hidden_dropout_prob = None # 0.1, 109 | hidden_size = None # 768, 110 | initializer_range = None # 0.02, 111 | intermediate_size = None # 3072, 112 | max_position_embeddings = None # 512, 113 | num_attention_heads = None # 12, 114 | num_hidden_layers = None # 12, 115 | type_vocab_size = None # 2, 116 | vocab_size = None # 30522 117 | 118 | # ALBERT params 119 | # directionality = None # "bidi" 120 | # pooler_fc_size = None # 768, 121 | # pooler_num_attention_heads = None # 12, 122 | # pooler_num_fc_layers = None # 3, 123 | # pooler_size_per_head = None # 128, 124 | # pooler_type = None # "first_token_transform", 125 | ln_type = None # "postln" # used for detecting brightmarts weights 126 | embedding_size = None # 128 127 | 128 | def to_bert_model_layer_params(self): 129 | return map_stock_config_to_params(self) 130 | 131 | 132 | def map_stock_config_to_params(bc): 133 | """ 134 | Converts the original BERT or ALBERT config dictionary 135 | to a `BertModelLayer.Params` instance. 136 | :return: a `BertModelLayer.Params` instance. 137 | """ 138 | bert_params = BertModelLayer.Params( 139 | num_layers=bc.num_hidden_layers, 140 | num_heads=bc.num_attention_heads, 141 | hidden_size=bc.hidden_size, 142 | hidden_dropout=bc.hidden_dropout_prob, 143 | attention_dropout=bc.attention_probs_dropout_prob, 144 | 145 | intermediate_size=bc.intermediate_size, 146 | intermediate_activation=bc.hidden_act, 147 | 148 | vocab_size=bc.vocab_size, 149 | use_token_type=True, 150 | use_position_embeddings=True, 151 | token_type_vocab_size=bc.type_vocab_size, 152 | max_position_embeddings=bc.max_position_embeddings, 153 | 154 | embedding_size=bc.embedding_size, 155 | shared_layer=bc.embedding_size is not None, 156 | ) 157 | return bert_params 158 | 159 | 160 | def params_from_pretrained_ckpt(bert_ckpt_dir): 161 | json_config_files = tf.io.gfile.glob(os.path.join(bert_ckpt_dir, "*_config*.json")) 162 | if len(json_config_files) != 1: 163 | raise ValueError("Can't glob for BERT config json at: {}/*_config*.json".format(bert_ckpt_dir)) 164 | 165 | config_file_name = os.path.basename(json_config_files[0]) 166 | bert_config_file = os.path.join(bert_ckpt_dir, config_file_name) 167 | 168 | with tf.io.gfile.GFile(bert_config_file, "r") as reader: 169 | bc = StockBertConfig.from_json_string(reader.read()) 170 | bert_params = map_stock_config_to_params(bc) 171 | is_brightmart_weights = bc["ln_type"] is not None 172 | bert_params.project_position_embeddings = not is_brightmart_weights # ALBERT: False for brightmart/weights 173 | bert_params.project_embeddings_with_bias = not is_brightmart_weights # ALBERT: False for brightmart/weights 174 | 175 | return bert_params 176 | 177 | 178 | def _checkpoint_exists(ckpt_path): 179 | cktp_files = tf.io.gfile.glob(ckpt_path + "*") 180 | return len(cktp_files) > 0 181 | 182 | 183 | def bert_prefix(bert: BertModelLayer): 184 | re_bert = re.compile(r'(.*)/(embeddings|encoder)/(.+):0') 185 | match = re_bert.match(bert.weights[0].name) 186 | assert match, "Unexpected bert layer: {} weight:{}".format(bert, bert.weights[0].name) 187 | prefix = match.group(1) 188 | return prefix 189 | 190 | 191 | def load_stock_weights(bert: BertModelLayer, ckpt_path, map_to_stock_fn=map_to_stock_variable_name): 192 | """ 193 | Use this method to load the weights from a pre-trained BERT checkpoint into a bert layer. 194 | 195 | :param bert: a BertModelLayer instance within a built keras model. 196 | :param ckpt_path: checkpoint path, i.e. `uncased_L-12_H-768_A-12/bert_model.ckpt` or `albert_base_zh/albert_model.ckpt` 197 | :return: list of weights with mismatched shapes. This can be used to extend 198 | the segment/token_type embeddings. 199 | """ 200 | assert isinstance(bert, BertModelLayer), "Expecting a BertModelLayer instance as first argument" 201 | assert _checkpoint_exists(ckpt_path), "Checkpoint does not exist: {}".format(ckpt_path) 202 | assert len(bert.weights) > 0, "BertModelLayer weights have not been instantiated yet. " \ 203 | "Please add the layer in a Keras model and call model.build() first!" 204 | 205 | ckpt_reader = tf.train.load_checkpoint(ckpt_path) 206 | 207 | stock_weights = set(ckpt_reader.get_variable_to_dtype_map().keys()) 208 | 209 | prefix = bert_prefix(bert) 210 | 211 | loaded_weights = set() 212 | skip_count = 0 213 | weight_value_tuples = [] 214 | skipped_weight_value_tuples = [] 215 | 216 | bert_params = bert.weights 217 | param_values = keras.backend.batch_get_value(bert.weights) 218 | for ndx, (param_value, param) in enumerate(zip(param_values, bert_params)): 219 | stock_name = map_to_stock_fn(param.name, prefix) 220 | 221 | if ckpt_reader.has_tensor(stock_name): 222 | ckpt_value = ckpt_reader.get_tensor(stock_name) 223 | 224 | if param_value.shape != ckpt_value.shape: 225 | trace("loader: Skipping weight:[{}] as the weight shape:[{}] is not compatible " 226 | "with the checkpoint:[{}] shape:{}".format(param.name, param.shape, 227 | stock_name, ckpt_value.shape)) 228 | skipped_weight_value_tuples.append((param, ckpt_value)) 229 | continue 230 | 231 | weight_value_tuples.append((param, ckpt_value)) 232 | loaded_weights.add(stock_name) 233 | else: 234 | trace("loader: No value for:[{}], i.e.:[{}] in:[{}]".format(param.name, stock_name, ckpt_path)) 235 | skip_count += 1 236 | keras.backend.batch_set_value(weight_value_tuples) 237 | 238 | trace("Done loading {} BERT weights from: {} into {} (prefix:{}). " 239 | "Count of weights not found in the checkpoint was: [{}]. " 240 | "Count of weights with mismatched shape: [{}]".format( 241 | len(weight_value_tuples), ckpt_path, bert, prefix, skip_count, len(skipped_weight_value_tuples))) 242 | 243 | trace("Unused weights from checkpoint:", 244 | "\n\t" + "\n\t".join(sorted(stock_weights.difference(loaded_weights)))) 245 | 246 | return skipped_weight_value_tuples # (bert_weight, value_from_ckpt) 247 | -------------------------------------------------------------------------------- /bert/loader_albert.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # 3 | # created by kpe on 28.10.2019 at 2:02 PM 4 | # 5 | 6 | from __future__ import division, absolute_import, print_function 7 | 8 | import os 9 | import re 10 | import urllib 11 | import params_flow as pf 12 | 13 | import tensorflow as tf 14 | from tensorflow import keras 15 | 16 | from bert import BertModelLayer, loader 17 | 18 | _verbose = os.environ.get('VERBOSE', 1) # verbose print per default 19 | trace = print if int(_verbose) else lambda *a, **k: None 20 | 21 | albert_models_tfhub = { 22 | "albert_base": "https://tfhub.dev/google/albert_base/{version}?tf-hub-format=compressed", 23 | "albert_large": "https://tfhub.dev/google/albert_large/{version}?tf-hub-format=compressed", 24 | "albert_xlarge": "https://tfhub.dev/google/albert_xlarge/{version}?tf-hub-format=compressed", 25 | "albert_xxlarge": "https://tfhub.dev/google/albert_xxlarge/{version}?tf-hub-format=compressed", 26 | } 27 | 28 | albert_models_brightmart = { 29 | "albert_tiny": "https://storage.googleapis.com/albert_zh/albert_tiny.zip", 30 | "albert_tiny_489k": "https://storage.googleapis.com/albert_zh/albert_tiny_489k.zip", 31 | "albert_base": "https://storage.googleapis.com/albert_zh/albert_base_zh.zip", 32 | "albert_base_36k": "https://storage.googleapis.com/albert_zh/albert_base_zh_additional_36k_steps.zip", 33 | "albert_large": "https://storage.googleapis.com/albert_zh/albert_large_zh.zip", 34 | "albert_xlarge": "https://storage.googleapis.com/albert_zh/albert_xlarge_zh_177k.zip", 35 | "albert_xlarge_183k": "https://storage.googleapis.com/albert_zh/albert_xlarge_zh_183k.zip", 36 | } 37 | 38 | albert_models_google = { 39 | "albert_base_zh": "https://storage.googleapis.com/albert_models/albert_base_zh.tar.gz", 40 | "albert_large_zh": "https://storage.googleapis.com/albert_models/albert_large_zh.tar.gz", 41 | "albert_xlarge_zh": "https://storage.googleapis.com/albert_models/albert_xlarge_zh.tar.gz", 42 | "albert_xxlarge_zh": "https://storage.googleapis.com/albert_models/albert_xxlarge_zh.tar.gz", 43 | 44 | "albert_base_v2": "https://storage.googleapis.com/albert_models/albert_base_v2.tar.gz", 45 | "albert_large_v2": "https://storage.googleapis.com/albert_models/albert_large_v2.tar.gz", 46 | "albert_xlarge_v2": "https://storage.googleapis.com/albert_models/albert_xlarge_v2.tar.gz", 47 | "albert_xxlarge_v2": "https://storage.googleapis.com/albert_models/albert_xxlarge_v2.tar.gz" 48 | } 49 | 50 | config_albert_base = { 51 | "attention_probs_dropout_prob": 0.1, 52 | "hidden_act": "gelu", 53 | "hidden_dropout_prob": 0.1, 54 | "embedding_size": 128, 55 | "hidden_size": 768, 56 | "initializer_range": 0.02, 57 | "intermediate_size": 3072, 58 | "max_position_embeddings": 512, 59 | "num_attention_heads": 12, 60 | "num_hidden_layers": 12, 61 | "num_hidden_groups": 1, 62 | "net_structure_type": 0, 63 | "gap_size": 0, 64 | "num_memory_blocks": 0, 65 | "inner_group_num": 1, 66 | "down_scale_factor": 1, 67 | "type_vocab_size": 2, 68 | "vocab_size": 30000 69 | } 70 | 71 | config_albert_large = { 72 | "attention_probs_dropout_prob": 0.1, 73 | "hidden_act": "gelu", 74 | "hidden_dropout_prob": 0.1, 75 | "embedding_size": 128, 76 | "hidden_size": 1024, 77 | "initializer_range": 0.02, 78 | "intermediate_size": 4096, 79 | "max_position_embeddings": 512, 80 | "num_attention_heads": 16, 81 | "num_hidden_layers": 24, 82 | "num_hidden_groups": 1, 83 | "net_structure_type": 0, 84 | "gap_size": 0, 85 | "num_memory_blocks": 0, 86 | "inner_group_num": 1, 87 | "down_scale_factor": 1, 88 | "type_vocab_size": 2, 89 | "vocab_size": 30000 90 | } 91 | config_albert_xlarge = { 92 | "attention_probs_dropout_prob": 0.1, 93 | "hidden_act": "gelu", 94 | "hidden_dropout_prob": 0.1, 95 | "embedding_size": 128, 96 | "hidden_size": 2048, 97 | "initializer_range": 0.02, 98 | "intermediate_size": 8192, 99 | "max_position_embeddings": 512, 100 | "num_attention_heads": 16, 101 | "num_hidden_layers": 24, 102 | "num_hidden_groups": 1, 103 | "net_structure_type": 0, 104 | "gap_size": 0, 105 | "num_memory_blocks": 0, 106 | "inner_group_num": 1, 107 | "down_scale_factor": 1, 108 | "type_vocab_size": 2, 109 | "vocab_size": 30000 110 | } 111 | 112 | config_albert_xxlarge = { 113 | "attention_probs_dropout_prob": 0, 114 | "hidden_act": "gelu", 115 | "hidden_dropout_prob": 0, 116 | "embedding_size": 128, 117 | "hidden_size": 4096, 118 | "initializer_range": 0.02, 119 | "intermediate_size": 16384, 120 | "max_position_embeddings": 512, 121 | "num_attention_heads": 64, 122 | "num_hidden_layers": 12, 123 | "num_hidden_groups": 1, 124 | "net_structure_type": 0, 125 | "layers_to_keep": [], 126 | "gap_size": 0, 127 | "num_memory_blocks": 0, 128 | "inner_group_num": 1, 129 | "down_scale_factor": 1, 130 | "type_vocab_size": 2, 131 | "vocab_size": 30000 132 | } 133 | 134 | albert_models_config = { 135 | "albert_base": config_albert_base, 136 | "albert_large": config_albert_large, 137 | "albert_xlarge": config_albert_xlarge, 138 | "albert_xxlarge": config_albert_xxlarge, 139 | } 140 | 141 | 142 | def albert_params(albert_model: str): 143 | """Returns the ALBERT params for the specified TFHub model. 144 | 145 | :param albert_model: either a model name or a checkpoint directory 146 | containing an assets/albert_config.json 147 | """ 148 | if tf.io.gfile.isdir(albert_model): 149 | config_file = os.path.join(albert_model, "assets", "albert_config.json") # google tfhub v2 weights 150 | if not tf.io.gfile.exists(config_file): 151 | config_file = os.path.join(albert_model, "albert_config.json") # google non-tfhub v2 weights 152 | if tf.io.gfile.exists(config_file): 153 | stock_config = loader.StockBertConfig.from_json_file(config_file) 154 | else: 155 | raise ValueError("No google-research ALBERT model found under:[{}] expecting albert_config.json or assets/albert_config.json".format(albert_model)) 156 | else: 157 | if albert_model in albert_models_config: # google tfhub v1 weights 158 | albert_config = albert_models_config[albert_model] 159 | stock_config = loader.StockBertConfig.from_dict(albert_config, return_instance=True, return_unused=False) 160 | else: 161 | raise ValueError("ALBERT model with name:[{}] not one of tfhub/google-research albert models, try one of:{}".format( 162 | albert_model, albert_models_tfhub)) 163 | 164 | params = loader.map_stock_config_to_params(stock_config) 165 | return params 166 | 167 | 168 | def fetch_brightmart_albert_model(model_name: str, fetch_dir: str): 169 | if model_name not in albert_models_brightmart: 170 | raise ValueError("ALBERT model with name:[{}] not found at brightmart/albert_zh, try one of:{}".format( 171 | model_name, albert_models_brightmart)) 172 | else: 173 | fetch_url = albert_models_brightmart[model_name] 174 | 175 | fetched_file = pf.utils.fetch_url(fetch_url, fetch_dir=fetch_dir) 176 | fetched_dir = pf.utils.unpack_archive(fetched_file) 177 | return fetched_dir 178 | 179 | 180 | def fetch_google_albert_model(model_name: str, fetch_dir: str): 181 | if model_name not in albert_models_google: 182 | raise ValueError("ALBERT model with name:[{}] not found at google-research/ALBERT, try one of:{}".format( 183 | model_name, albert_models_google)) 184 | else: 185 | fetch_url = albert_models_google[model_name] 186 | 187 | fetched_file = pf.utils.fetch_url(fetch_url, fetch_dir=fetch_dir) 188 | fetched_dir = pf.utils.unpack_archive(fetched_file) 189 | fetched_dir = tf.io.gfile.glob(os.path.join(fetched_dir, "*", "model.ckpt-best.meta"))[0] 190 | fetched_dir = os.path.dirname(fetched_dir) 191 | return fetched_dir 192 | 193 | 194 | def fetch_tfhub_albert_model(albert_model: str, fetch_dir: str, version="2"): 195 | """ 196 | Fetches a pre-trained ALBERT model from TFHub. 197 | :param albert_model: TFHub model URL or a model name like albert_base, albert_large, etc. 198 | :param fetch_dir: 199 | :return: 200 | """ 201 | if albert_model.startswith("http"): 202 | fetch_url = albert_model 203 | elif albert_model not in albert_models_tfhub: 204 | raise ValueError("ALBERT model with name:[{}] not found in tfhub/google, try one of:{}".format( 205 | albert_model, albert_models_tfhub)) 206 | else: 207 | fetch_url = albert_models_tfhub[albert_model].format(version=version) 208 | 209 | name, version = urllib.parse.urlparse(fetch_url).path.split("/")[-2:] 210 | local_file_name = "{}.tar.gz".format(name) 211 | 212 | trace("Fetching ALBERT model: {} version: {}".format(name, version)) 213 | 214 | fetched_file = pf.utils.fetch_url(fetch_url, fetch_dir=fetch_dir, local_file_name=local_file_name) 215 | fetched_dir = pf.utils.unpack_archive(fetched_file) 216 | 217 | return fetched_dir 218 | 219 | 220 | def map_to_tfhub_albert_variable_name(name, prefix="bert"): 221 | 222 | name = re.compile("encoder/layer_shared/intermediate/(?=kernel|bias)").sub( 223 | "encoder/transformer/group_0/inner_group_0/ffn_1/intermediate/dense/", name) 224 | name = re.compile("encoder/layer_shared/output/dense/(?=kernel|bias)").sub( 225 | "encoder/transformer/group_0/inner_group_0/ffn_1/intermediate/output/dense/", name) 226 | 227 | name = name.replace("encoder/layer_shared/output/dense", "encoder/transformer/group_0/inner_group_0/ffn_1/intermediate/output/dense") 228 | name = name.replace("encoder/layer_shared/attention/output/LayerNorm", "encoder/transformer/group_0/inner_group_0/LayerNorm") 229 | name = name.replace("encoder/layer_shared/output/LayerNorm", "encoder/transformer/group_0/inner_group_0/LayerNorm_1") 230 | name = name.replace("encoder/layer_shared/attention", "encoder/transformer/group_0/inner_group_0/attention_1") 231 | 232 | name = name.replace("embeddings/word_embeddings_projector/projector", 233 | "encoder/embedding_hidden_mapping_in/kernel") 234 | name = name.replace("embeddings/word_embeddings_projector/bias", 235 | "encoder/embedding_hidden_mapping_in/bias") 236 | 237 | name = name.split(":")[0] 238 | ns = name.split("/") 239 | pns = prefix.split("/") 240 | 241 | if ns[:len(pns)] != pns: 242 | return None 243 | 244 | name = "/".join(["bert"] + ns[len(pns):]) 245 | ns = name.split("/") 246 | 247 | if ns[1] not in ["encoder", "embeddings"]: 248 | return None 249 | if ns[1] == "embeddings": 250 | if ns[2] == "LayerNorm": 251 | return name 252 | else: 253 | return "/".join(ns[:-1]) 254 | if ns[1] == "encoder": 255 | if ns[3] == "intermediate": 256 | return "/".join(ns[:4] + ["dense"] + ns[4:]) 257 | else: 258 | return name 259 | return None 260 | 261 | 262 | def _is_tfhub_model(tfhub_model_path): 263 | try: 264 | assets_files = tf.io.gfile.glob(os.path.join(tfhub_model_path, "assets/*")) 265 | variables_files = tf.io.gfile.glob(os.path.join(tfhub_model_path, "variables/variables.*")) 266 | pb_files = tf.io.gfile.glob(os.path.join(tfhub_model_path, "*.pb")) 267 | except tf.errors.NotFoundError: 268 | assets_files, variables_files, pb_files = [], [], [] 269 | 270 | return len(pb_files) >= 2 and len(assets_files) >= 1 and len(variables_files) >= 2 271 | 272 | 273 | def _is_google_model(ckpt_path): 274 | ckpt_index = os.path.isfile(ckpt_path + ".index") 275 | if ckpt_index: 276 | config_path = tf.io.gfile.glob(os.path.join(os.path.dirname(ckpt_path), "albert_config.json")) 277 | ckpt_meta = tf.io.gfile.glob(os.path.join(os.path.dirname(ckpt_path), "model.ckpt-best.meta")) 278 | return config_path and ckpt_meta 279 | return False 280 | 281 | 282 | def load_albert_weights(bert: BertModelLayer, tfhub_model_path, tags=[]): 283 | """ 284 | Use this method to load the weights from a pre-trained BERT checkpoint into a bert layer. 285 | 286 | :param bert: a BertModelLayer instance within a built keras model. 287 | :param ckpt_path: checkpoint path, i.e. `uncased_L-12_H-768_A-12/bert_model.ckpt` or `albert_base_zh/albert_model.ckpt` 288 | :return: list of weights with mismatched shapes. This can be used to extend 289 | the segment/token_type embeddings. 290 | """ 291 | 292 | if not _is_tfhub_model(tfhub_model_path): 293 | if _is_google_model(tfhub_model_path): 294 | trace("Loading google-research/ALBERT weights...") 295 | map_to_stock_fn = map_to_tfhub_albert_variable_name 296 | else: 297 | trace("Loading brightmart/albert_zh weights...") 298 | map_to_stock_fn = loader.map_to_stock_variable_name 299 | return loader.load_stock_weights(bert, tfhub_model_path, map_to_stock_fn=map_to_stock_fn) 300 | 301 | assert isinstance(bert, BertModelLayer), "Expecting a BertModelLayer instance as first argument" 302 | prefix = loader.bert_prefix(bert) 303 | 304 | with tf.Graph().as_default(): 305 | sm = tf.compat.v2.saved_model.load(tfhub_model_path, tags=tags) 306 | with tf.compat.v1.Session() as sess: 307 | sess.run(tf.compat.v1.global_variables_initializer()) 308 | stock_values = {v.name.split(":")[0]: v.read_value() for v in sm.variables} 309 | stock_values = sess.run(stock_values) 310 | 311 | # print("\n".join([str((n, v.shape)) for n,v in stock_values.items()])) 312 | 313 | loaded_weights = set() 314 | skip_count = 0 315 | weight_value_tuples = [] 316 | skipped_weight_value_tuples = [] 317 | 318 | bert_params = bert.weights 319 | param_values = keras.backend.batch_get_value(bert.weights) 320 | for ndx, (param_value, param) in enumerate(zip(param_values, bert_params)): 321 | stock_name = map_to_tfhub_albert_variable_name(param.name, prefix) 322 | 323 | if stock_name in stock_values: 324 | ckpt_value = stock_values[stock_name] 325 | 326 | if param_value.shape != ckpt_value.shape: 327 | trace("loader: Skipping weight:[{}] as the weight shape:[{}] is not compatible " 328 | "with the checkpoint:[{}] shape:{}".format(param.name, param.shape, 329 | stock_name, ckpt_value.shape)) 330 | skipped_weight_value_tuples.append((param, ckpt_value)) 331 | continue 332 | 333 | weight_value_tuples.append((param, ckpt_value)) 334 | loaded_weights.add(stock_name) 335 | else: 336 | trace("loader: No value for:[{}], i.e.:[{}] in:[{}]".format(param.name, stock_name, tfhub_model_path)) 337 | skip_count += 1 338 | keras.backend.batch_set_value(weight_value_tuples) 339 | 340 | trace("Done loading {} BERT weights from: {} into {} (prefix:{}). " 341 | "Count of weights not found in the checkpoint was: [{}]. " 342 | "Count of weights with mismatched shape: [{}]".format( 343 | len(weight_value_tuples), tfhub_model_path, bert, prefix, skip_count, len(skipped_weight_value_tuples))) 344 | trace("Unused weights from saved model:", 345 | "\n\t" + "\n\t".join(sorted(set(stock_values.keys()).difference(loaded_weights)))) 346 | 347 | return skipped_weight_value_tuples # (bert_weight, value_from_ckpt) 348 | 349 | -------------------------------------------------------------------------------- /bert/model.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # 3 | # created by kpe on 28.Mar.2019 at 12:33 4 | # 5 | 6 | from __future__ import absolute_import, division, print_function 7 | 8 | from tensorflow import keras 9 | import params_flow as pf 10 | 11 | from bert.layer import Layer 12 | from bert.embeddings import BertEmbeddingsLayer 13 | from bert.transformer import TransformerEncoderLayer 14 | 15 | 16 | class BertModelLayer(Layer): 17 | """ 18 | Implementation of BERT (arXiv:1810.04805), adapter-BERT (arXiv:1902.00751) and ALBERT (arXiv:1909.11942). 19 | 20 | See: https://arxiv.org/pdf/1810.04805.pdf - BERT 21 | https://arxiv.org/pdf/1902.00751.pdf - adapter-BERT 22 | https://arxiv.org/pdf/1909.11942.pdf - ALBERT 23 | 24 | """ 25 | class Params(BertEmbeddingsLayer.Params, 26 | TransformerEncoderLayer.Params): 27 | pass 28 | 29 | # noinspection PyUnusedLocal 30 | def _construct(self, **kwargs): 31 | super()._construct(**kwargs) 32 | self.embeddings_layer = BertEmbeddingsLayer.from_params( 33 | self.params, 34 | name="embeddings" 35 | ) 36 | # create all transformer encoder sub-layers 37 | self.encoders_layer = TransformerEncoderLayer.from_params( 38 | self.params, 39 | name="encoder" 40 | ) 41 | 42 | self.support_masking = True 43 | 44 | # noinspection PyAttributeOutsideInit 45 | def build(self, input_shape): 46 | if isinstance(input_shape, list): 47 | assert len(input_shape) == 2 48 | input_ids_shape, token_type_ids_shape = input_shape 49 | self.input_spec = [keras.layers.InputSpec(shape=input_ids_shape), 50 | keras.layers.InputSpec(shape=token_type_ids_shape)] 51 | else: 52 | input_ids_shape = input_shape 53 | self.input_spec = keras.layers.InputSpec(shape=input_ids_shape) 54 | super(BertModelLayer, self).build(input_shape) 55 | 56 | def compute_output_shape(self, input_shape): 57 | if isinstance(input_shape, list): 58 | assert len(input_shape) == 2 59 | input_ids_shape, _ = input_shape 60 | else: 61 | input_ids_shape = input_shape 62 | 63 | output_shape = list(input_ids_shape) + [self.params.hidden_size] 64 | return output_shape 65 | 66 | def apply_adapter_freeze(self): 67 | """ Should be called once the model has been built to freeze 68 | all bet the adapter and layer normalization layers in BERT. 69 | """ 70 | if self.params.adapter_size is not None: 71 | def freeze_selector(layer): 72 | return layer.name not in ["adapter-up", "adapter-down", "LayerNorm", "extra_word_embeddings"] 73 | pf.utils.freeze_leaf_layers(self, freeze_selector) 74 | 75 | def call(self, inputs, mask=None, training=None): 76 | if mask is None: 77 | mask = self.embeddings_layer.compute_mask(inputs) 78 | 79 | embedding_output = self.embeddings_layer(inputs, mask=mask, training=training) 80 | output = self.encoders_layer(embedding_output, mask=mask, training=training) 81 | return output # [B, seq_len, hidden_size] 82 | 83 | -------------------------------------------------------------------------------- /bert/tokenization/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # 3 | # created by kpe on 18.Nov.2019 at 07:18 4 | # 5 | from __future__ import division, absolute_import, print_function 6 | 7 | -------------------------------------------------------------------------------- /bert/tokenization/albert_tokenization.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # Lint as: python2, python3 16 | # coding=utf-8 17 | """Tokenization classes.""" 18 | 19 | from __future__ import absolute_import 20 | from __future__ import division 21 | from __future__ import print_function 22 | 23 | import collections 24 | import re 25 | import unicodedata 26 | import six 27 | from six.moves import range 28 | import tensorflow.compat.v1 as tf 29 | 30 | SPIECE_UNDERLINE = u"▁".encode("utf-8") 31 | 32 | 33 | def validate_case_matches_checkpoint(do_lower_case, init_checkpoint): 34 | """Checks whether the casing config is consistent with the checkpoint name.""" 35 | 36 | # The casing has to be passed in by the user and there is no explicit check 37 | # as to whether it matches the checkpoint. The casing information probably 38 | # should have been stored in the bert_config.json file, but it's not, so 39 | # we have to heuristically detect it to validate. 40 | 41 | if not init_checkpoint: 42 | return 43 | 44 | m = re.match("^.*?([A-Za-z0-9_-]+)/bert_model.ckpt", 45 | six.ensure_str(init_checkpoint)) 46 | if m is None: 47 | return 48 | 49 | model_name = m.group(1) 50 | 51 | lower_models = [ 52 | "uncased_L-24_H-1024_A-16", "uncased_L-12_H-768_A-12", 53 | "multilingual_L-12_H-768_A-12", "chinese_L-12_H-768_A-12" 54 | ] 55 | 56 | cased_models = [ 57 | "cased_L-12_H-768_A-12", "cased_L-24_H-1024_A-16", 58 | "multi_cased_L-12_H-768_A-12" 59 | ] 60 | 61 | is_bad_config = False 62 | if model_name in lower_models and not do_lower_case: 63 | is_bad_config = True 64 | actual_flag = "False" 65 | case_name = "lowercased" 66 | opposite_flag = "True" 67 | 68 | if model_name in cased_models and do_lower_case: 69 | is_bad_config = True 70 | actual_flag = "True" 71 | case_name = "cased" 72 | opposite_flag = "False" 73 | 74 | if is_bad_config: 75 | raise ValueError( 76 | "You passed in `--do_lower_case=%s` with `--init_checkpoint=%s`. " 77 | "However, `%s` seems to be a %s model, so you " 78 | "should pass in `--do_lower_case=%s` so that the fine-tuning matches " 79 | "how the model was pre-training. If this error is wrong, please " 80 | "just comment out this check." % (actual_flag, init_checkpoint, 81 | model_name, case_name, opposite_flag)) 82 | 83 | 84 | def preprocess_text(inputs, remove_space=True, lower=False): 85 | """preprocess data by removing extra space and normalize data.""" 86 | outputs = inputs 87 | if remove_space: 88 | outputs = " ".join(inputs.strip().split()) 89 | 90 | if six.PY2 and isinstance(outputs, str): 91 | try: 92 | outputs = six.ensure_text(outputs, "utf-8") 93 | except UnicodeDecodeError: 94 | outputs = six.ensure_text(outputs, "latin-1") 95 | 96 | outputs = unicodedata.normalize("NFKD", outputs) 97 | outputs = "".join([c for c in outputs if not unicodedata.combining(c)]) 98 | if lower: 99 | outputs = outputs.lower() 100 | 101 | return outputs 102 | 103 | 104 | def encode_pieces(sp_model, text, return_unicode=True, sample=False): 105 | """turn sentences into word pieces.""" 106 | 107 | if six.PY2 and isinstance(text, six.text_type): 108 | text = six.ensure_binary(text, "utf-8") 109 | 110 | if not sample: 111 | pieces = sp_model.EncodeAsPieces(text) 112 | else: 113 | pieces = sp_model.SampleEncodeAsPieces(text, 64, 0.1) 114 | new_pieces = [] 115 | for piece in pieces: 116 | piece = printable_text(piece) 117 | if len(piece) > 1 and piece[-1] == "," and piece[-2].isdigit(): 118 | cur_pieces = sp_model.EncodeAsPieces( 119 | six.ensure_binary(piece[:-1]).replace(SPIECE_UNDERLINE, b"")) 120 | if piece[0] != SPIECE_UNDERLINE and cur_pieces[0][0] == SPIECE_UNDERLINE: 121 | if len(cur_pieces[0]) == 1: 122 | cur_pieces = cur_pieces[1:] 123 | else: 124 | cur_pieces[0] = cur_pieces[0][1:] 125 | cur_pieces.append(piece[-1]) 126 | new_pieces.extend(cur_pieces) 127 | else: 128 | new_pieces.append(piece) 129 | 130 | # note(zhiliny): convert back to unicode for py2 131 | if six.PY2 and return_unicode: 132 | ret_pieces = [] 133 | for piece in new_pieces: 134 | if isinstance(piece, str): 135 | piece = six.ensure_text(piece, "utf-8") 136 | ret_pieces.append(piece) 137 | new_pieces = ret_pieces 138 | 139 | return new_pieces 140 | 141 | 142 | def encode_ids(sp_model, text, sample=False): 143 | pieces = encode_pieces(sp_model, text, return_unicode=False, sample=sample) 144 | ids = [sp_model.PieceToId(piece) for piece in pieces] 145 | return ids 146 | 147 | 148 | def convert_to_unicode(text): 149 | """Converts `text` to Unicode (if it's not already), assuming utf-8 input.""" 150 | if six.PY3: 151 | if isinstance(text, str): 152 | return text 153 | elif isinstance(text, bytes): 154 | return six.ensure_text(text, "utf-8", "ignore") 155 | else: 156 | raise ValueError("Unsupported string type: %s" % (type(text))) 157 | elif six.PY2: 158 | if isinstance(text, str): 159 | return six.ensure_text(text, "utf-8", "ignore") 160 | elif isinstance(text, six.text_type): 161 | return text 162 | else: 163 | raise ValueError("Unsupported string type: %s" % (type(text))) 164 | else: 165 | raise ValueError("Not running on Python2 or Python 3?") 166 | 167 | 168 | def printable_text(text): 169 | """Returns text encoded in a way suitable for print or `tf.logging`.""" 170 | 171 | # These functions want `str` for both Python2 and Python3, but in one case 172 | # it's a Unicode string and in the other it's a byte string. 173 | if six.PY3: 174 | if isinstance(text, str): 175 | return text 176 | elif isinstance(text, bytes): 177 | return six.ensure_text(text, "utf-8", "ignore") 178 | else: 179 | raise ValueError("Unsupported string type: %s" % (type(text))) 180 | elif six.PY2: 181 | if isinstance(text, str): 182 | return text 183 | elif isinstance(text, six.text_type): 184 | return six.ensure_binary(text, "utf-8") 185 | else: 186 | raise ValueError("Unsupported string type: %s" % (type(text))) 187 | else: 188 | raise ValueError("Not running on Python2 or Python 3?") 189 | 190 | 191 | def load_vocab(vocab_file): 192 | """Loads a vocabulary file into a dictionary.""" 193 | vocab = collections.OrderedDict() 194 | with tf.io.gfile.GFile(vocab_file, "r") as reader: 195 | while True: 196 | token = convert_to_unicode(reader.readline()) 197 | if not token: 198 | break 199 | token = token.strip() 200 | if token: 201 | token = token.split()[0] 202 | if token not in vocab: 203 | vocab[token] = len(vocab) 204 | return vocab 205 | 206 | 207 | def convert_by_vocab(vocab, items): 208 | """Converts a sequence of [tokens|ids] using the vocab.""" 209 | output = [] 210 | for item in items: 211 | output.append(vocab[item]) 212 | return output 213 | 214 | 215 | def convert_tokens_to_ids(vocab, tokens): 216 | return convert_by_vocab(vocab, tokens) 217 | 218 | 219 | def convert_ids_to_tokens(inv_vocab, ids): 220 | return convert_by_vocab(inv_vocab, ids) 221 | 222 | 223 | def whitespace_tokenize(text): 224 | """Runs basic whitespace cleaning and splitting on a piece of text.""" 225 | text = text.strip() 226 | if not text: 227 | return [] 228 | tokens = text.split() 229 | return tokens 230 | 231 | 232 | class FullTokenizer(object): 233 | """Runs end-to-end tokenziation.""" 234 | 235 | def __init__(self, vocab_file, do_lower_case=True, spm_model_file=None): 236 | self.vocab = None 237 | self.sp_model = None 238 | if spm_model_file: 239 | import sentencepiece as spm 240 | 241 | self.sp_model = spm.SentencePieceProcessor() 242 | tf.compat.v1.logging.info("loading sentence piece model") 243 | self.sp_model.Load(spm_model_file) 244 | # Note(mingdachen): For the purpose of consisent API, we are 245 | # generating a vocabulary for the sentence piece tokenizer. 246 | self.vocab = {self.sp_model.IdToPiece(i): i for i 247 | in range(self.sp_model.GetPieceSize())} 248 | else: 249 | self.vocab = load_vocab(vocab_file) 250 | self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case) 251 | self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab) 252 | self.inv_vocab = {v: k for k, v in self.vocab.items()} 253 | 254 | @classmethod 255 | def from_scratch(cls, vocab_file, do_lower_case, spm_model_file): 256 | return FullTokenizer(vocab_file, do_lower_case, spm_model_file) 257 | 258 | @classmethod 259 | def from_hub_module(cls, hub_module, spm_model_file): 260 | """Get the vocab file and casing info from the Hub module.""" 261 | import tensorflow_hub as hub 262 | with tf.Graph().as_default(): 263 | albert_module = hub.Module(hub_module) 264 | tokenization_info = albert_module(signature="tokenization_info", 265 | as_dict=True) 266 | with tf.Session() as sess: 267 | vocab_file, do_lower_case = sess.run( 268 | [tokenization_info["vocab_file"], 269 | tokenization_info["do_lower_case"]]) 270 | return FullTokenizer( 271 | vocab_file=vocab_file, do_lower_case=do_lower_case, 272 | spm_model_file=spm_model_file) 273 | 274 | def tokenize(self, text): 275 | if self.sp_model: 276 | split_tokens = encode_pieces(self.sp_model, text, return_unicode=False) 277 | else: 278 | split_tokens = [] 279 | for token in self.basic_tokenizer.tokenize(text): 280 | for sub_token in self.wordpiece_tokenizer.tokenize(token): 281 | split_tokens.append(sub_token) 282 | 283 | return split_tokens 284 | 285 | def convert_tokens_to_ids(self, tokens): 286 | if self.sp_model: 287 | tf.compat.v1.logging.info("using sentence piece tokenzier.") 288 | return [self.sp_model.PieceToId( 289 | printable_text(token)) for token in tokens] 290 | else: 291 | return convert_by_vocab(self.vocab, tokens) 292 | 293 | def convert_ids_to_tokens(self, ids): 294 | if self.sp_model: 295 | tf.compat.v1.logging.info("using sentence piece tokenzier.") 296 | return [self.sp_model.IdToPiece(id_) for id_ in ids] 297 | else: 298 | return convert_by_vocab(self.inv_vocab, ids) 299 | 300 | 301 | class BasicTokenizer(object): 302 | """Runs basic tokenization (punctuation splitting, lower casing, etc.).""" 303 | 304 | def __init__(self, do_lower_case=True): 305 | """Constructs a BasicTokenizer. 306 | 307 | Args: 308 | do_lower_case: Whether to lower case the input. 309 | """ 310 | self.do_lower_case = do_lower_case 311 | 312 | def tokenize(self, text): 313 | """Tokenizes a piece of text.""" 314 | text = convert_to_unicode(text) 315 | text = self._clean_text(text) 316 | 317 | # This was added on November 1st, 2018 for the multilingual and Chinese 318 | # models. This is also applied to the English models now, but it doesn't 319 | # matter since the English models were not trained on any Chinese data 320 | # and generally don't have any Chinese data in them (there are Chinese 321 | # characters in the vocabulary because Wikipedia does have some Chinese 322 | # words in the English Wikipedia.). 323 | text = self._tokenize_chinese_chars(text) 324 | 325 | orig_tokens = whitespace_tokenize(text) 326 | split_tokens = [] 327 | for token in orig_tokens: 328 | if self.do_lower_case: 329 | token = token.lower() 330 | token = self._run_strip_accents(token) 331 | split_tokens.extend(self._run_split_on_punc(token)) 332 | 333 | output_tokens = whitespace_tokenize(" ".join(split_tokens)) 334 | return output_tokens 335 | 336 | def _run_strip_accents(self, text): 337 | """Strips accents from a piece of text.""" 338 | text = unicodedata.normalize("NFD", text) 339 | output = [] 340 | for char in text: 341 | cat = unicodedata.category(char) 342 | if cat == "Mn": 343 | continue 344 | output.append(char) 345 | return "".join(output) 346 | 347 | def _run_split_on_punc(self, text): 348 | """Splits punctuation on a piece of text.""" 349 | chars = list(text) 350 | i = 0 351 | start_new_word = True 352 | output = [] 353 | while i < len(chars): 354 | char = chars[i] 355 | if _is_punctuation(char): 356 | output.append([char]) 357 | start_new_word = True 358 | else: 359 | if start_new_word: 360 | output.append([]) 361 | start_new_word = False 362 | output[-1].append(char) 363 | i += 1 364 | 365 | return ["".join(x) for x in output] 366 | 367 | def _tokenize_chinese_chars(self, text): 368 | """Adds whitespace around any CJK character.""" 369 | output = [] 370 | for char in text: 371 | cp = ord(char) 372 | if self._is_chinese_char(cp): 373 | output.append(" ") 374 | output.append(char) 375 | output.append(" ") 376 | else: 377 | output.append(char) 378 | return "".join(output) 379 | 380 | def _is_chinese_char(self, cp): 381 | """Checks whether CP is the codepoint of a CJK character.""" 382 | # This defines a "chinese character" as anything in the CJK Unicode block: 383 | # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block) 384 | # 385 | # Note that the CJK Unicode block is NOT all Japanese and Korean characters, 386 | # despite its name. The modern Korean Hangul alphabet is a different block, 387 | # as is Japanese Hiragana and Katakana. Those alphabets are used to write 388 | # space-separated words, so they are not treated specially and handled 389 | # like the all of the other languages. 390 | if ((cp >= 0x4E00 and cp <= 0x9FFF) or # 391 | (cp >= 0x3400 and cp <= 0x4DBF) or # 392 | (cp >= 0x20000 and cp <= 0x2A6DF) or # 393 | (cp >= 0x2A700 and cp <= 0x2B73F) or # 394 | (cp >= 0x2B740 and cp <= 0x2B81F) or # 395 | (cp >= 0x2B820 and cp <= 0x2CEAF) or 396 | (cp >= 0xF900 and cp <= 0xFAFF) or # 397 | (cp >= 0x2F800 and cp <= 0x2FA1F)): # 398 | return True 399 | 400 | return False 401 | 402 | def _clean_text(self, text): 403 | """Performs invalid character removal and whitespace cleanup on text.""" 404 | output = [] 405 | for char in text: 406 | cp = ord(char) 407 | if cp == 0 or cp == 0xfffd or _is_control(char): 408 | continue 409 | if _is_whitespace(char): 410 | output.append(" ") 411 | else: 412 | output.append(char) 413 | return "".join(output) 414 | 415 | 416 | class WordpieceTokenizer(object): 417 | """Runs WordPiece tokenziation.""" 418 | 419 | def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=200): 420 | self.vocab = vocab 421 | self.unk_token = unk_token 422 | self.max_input_chars_per_word = max_input_chars_per_word 423 | 424 | def tokenize(self, text): 425 | """Tokenizes a piece of text into its word pieces. 426 | 427 | This uses a greedy longest-match-first algorithm to perform tokenization 428 | using the given vocabulary. 429 | 430 | For example: 431 | input = "unaffable" 432 | output = ["un", "##aff", "##able"] 433 | 434 | Args: 435 | text: A single token or whitespace separated tokens. This should have 436 | already been passed through `BasicTokenizer. 437 | 438 | Returns: 439 | A list of wordpiece tokens. 440 | """ 441 | 442 | text = convert_to_unicode(text) 443 | 444 | output_tokens = [] 445 | for token in whitespace_tokenize(text): 446 | chars = list(token) 447 | if len(chars) > self.max_input_chars_per_word: 448 | output_tokens.append(self.unk_token) 449 | continue 450 | 451 | is_bad = False 452 | start = 0 453 | sub_tokens = [] 454 | while start < len(chars): 455 | end = len(chars) 456 | cur_substr = None 457 | while start < end: 458 | substr = "".join(chars[start:end]) 459 | if start > 0: 460 | substr = "##" + six.ensure_str(substr) 461 | if substr in self.vocab: 462 | cur_substr = substr 463 | break 464 | end -= 1 465 | if cur_substr is None: 466 | is_bad = True 467 | break 468 | sub_tokens.append(cur_substr) 469 | start = end 470 | 471 | if is_bad: 472 | output_tokens.append(self.unk_token) 473 | else: 474 | output_tokens.extend(sub_tokens) 475 | return output_tokens 476 | 477 | 478 | def _is_whitespace(char): 479 | """Checks whether `chars` is a whitespace character.""" 480 | # \t, \n, and \r are technically control characters but we treat them 481 | # as whitespace since they are generally considered as such. 482 | if char == " " or char == "\t" or char == "\n" or char == "\r": 483 | return True 484 | cat = unicodedata.category(char) 485 | if cat == "Zs": 486 | return True 487 | return False 488 | 489 | 490 | def _is_control(char): 491 | """Checks whether `chars` is a control character.""" 492 | # These are technically control characters but we count them as whitespace 493 | # characters. 494 | if char == "\t" or char == "\n" or char == "\r": 495 | return False 496 | cat = unicodedata.category(char) 497 | if cat in ("Cc", "Cf"): 498 | return True 499 | return False 500 | 501 | 502 | def _is_punctuation(char): 503 | """Checks whether `chars` is a punctuation character.""" 504 | cp = ord(char) 505 | # We treat all non-letter/number ASCII as punctuation. 506 | # Characters such as "^", "$", and "`" are not in the Unicode 507 | # Punctuation class but we treat them as punctuation anyways, for 508 | # consistency. 509 | if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or 510 | (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)): 511 | return True 512 | cat = unicodedata.category(char) 513 | if cat.startswith("P"): 514 | return True 515 | return False 516 | -------------------------------------------------------------------------------- /bert/tokenization/bert_tokenization.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Tokenization classes.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import collections 22 | import re 23 | import unicodedata 24 | import six 25 | import tensorflow as tf 26 | 27 | 28 | def validate_case_matches_checkpoint(do_lower_case, init_checkpoint): 29 | """Checks whether the casing config is consistent with the checkpoint name.""" 30 | 31 | # The casing has to be passed in by the user and there is no explicit check 32 | # as to whether it matches the checkpoint. The casing information probably 33 | # should have been stored in the bert_config.json file, but it's not, so 34 | # we have to heuristically detect it to validate. 35 | 36 | if not init_checkpoint: 37 | return 38 | 39 | m = re.match("^.*?([A-Za-z0-9_-]+)/bert_model.ckpt", init_checkpoint) 40 | if m is None: 41 | return 42 | 43 | model_name = m.group(1) 44 | 45 | lower_models = [ 46 | "uncased_L-24_H-1024_A-16", "uncased_L-12_H-768_A-12", 47 | "multilingual_L-12_H-768_A-12", "chinese_L-12_H-768_A-12" 48 | ] 49 | 50 | cased_models = [ 51 | "cased_L-12_H-768_A-12", "cased_L-24_H-1024_A-16", 52 | "multi_cased_L-12_H-768_A-12" 53 | ] 54 | 55 | is_bad_config = False 56 | if model_name in lower_models and not do_lower_case: 57 | is_bad_config = True 58 | actual_flag = "False" 59 | case_name = "lowercased" 60 | opposite_flag = "True" 61 | 62 | if model_name in cased_models and do_lower_case: 63 | is_bad_config = True 64 | actual_flag = "True" 65 | case_name = "cased" 66 | opposite_flag = "False" 67 | 68 | if is_bad_config: 69 | raise ValueError( 70 | "You passed in `--do_lower_case=%s` with `--init_checkpoint=%s`. " 71 | "However, `%s` seems to be a %s model, so you " 72 | "should pass in `--do_lower_case=%s` so that the fine-tuning matches " 73 | "how the model was pre-training. If this error is wrong, please " 74 | "just comment out this check." % (actual_flag, init_checkpoint, 75 | model_name, case_name, opposite_flag)) 76 | 77 | 78 | def convert_to_unicode(text): 79 | """Converts `text` to Unicode (if it's not already), assuming utf-8 input.""" 80 | if six.PY3: 81 | if isinstance(text, str): 82 | return text 83 | elif isinstance(text, bytes): 84 | return text.decode("utf-8", "ignore") 85 | else: 86 | raise ValueError("Unsupported string type: %s" % (type(text))) 87 | elif six.PY2: 88 | if isinstance(text, str): 89 | return text.decode("utf-8", "ignore") 90 | elif isinstance(text, unicode): 91 | return text 92 | else: 93 | raise ValueError("Unsupported string type: %s" % (type(text))) 94 | else: 95 | raise ValueError("Not running on Python2 or Python 3?") 96 | 97 | 98 | def printable_text(text): 99 | """Returns text encoded in a way suitable for print or `tf.logging`.""" 100 | 101 | # These functions want `str` for both Python2 and Python3, but in one case 102 | # it's a Unicode string and in the other it's a byte string. 103 | if six.PY3: 104 | if isinstance(text, str): 105 | return text 106 | elif isinstance(text, bytes): 107 | return text.decode("utf-8", "ignore") 108 | else: 109 | raise ValueError("Unsupported string type: %s" % (type(text))) 110 | elif six.PY2: 111 | if isinstance(text, str): 112 | return text 113 | elif isinstance(text, unicode): 114 | return text.encode("utf-8") 115 | else: 116 | raise ValueError("Unsupported string type: %s" % (type(text))) 117 | else: 118 | raise ValueError("Not running on Python2 or Python 3?") 119 | 120 | 121 | def load_vocab(vocab_file): 122 | """Loads a vocabulary file into a dictionary.""" 123 | vocab = collections.OrderedDict() 124 | index = 0 125 | with tf.io.gfile.GFile(vocab_file, "r") as reader: 126 | while True: 127 | token = convert_to_unicode(reader.readline()) 128 | if not token: 129 | break 130 | token = token.strip() 131 | vocab[token] = index 132 | index += 1 133 | return vocab 134 | 135 | 136 | def convert_by_vocab(vocab, items): 137 | """Converts a sequence of [tokens|ids] using the vocab.""" 138 | output = [] 139 | for item in items: 140 | output.append(vocab[item]) 141 | return output 142 | 143 | 144 | def convert_tokens_to_ids(vocab, tokens): 145 | return convert_by_vocab(vocab, tokens) 146 | 147 | 148 | def convert_ids_to_tokens(inv_vocab, ids): 149 | return convert_by_vocab(inv_vocab, ids) 150 | 151 | 152 | def whitespace_tokenize(text): 153 | """Runs basic whitespace cleaning and splitting on a piece of text.""" 154 | text = text.strip() 155 | if not text: 156 | return [] 157 | tokens = text.split() 158 | return tokens 159 | 160 | 161 | class FullTokenizer(object): 162 | """Runs end-to-end tokenziation.""" 163 | 164 | def __init__(self, vocab_file, do_lower_case=True): 165 | self.vocab = load_vocab(vocab_file) 166 | self.inv_vocab = {v: k for k, v in self.vocab.items()} 167 | self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case) 168 | self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab) 169 | 170 | def tokenize(self, text): 171 | split_tokens = [] 172 | for token in self.basic_tokenizer.tokenize(text): 173 | for sub_token in self.wordpiece_tokenizer.tokenize(token): 174 | split_tokens.append(sub_token) 175 | 176 | return split_tokens 177 | 178 | def convert_tokens_to_ids(self, tokens): 179 | return convert_by_vocab(self.vocab, tokens) 180 | 181 | def convert_ids_to_tokens(self, ids): 182 | return convert_by_vocab(self.inv_vocab, ids) 183 | 184 | 185 | class BasicTokenizer(object): 186 | """Runs basic tokenization (punctuation splitting, lower casing, etc.).""" 187 | 188 | def __init__(self, do_lower_case=True): 189 | """Constructs a BasicTokenizer. 190 | 191 | Args: 192 | do_lower_case: Whether to lower case the input. 193 | """ 194 | self.do_lower_case = do_lower_case 195 | 196 | def tokenize(self, text): 197 | """Tokenizes a piece of text.""" 198 | text = convert_to_unicode(text) 199 | text = self._clean_text(text) 200 | 201 | # This was added on November 1st, 2018 for the multilingual and Chinese 202 | # models. This is also applied to the English models now, but it doesn't 203 | # matter since the English models were not trained on any Chinese data 204 | # and generally don't have any Chinese data in them (there are Chinese 205 | # characters in the vocabulary because Wikipedia does have some Chinese 206 | # words in the English Wikipedia.). 207 | text = self._tokenize_chinese_chars(text) 208 | 209 | orig_tokens = whitespace_tokenize(text) 210 | split_tokens = [] 211 | for token in orig_tokens: 212 | if self.do_lower_case: 213 | token = token.lower() 214 | token = self._run_strip_accents(token) 215 | split_tokens.extend(self._run_split_on_punc(token)) 216 | 217 | output_tokens = whitespace_tokenize(" ".join(split_tokens)) 218 | return output_tokens 219 | 220 | def _run_strip_accents(self, text): 221 | """Strips accents from a piece of text.""" 222 | text = unicodedata.normalize("NFD", text) 223 | output = [] 224 | for char in text: 225 | cat = unicodedata.category(char) 226 | if cat == "Mn": 227 | continue 228 | output.append(char) 229 | return "".join(output) 230 | 231 | def _run_split_on_punc(self, text): 232 | """Splits punctuation on a piece of text.""" 233 | chars = list(text) 234 | i = 0 235 | start_new_word = True 236 | output = [] 237 | while i < len(chars): 238 | char = chars[i] 239 | if _is_punctuation(char): 240 | output.append([char]) 241 | start_new_word = True 242 | else: 243 | if start_new_word: 244 | output.append([]) 245 | start_new_word = False 246 | output[-1].append(char) 247 | i += 1 248 | 249 | return ["".join(x) for x in output] 250 | 251 | def _tokenize_chinese_chars(self, text): 252 | """Adds whitespace around any CJK character.""" 253 | output = [] 254 | for char in text: 255 | cp = ord(char) 256 | if self._is_chinese_char(cp): 257 | output.append(" ") 258 | output.append(char) 259 | output.append(" ") 260 | else: 261 | output.append(char) 262 | return "".join(output) 263 | 264 | def _is_chinese_char(self, cp): 265 | """Checks whether CP is the codepoint of a CJK character.""" 266 | # This defines a "chinese character" as anything in the CJK Unicode block: 267 | # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block) 268 | # 269 | # Note that the CJK Unicode block is NOT all Japanese and Korean characters, 270 | # despite its name. The modern Korean Hangul alphabet is a different block, 271 | # as is Japanese Hiragana and Katakana. Those alphabets are used to write 272 | # space-separated words, so they are not treated specially and handled 273 | # like the all of the other languages. 274 | if ((cp >= 0x4E00 and cp <= 0x9FFF) or # 275 | (cp >= 0x3400 and cp <= 0x4DBF) or # 276 | (cp >= 0x20000 and cp <= 0x2A6DF) or # 277 | (cp >= 0x2A700 and cp <= 0x2B73F) or # 278 | (cp >= 0x2B740 and cp <= 0x2B81F) or # 279 | (cp >= 0x2B820 and cp <= 0x2CEAF) or 280 | (cp >= 0xF900 and cp <= 0xFAFF) or # 281 | (cp >= 0x2F800 and cp <= 0x2FA1F)): # 282 | return True 283 | 284 | return False 285 | 286 | def _clean_text(self, text): 287 | """Performs invalid character removal and whitespace cleanup on text.""" 288 | output = [] 289 | for char in text: 290 | cp = ord(char) 291 | if cp == 0 or cp == 0xfffd or _is_control(char): 292 | continue 293 | if _is_whitespace(char): 294 | output.append(" ") 295 | else: 296 | output.append(char) 297 | return "".join(output) 298 | 299 | 300 | class WordpieceTokenizer(object): 301 | """Runs WordPiece tokenziation.""" 302 | 303 | def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=200): 304 | self.vocab = vocab 305 | self.unk_token = unk_token 306 | self.max_input_chars_per_word = max_input_chars_per_word 307 | 308 | def tokenize(self, text): 309 | """Tokenizes a piece of text into its word pieces. 310 | 311 | This uses a greedy longest-match-first algorithm to perform tokenization 312 | using the given vocabulary. 313 | 314 | For example: 315 | input = "unaffable" 316 | output = ["un", "##aff", "##able"] 317 | 318 | Args: 319 | text: A single token or whitespace separated tokens. This should have 320 | already been passed through `BasicTokenizer. 321 | 322 | Returns: 323 | A list of wordpiece tokens. 324 | """ 325 | 326 | text = convert_to_unicode(text) 327 | 328 | output_tokens = [] 329 | for token in whitespace_tokenize(text): 330 | chars = list(token) 331 | if len(chars) > self.max_input_chars_per_word: 332 | output_tokens.append(self.unk_token) 333 | continue 334 | 335 | is_bad = False 336 | start = 0 337 | sub_tokens = [] 338 | while start < len(chars): 339 | end = len(chars) 340 | cur_substr = None 341 | while start < end: 342 | substr = "".join(chars[start:end]) 343 | if start > 0: 344 | substr = "##" + substr 345 | if substr in self.vocab: 346 | cur_substr = substr 347 | break 348 | end -= 1 349 | if cur_substr is None: 350 | is_bad = True 351 | break 352 | sub_tokens.append(cur_substr) 353 | start = end 354 | 355 | if is_bad: 356 | output_tokens.append(self.unk_token) 357 | else: 358 | output_tokens.extend(sub_tokens) 359 | return output_tokens 360 | 361 | 362 | def _is_whitespace(char): 363 | """Checks whether `chars` is a whitespace character.""" 364 | # \t, \n, and \r are technically contorl characters but we treat them 365 | # as whitespace since they are generally considered as such. 366 | if char == " " or char == "\t" or char == "\n" or char == "\r": 367 | return True 368 | cat = unicodedata.category(char) 369 | if cat == "Zs": 370 | return True 371 | return False 372 | 373 | 374 | def _is_control(char): 375 | """Checks whether `chars` is a control character.""" 376 | # These are technically control characters but we count them as whitespace 377 | # characters. 378 | if char == "\t" or char == "\n" or char == "\r": 379 | return False 380 | cat = unicodedata.category(char) 381 | if cat in ("Cc", "Cf"): 382 | return True 383 | return False 384 | 385 | 386 | def _is_punctuation(char): 387 | """Checks whether `chars` is a punctuation character.""" 388 | cp = ord(char) 389 | # We treat all non-letter/number ASCII as punctuation. 390 | # Characters such as "^", "$", and "`" are not in the Unicode 391 | # Punctuation class but we treat them as punctuation anyways, for 392 | # consistency. 393 | if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or 394 | (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)): 395 | return True 396 | cat = unicodedata.category(char) 397 | if cat.startswith("P"): 398 | return True 399 | return False 400 | -------------------------------------------------------------------------------- /bert/transformer.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # 3 | # created by kpe on 20.Mar.2019 at 16:30 4 | # 5 | 6 | from __future__ import absolute_import, division, print_function 7 | 8 | import tensorflow as tf 9 | from tensorflow.python import keras 10 | 11 | from params_flow import LayerNormalization 12 | 13 | from bert.attention import AttentionLayer 14 | from bert.layer import Layer 15 | 16 | 17 | class ProjectionLayer(Layer): 18 | class Params(Layer.Params): 19 | hidden_size = None 20 | hidden_dropout = 0.1 21 | initializer_range = 0.02 22 | adapter_size = None # bottleneck size of the adapter - arXiv:1902.00751 23 | adapter_activation = "gelu" 24 | adapter_init_scale = 1e-3 25 | 26 | def _construct(self, **kwargs): 27 | super()._construct(**kwargs) 28 | self.dense = None 29 | self.dropout = None 30 | self.layer_norm = None 31 | 32 | self.adapter_down = None 33 | self.adapter_up = None 34 | 35 | self.supports_masking = True 36 | 37 | # noinspection PyAttributeOutsideInit 38 | def build(self, input_shape): 39 | assert isinstance(input_shape, list) and 2 == len(input_shape) 40 | out_shape, residual_shape = input_shape 41 | self.input_spec = [keras.layers.InputSpec(shape=out_shape), 42 | keras.layers.InputSpec(shape=residual_shape)] 43 | 44 | self.dense = keras.layers.Dense(units=self.params.hidden_size, 45 | kernel_initializer=self.create_initializer(), 46 | name="dense") 47 | self.dropout = keras.layers.Dropout(rate=self.params.hidden_dropout) 48 | self.layer_norm = LayerNormalization(name="LayerNorm") 49 | 50 | if self.params.adapter_size is not None: 51 | self.adapter_down = keras.layers.Dense(units=self.params.adapter_size, 52 | kernel_initializer=tf.keras.initializers.TruncatedNormal( 53 | stddev=self.params.adapter_init_scale), 54 | activation=self.get_activation(self.params.adapter_activation), 55 | name="adapter-down") 56 | self.adapter_up = keras.layers.Dense(units=self.params.hidden_size, 57 | kernel_initializer=tf.keras.initializers.TruncatedNormal( 58 | stddev=self.params.adapter_init_scale), 59 | name="adapter-up") 60 | 61 | super(ProjectionLayer, self).build(input_shape) 62 | 63 | def call(self, inputs, mask=None, training=None, **kwargs): 64 | output, residual = inputs 65 | output = self.dense(output) 66 | output = self.dropout(output, training=training) 67 | 68 | if self.adapter_down is not None: 69 | adapted = self.adapter_down(output) 70 | adapted = self.adapter_up(adapted) 71 | output = tf.add(output, adapted) 72 | 73 | output = self.layer_norm(tf.add(output, residual)) 74 | return output 75 | 76 | 77 | class TransformerSelfAttentionLayer(Layer): 78 | class Params(ProjectionLayer.Params, 79 | AttentionLayer.Params): 80 | hidden_size = None 81 | num_heads = None 82 | hidden_dropout = None 83 | attention_dropout = 0.1 84 | initializer_range = 0.02 85 | 86 | def _construct(self, **kwargs): 87 | super()._construct(**kwargs) 88 | params = self.params 89 | if params.hidden_size % params.num_heads != 0: 90 | raise ValueError("The hidden_size:[{}] is not a multiple of num_heads:[{}]".format(params.hidden_size, 91 | params.num_heads)) 92 | self.size_per_head = params.hidden_size // params.num_heads 93 | assert params.size_per_head is None or self.size_per_head == params.size_per_head 94 | 95 | self.attention_layer = None 96 | self.attention_projector = None 97 | 98 | self.supports_masking = True 99 | 100 | def build(self, input_shape): 101 | self.input_spec = keras.layers.InputSpec(shape=input_shape) 102 | 103 | self.attention_layer = AttentionLayer.from_params( 104 | self.params, 105 | size_per_head=self.size_per_head, 106 | name="self", 107 | ) 108 | self.attention_projector = ProjectionLayer.from_params( 109 | self.params, 110 | name="output", 111 | ) 112 | 113 | super(TransformerSelfAttentionLayer, self).build(input_shape) 114 | 115 | def call(self, inputs, mask=None, training=None): 116 | layer_input = inputs 117 | 118 | # 119 | # TODO: is it OK to recompute the 3D attention mask in each attention layer 120 | # 121 | attention_head = self.attention_layer(layer_input, mask=mask, training=training) 122 | attention_output = self.attention_projector([attention_head, layer_input], mask=mask, training=training) 123 | 124 | return attention_output 125 | 126 | 127 | class SingleTransformerEncoderLayer(Layer): 128 | """ 129 | Multi-headed, single layer for the Transformer from 'Attention is All You Need' (arXiv: 1706.03762). 130 | 131 | See also: https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/models/transformer.py 132 | """ 133 | 134 | class Params(TransformerSelfAttentionLayer.Params, 135 | ProjectionLayer.Params): 136 | intermediate_size = None 137 | intermediate_activation = "gelu" 138 | 139 | def _construct(self, **kwargs): 140 | super()._construct(**kwargs) 141 | params = self.params 142 | if params.hidden_size % params.num_heads != 0: 143 | raise ValueError("The hidden_size:[{}] is not a multiple of num_heads:[{}]".format(params.hidden_size, 144 | params.num_heads)) 145 | self.size_per_head = params.hidden_size // params.num_heads 146 | 147 | self.self_attention_layer = None 148 | self.intermediate_layer = None 149 | self.output_projector = None 150 | 151 | self.supports_masking = True 152 | 153 | def build(self, input_shape): 154 | self.input_spec = keras.layers.InputSpec(shape=input_shape) # [B, seq_len, hidden_size] 155 | 156 | self.self_attention_layer = TransformerSelfAttentionLayer.from_params( 157 | self.params, 158 | name="attention" 159 | ) 160 | self.intermediate_layer = keras.layers.Dense( 161 | name="intermediate", 162 | units=self.params.intermediate_size, 163 | activation=self.get_activation(self.params.intermediate_activation), 164 | kernel_initializer=self.create_initializer() 165 | ) 166 | self.output_projector = ProjectionLayer.from_params( 167 | self.params, 168 | name="output", 169 | ) 170 | 171 | super(SingleTransformerEncoderLayer, self).build(input_shape) 172 | 173 | def call(self, inputs, mask=None, training=None): 174 | layer_input = inputs 175 | 176 | attention_output = self.self_attention_layer(layer_input, mask=mask, training=training) 177 | 178 | # intermediate 179 | intermediate_output = self.intermediate_layer(attention_output) 180 | 181 | # output 182 | layer_output = self.output_projector([intermediate_output, attention_output], mask=mask) 183 | 184 | return layer_output 185 | 186 | 187 | class TransformerEncoderLayer(Layer): 188 | """ 189 | Multi-headed, multi-layer Transformer from 'Attention is All You Need' (arXiv: 1706.03762). 190 | 191 | Implemented for BERT, with support for ALBERT (sharing encoder layer params). 192 | 193 | See also: https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/models/transformer.py 194 | """ 195 | 196 | class Params(SingleTransformerEncoderLayer.Params): 197 | num_layers = None 198 | out_layer_ndxs = None # [-1] 199 | 200 | shared_layer = False # False for BERT, True for ALBERT 201 | 202 | def _construct(self, **kwargs): 203 | super()._construct(**kwargs) 204 | self.encoder_layers = [] 205 | self.shared_layer = None # for ALBERT 206 | self.supports_masking = True 207 | 208 | def build(self, input_shape): 209 | self.input_spec = keras.layers.InputSpec(shape=input_shape) 210 | 211 | # create all transformer encoder sub-layers 212 | if self.params.shared_layer: 213 | # ALBERT: share params 214 | self.shared_layer = SingleTransformerEncoderLayer.from_params(self.params, name="layer_shared") 215 | else: 216 | # BERT 217 | for layer_ndx in range(self.params.num_layers): 218 | encoder_layer = SingleTransformerEncoderLayer.from_params( 219 | self.params, 220 | name="layer_{}".format(layer_ndx), 221 | ) 222 | self.encoder_layers.append(encoder_layer) 223 | 224 | super(TransformerEncoderLayer, self).build(input_shape) 225 | 226 | def call(self, inputs, mask=None, training=None): 227 | layer_output = inputs 228 | 229 | layer_outputs = [] 230 | for layer_ndx in range(self.params.num_layers): 231 | encoder_layer = self.encoder_layers[layer_ndx] if self.encoder_layers else self.shared_layer 232 | layer_input = layer_output 233 | 234 | layer_output = encoder_layer(layer_input, mask=mask, training=training) 235 | layer_outputs.append(layer_output) 236 | 237 | if self.params.out_layer_ndxs is None: 238 | # return the final layer only 239 | final_output = layer_output 240 | else: 241 | final_output = [] 242 | for ndx in self.params.out_layer_ndxs: 243 | final_output.append(layer_outputs[ndx]) 244 | final_output = tuple(final_output) 245 | 246 | return final_output 247 | 248 | 249 | -------------------------------------------------------------------------------- /bert/version.py: -------------------------------------------------------------------------------- 1 | __version__ = "0.14.10" 2 | -------------------------------------------------------------------------------- /check-before-commit.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | PEP8_IGNORE=E221,E501,W504,W391,E241 4 | 5 | pycodestyle --ignore=${PEP8_IGNORE} --exclude=tests,.venv -r --show-source tests bert 6 | 7 | coverage run --source=bert $(which nosetests) -v --with-doctest tests/ --exclude-dir tests/nonci/ 8 | coverage report --show-missing --fail-under=60 --omit "bert/tokenization/*_tokenization.py" 9 | 10 | python setup.py sdist bdist_wheel 11 | twine check dist/* -------------------------------------------------------------------------------- /requirements-dev.txt: -------------------------------------------------------------------------------- 1 | coverage 2 | docutils 3 | nose 4 | nose-exclude 5 | Pygments 6 | pycodestyle 7 | twine 8 | sentencepiece 9 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | 2 | # tensorflow >= 2.0.0 # (should be installed manually) 3 | 4 | py-params >= 0.9.6 5 | params-flow >= 0.8.0 6 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | # 4 | # created by kpe on 22.10.2018 at 11:46 5 | # 6 | 7 | from setuptools import setup, find_packages, convert_path 8 | 9 | 10 | def _version(): 11 | ns = {} 12 | with open(convert_path("bert/version.py"), "r") as fh: 13 | exec(fh.read(), ns) 14 | return ns['__version__'] 15 | 16 | 17 | __version__ = _version() 18 | 19 | 20 | with open("README.rst", "r", encoding="utf-8") as fh: 21 | long_description = fh.read() 22 | 23 | with open("requirements.txt", "r") as reader: 24 | install_requires = list(map(lambda x: x.strip(), reader.readlines())) 25 | 26 | setup(name="bert-for-tf2", 27 | version=__version__, 28 | url="https://github.com/kpe/bert-for-tf2/", 29 | description="A TensorFlow 2.0 Keras implementation of BERT.", 30 | long_description=long_description, 31 | long_description_content_type="text/x-rst", 32 | keywords="tensorflow keras bert", 33 | license="MIT", 34 | 35 | author="kpe", 36 | author_email="kpe.git@gmailbox.org", 37 | packages=find_packages(exclude=["tests"]), 38 | package_data={"": ["*.txt", "*.rst"]}, 39 | 40 | zip_safe=True, 41 | install_requires=install_requires, 42 | python_requires=">=3.5", 43 | classifiers=[ 44 | "Development Status :: 5 - Production/Stable", 45 | "License :: OSI Approved :: MIT License", 46 | "Programming Language :: Python", 47 | "Programming Language :: Python :: 3", 48 | "Programming Language :: Python :: 3.5", 49 | "Programming Language :: Python :: 3.6", 50 | "Programming Language :: Python :: 3.7", 51 | "Programming Language :: Python :: Implementation :: CPython", 52 | "Programming Language :: Python :: Implementation :: PyPy"]) 53 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # 3 | # created by kpe on 15.Mar.2019 at 12:57 4 | # 5 | from __future__ import division, absolute_import, print_function 6 | -------------------------------------------------------------------------------- /tests/ext/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # 3 | # created by kpe on 28.Mar.2019 at 15:56 4 | # 5 | 6 | from __future__ import absolute_import, division, print_function 7 | -------------------------------------------------------------------------------- /tests/nonci/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # 3 | # created by kpe on 23.May.2019 at 16:05 4 | # 5 | 6 | from __future__ import absolute_import, division, print_function 7 | -------------------------------------------------------------------------------- /tests/nonci/test_attention.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # 3 | # created by kpe on 15.Mar.2019 at 15:30 4 | # 5 | 6 | from __future__ import absolute_import, division, print_function 7 | 8 | 9 | import unittest 10 | 11 | import random 12 | import numpy as np 13 | 14 | 15 | import tensorflow as tf 16 | 17 | from bert.attention import AttentionLayer 18 | 19 | 20 | # tf.enable_v2_behavior() 21 | # tf.enable_eager_execution() 22 | 23 | 24 | class MaskFlatten(tf.keras.layers.Flatten): 25 | 26 | def __init__(self, **kwargs): 27 | self.supports_masking = True 28 | super(MaskFlatten, self).__init__(**kwargs) 29 | 30 | def compute_mask(self, _, mask=None): 31 | return mask 32 | 33 | 34 | class BertAttentionTest(unittest.TestCase): 35 | 36 | @staticmethod 37 | def data_generator(batch_size=32, max_len=10): # ([batch_size, 10], [10]) 38 | while True: 39 | data = np.zeros((batch_size, max_len)) 40 | tag = np.zeros(batch_size, dtype='int32') 41 | for i in range(batch_size): 42 | datum_len = random.randint(1, max_len - 1) 43 | total = 0 44 | for j in range(datum_len): 45 | data[i, j] = random.randint(1, 4) 46 | total += data[i, j] 47 | tag[i] = total % 2 48 | yield data, tag 49 | 50 | def test_attention(self): 51 | max_seq_len = random.randint(5, 10) 52 | count = 0 53 | for data, tag in self.data_generator(4, max_seq_len): 54 | count += 1 55 | print(data, tag) 56 | if count > 2: 57 | break 58 | 59 | class AModel(tf.keras.models.Model): 60 | def __init__(self, **kwargs): 61 | super(AModel, self).__init__(**kwargs) 62 | self.embedding = tf.keras.layers.Embedding(input_dim=5, output_dim=3, mask_zero=True) 63 | self.attention = AttentionLayer(num_heads=5, size_per_head=3) 64 | self.timedist = tf.keras.layers.TimeDistributed(MaskFlatten()) 65 | self.bigru = tf.keras.layers.Bidirectional(tf.keras.layers.GRU(units=8)) 66 | self.softmax = tf.keras.layers.Dense(units=2, activation="softmax") 67 | 68 | #def build(self, input_shape): 69 | # super(AModel,self).build(input_shape) 70 | 71 | def call(self, inputs, training=None, mask=None): 72 | out = inputs 73 | out = self.embedding(out) 74 | out = self.attention(out) 75 | out = self.timedist(out) 76 | out = self.bigru(out) 77 | out = self.softmax(out) 78 | return out 79 | 80 | model = tf.keras.models.Sequential([ 81 | tf.keras.layers.Embedding(input_dim=5, output_dim=3, mask_zero=True), 82 | AttentionLayer(num_heads=5, size_per_head=3), 83 | tf.keras.layers.TimeDistributed(MaskFlatten()), 84 | tf.keras.layers.Bidirectional(tf.keras.layers.GRU(units=8)), 85 | tf.keras.layers.Dense(units=2, activation="softmax") 86 | ]) 87 | 88 | #model = AModel() 89 | model.compile(optimizer=tf.keras.optimizers.Adam(lr=0.003), 90 | loss=tf.keras.losses.SparseCategoricalCrossentropy(), 91 | metrics=[tf.keras.metrics.SparseCategoricalAccuracy()]) 92 | # model.build(input_shape=(None, max_seq_len)) 93 | 94 | model.build() 95 | model.summary() 96 | 97 | model.fit_generator( 98 | generator=self.data_generator(64, max_seq_len), 99 | steps_per_epoch=100, 100 | epochs=10, 101 | validation_data=self.data_generator(8, max_seq_len), 102 | validation_steps=10, 103 | #callbacks=[ 104 | # keras.callbacks.EarlyStopping(monitor='val_sparse_categorical_accuracy', patience=5), 105 | #], 106 | ) 107 | -------------------------------------------------------------------------------- /tests/nonci/test_bert.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # 3 | # created by kpe on 26.Mar.2019 at 14:11 4 | # 5 | 6 | from __future__ import absolute_import, division, print_function 7 | 8 | import random 9 | 10 | import unittest 11 | 12 | import numpy as np 13 | import tensorflow as tf 14 | 15 | from tensorflow import keras 16 | 17 | from bert import BertModelLayer 18 | 19 | 20 | tf.compat.v1.enable_eager_execution() 21 | 22 | 23 | class MaskFlatten(keras.layers.Flatten): 24 | 25 | def __init__(self, **kwargs): 26 | self.supports_masking = True 27 | super(MaskFlatten, self).__init__(**kwargs) 28 | 29 | def compute_mask(self, _, mask=None): 30 | return mask 31 | 32 | 33 | def parity_ds_generator(batch_size=32, max_len=10, max_int=4, modulus=2): 34 | """ 35 | Generates a parity calculation dataset (seq -> sum(seq) mod 2), 36 | where seq is a sequence of length less than max_len 37 | of integers in [1..max_int). 38 | """ 39 | while True: 40 | data = np.zeros((batch_size, max_len)) 41 | tag = np.zeros(batch_size, dtype='int32') 42 | for i in range(batch_size): 43 | datum_len = random.randint(1, max_len - 1) 44 | total = 0 45 | for j in range(datum_len): 46 | data[i, j] = random.randint(1, max_int) 47 | total += data[i, j] 48 | tag[i] = total % modulus 49 | yield data, tag # ([batch_size, max_len], [max_len]) 50 | 51 | 52 | class RawBertTest(unittest.TestCase): 53 | 54 | def test_simple(self): 55 | max_seq_len = 10 56 | bert = BertModelLayer( 57 | vocab_size=5, 58 | max_position_embeddings=10, 59 | hidden_size=15, 60 | num_layers=2, 61 | num_heads=5, 62 | intermediate_size=4, 63 | use_token_type=False 64 | ) 65 | model = keras.Sequential([ 66 | bert, 67 | keras.layers.Lambda(lambda x: x[:, -0, ...]), # [B, 2] 68 | keras.layers.Dense(units=2, activation="softmax"), # [B, 10, 2] 69 | ]) 70 | 71 | model.build(input_shape=(None, max_seq_len)) 72 | 73 | model.compile(optimizer=keras.optimizers.Adam(lr=0.002), 74 | loss=keras.losses.sparse_categorical_crossentropy, 75 | metrics=[keras.metrics.sparse_categorical_accuracy] 76 | ) 77 | 78 | model.summary(line_length=120) 79 | 80 | for ndx, var in enumerate(model.trainable_variables): 81 | print("{:5d}".format(ndx), var.name, var.shape, var.dtype) 82 | 83 | model.fit_generator(generator=parity_ds_generator(64, max_seq_len), 84 | steps_per_epoch=100, 85 | epochs=10, 86 | validation_data=parity_ds_generator(32, max_seq_len), # TODO: can't change max_seq_len (but transformer alone can) 87 | validation_steps=10, 88 | callbacks=[ 89 | keras.callbacks.EarlyStopping(monitor='val_sparse_categorical_accuracy', patience=5), 90 | ], 91 | ) 92 | 93 | 94 | 95 | -------------------------------------------------------------------------------- /tests/nonci/test_compare_pretrained.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # 3 | # created by kpe on 27.Mar.2019 at 15:37 4 | # 5 | 6 | from __future__ import absolute_import, division, print_function 7 | 8 | import unittest 9 | import os 10 | 11 | import numpy as np 12 | 13 | import tensorflow as tf 14 | from tensorflow.python import keras 15 | 16 | 17 | import bert 18 | from bert.tokenization.bert_tokenization import FullTokenizer 19 | 20 | tf.compat.v1.disable_eager_execution() 21 | 22 | 23 | class TestCompareBertsOnPretrainedWeight(unittest.TestCase): 24 | def setUp(self) -> None: 25 | self.bert_name = "uncased_L-12_H-768_A-12" 26 | self.bert_ckpt_dir = bert.fetch_google_bert_model(self.bert_name, fetch_dir=".models") 27 | self.bert_ckpt_file = os.path.join(self.bert_ckpt_dir, "bert_model.ckpt") 28 | self.bert_config_file = os.path.join(self.bert_ckpt_dir, "bert_config.json") 29 | 30 | def test_bert_original_weights(self): 31 | print("bert checkpoint: ", self.bert_ckpt_file) 32 | bert_vars = tf.train.list_variables(self.bert_ckpt_file) 33 | for ndx, var in enumerate(bert_vars): 34 | print("{:3d}".format(ndx), var) 35 | 36 | def create_bert_model(self, max_seq_len=18): 37 | bert_params = bert.loader.params_from_pretrained_ckpt(self.bert_ckpt_dir) 38 | l_bert = bert.BertModelLayer.from_params(bert_params, name="bert") 39 | 40 | input_ids = keras.layers.Input(shape=(max_seq_len,), dtype='int32', name="input_ids") 41 | token_type_ids = keras.layers.Input(shape=(max_seq_len,), dtype='int32', name="token_type_ids") 42 | output = l_bert([input_ids, token_type_ids]) 43 | 44 | model = keras.Model(inputs=[input_ids, token_type_ids], outputs=output) 45 | 46 | return model, l_bert, (input_ids, token_type_ids) 47 | 48 | def test_keras_weights(self): 49 | max_seq_len = 18 50 | model, l_bert, inputs = self.create_bert_model(18) 51 | 52 | model.build(input_shape=[(None, max_seq_len), 53 | (None, max_seq_len)]) 54 | 55 | model.summary() 56 | 57 | for ndx, var in enumerate(l_bert.trainable_variables): 58 | print("{:3d}".format(ndx), var.name, var.shape) 59 | 60 | #for ndx, var in enumerate(model.trainable_variables): 61 | # print("{:3d}".format(ndx), var.name, var.shape) 62 | 63 | def test___compare_weights(self): 64 | 65 | tf.compat.v1.reset_default_graph() 66 | 67 | max_seq_len = 18 68 | model, l_bert, inputs = self.create_bert_model(18) 69 | model.build(input_shape=[(None, max_seq_len), 70 | (None, max_seq_len)]) 71 | 72 | stock_vars = tf.train.list_variables(self.bert_ckpt_file) 73 | stock_vars = {name: list(shape) for name, shape in stock_vars} 74 | 75 | keras_vars = model.trainable_variables 76 | keras_vars = {var.name.split(":")[0]: var.shape.as_list() for var in keras_vars} 77 | 78 | matched_vars = set() 79 | unmatched_vars = set() 80 | shape_errors = set() 81 | 82 | for name in stock_vars: 83 | bert_name = name 84 | keras_name = bert.loader.map_from_stock_variale_name(bert_name) 85 | if keras_name in keras_vars: 86 | if keras_vars[keras_name] == stock_vars[bert_name]: 87 | matched_vars.add(bert_name) 88 | else: 89 | shape_errors.add(bert_name) 90 | else: 91 | unmatched_vars.add(bert_name) 92 | 93 | print("bert -> keras:") 94 | print(" matched count:", len(matched_vars)) 95 | print(" unmatched count:", len(unmatched_vars)) 96 | print(" shape error count:", len(shape_errors)) 97 | 98 | print("unmatched:\n", "\n ".join(unmatched_vars)) 99 | 100 | self.assertEqual(197, len(matched_vars)) 101 | self.assertEqual(9, len(unmatched_vars)) 102 | self.assertEqual(0, len(shape_errors)) 103 | 104 | matched_vars = set() 105 | unmatched_vars = set() 106 | shape_errors = set() 107 | 108 | for name in keras_vars: 109 | keras_name = name 110 | bert_name = bert.loader.map_to_stock_variable_name(keras_name) 111 | if bert_name in stock_vars: 112 | if stock_vars[bert_name] == keras_vars[keras_name]: 113 | matched_vars.add(keras_name) 114 | else: 115 | shape_errors.add(keras_name) 116 | else: 117 | unmatched_vars.add(keras_name) 118 | 119 | print("keras -> bert:") 120 | print(" matched count:", len(matched_vars)) 121 | print(" unmatched count:", len(unmatched_vars)) 122 | print(" shape error count:", len(shape_errors)) 123 | 124 | print("unmatched:\n", "\n ".join(unmatched_vars)) 125 | self.assertEqual(197, len(matched_vars)) 126 | self.assertEqual(0, len(unmatched_vars)) 127 | self.assertEqual(0, len(shape_errors)) 128 | 129 | 130 | 131 | def predict_on_keras_model(self, input_ids, input_mask, token_type_ids): 132 | max_seq_len = input_ids.shape[-1] 133 | model, l_bert, k_inputs = self.create_bert_model(max_seq_len) 134 | model.build(input_shape=[(None, max_seq_len), 135 | (None, max_seq_len)]) 136 | bert.load_stock_weights(l_bert, self.bert_ckpt_file) 137 | k_res = model.predict([input_ids, token_type_ids]) 138 | return k_res 139 | 140 | def predict_on_stock_model(self, input_ids, input_mask, token_type_ids): 141 | from tests.ext.modeling import BertModel, BertConfig, get_assignment_map_from_checkpoint 142 | 143 | tf.compat.v1.reset_default_graph() 144 | 145 | tf_placeholder = tf.compat.v1.placeholder 146 | 147 | max_seq_len = input_ids.shape[-1] 148 | pl_input_ids = tf.compat.v1.placeholder(tf.int32, shape=(1, max_seq_len)) 149 | pl_mask = tf.compat.v1.placeholder(tf.int32, shape=(1, max_seq_len)) 150 | pl_token_type_ids = tf.compat.v1.placeholder(tf.int32, shape=(1, max_seq_len)) 151 | 152 | bert_config = BertConfig.from_json_file(self.bert_config_file) 153 | tokenizer = FullTokenizer(vocab_file=os.path.join(self.bert_ckpt_dir, "vocab.txt")) 154 | 155 | s_model = BertModel(config=bert_config, 156 | is_training=False, 157 | input_ids=pl_input_ids, 158 | input_mask=pl_mask, 159 | token_type_ids=pl_token_type_ids, 160 | use_one_hot_embeddings=False) 161 | 162 | tvars = tf.compat.v1.trainable_variables() 163 | (assignment_map, initialized_var_names) = get_assignment_map_from_checkpoint(tvars, self.bert_ckpt_file) 164 | tf.compat.v1.train.init_from_checkpoint(self.bert_ckpt_file, assignment_map) 165 | 166 | with tf.compat.v1.Session() as sess: 167 | sess.run(tf.compat.v1.global_variables_initializer()) 168 | 169 | s_res = sess.run( 170 | s_model.get_sequence_output(), 171 | feed_dict={pl_input_ids: input_ids, 172 | pl_token_type_ids: token_type_ids, 173 | pl_mask: input_mask, 174 | }) 175 | return s_res 176 | 177 | def test_direct_keras_to_stock_compare(self): 178 | from tests.ext.modeling import BertModel, BertConfig, get_assignment_map_from_checkpoint 179 | 180 | bert_config = BertConfig.from_json_file(self.bert_config_file) 181 | tokenizer = FullTokenizer(vocab_file=os.path.join(self.bert_ckpt_dir, "vocab.txt")) 182 | 183 | # prepare input 184 | max_seq_len = 6 185 | input_str = "Hello, Bert!" 186 | input_tokens = tokenizer.tokenize(input_str) 187 | input_tokens = ["[CLS]"] + input_tokens + ["[SEP]"] 188 | input_ids = tokenizer.convert_tokens_to_ids(input_tokens) 189 | input_ids = input_ids + [0]*(max_seq_len - len(input_tokens)) 190 | input_mask = [1]*len(input_tokens) + [0]*(max_seq_len - len(input_tokens)) 191 | token_type_ids = [0]*len(input_tokens) + [0]*(max_seq_len - len(input_tokens)) 192 | 193 | input_ids = np.array([input_ids], dtype=np.int32) 194 | input_mask = np.array([input_mask], dtype=np.int32) 195 | token_type_ids = np.array([token_type_ids], dtype=np.int32) 196 | 197 | print(" tokens:", input_tokens) 198 | print("input_ids:{}/{}:{}".format(len(input_tokens), max_seq_len, input_ids), input_ids.shape, token_type_ids) 199 | 200 | s_res = self.predict_on_stock_model(input_ids, input_mask, token_type_ids) 201 | k_res = self.predict_on_keras_model(input_ids, input_mask, token_type_ids) 202 | 203 | np.set_printoptions(precision=9, threshold=20, linewidth=200, sign="+", floatmode="fixed") 204 | print("s_res", s_res.shape) 205 | print("k_res", k_res.shape) 206 | 207 | print("s_res:\n {}".format(s_res[0, :2, :10]), s_res.dtype) 208 | print("k_res:\n {}".format(k_res[0, :2, :10]), k_res.dtype) 209 | 210 | adiff = np.abs(s_res-k_res).flatten() 211 | print("diff:", np.max(adiff), np.argmax(adiff)) 212 | self.assertTrue(np.allclose(s_res, k_res, atol=1e-6)) 213 | 214 | 215 | -------------------------------------------------------------------------------- /tests/nonci/test_load_pretrained_weights.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # 3 | # created by kpe on 10.Oct.2019 at 16:26 4 | # 5 | 6 | from __future__ import absolute_import, division, print_function 7 | 8 | import unittest 9 | 10 | import os 11 | 12 | import tensorflow as tf 13 | from tensorflow import keras 14 | 15 | import bert 16 | 17 | 18 | class TestLoadPreTrainedWeights(unittest.TestCase): 19 | 20 | def build_model(self, bert_params): 21 | l_bert = bert.BertModelLayer.from_params(bert_params, name="bert") 22 | 23 | l_input_ids = keras.layers.Input(shape=(128,), dtype='int32', name="input_ids") 24 | l_token_type_ids = keras.layers.Input(shape=(128,), dtype='int32', name="token_type_ids") 25 | output = l_bert([l_input_ids, l_token_type_ids]) 26 | output = keras.layers.Lambda(lambda x: x[:, 0, :])(output) 27 | output = keras.layers.Dense(2)(output) 28 | model = keras.Model(inputs=[l_input_ids, l_token_type_ids], outputs=output) 29 | 30 | model.build(input_shape=(None, 128)) 31 | model.compile(optimizer=keras.optimizers.Adam(), 32 | loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True), 33 | metrics=[keras.metrics.SparseCategoricalAccuracy(name="acc")]) 34 | 35 | for weight in l_bert.weights: 36 | print(weight.name) 37 | 38 | return model, l_bert 39 | 40 | def test_bert_google_weights(self): 41 | bert_model_name = "uncased_L-12_H-768_A-12" 42 | bert_dir = bert.fetch_google_bert_model(bert_model_name, ".models") 43 | bert_ckpt = os.path.join(bert_dir, "bert_model.ckpt") 44 | 45 | bert_params = bert.params_from_pretrained_ckpt(bert_dir) 46 | model, l_bert = self.build_model(bert_params) 47 | 48 | skipped_weight_value_tuples = bert.load_bert_weights(l_bert, bert_ckpt) 49 | self.assertEqual(0, len(skipped_weight_value_tuples)) 50 | model.summary() 51 | 52 | def test_albert_chinese_weights(self): 53 | albert_model_name = "albert_base" 54 | albert_dir = bert.fetch_brightmart_albert_model(albert_model_name, ".models") 55 | albert_ckpt = os.path.join(albert_dir, "albert_model.ckpt") 56 | 57 | albert_params = bert.params_from_pretrained_ckpt(albert_dir) 58 | model, l_bert = self.build_model(albert_params) 59 | 60 | skipped_weight_value_tuples = bert.load_albert_weights(l_bert, albert_ckpt) 61 | self.assertEqual(0, len(skipped_weight_value_tuples)) 62 | model.summary() 63 | 64 | def test_albert_google_weights(self): 65 | albert_model_name = "albert_base" 66 | albert_dir = bert.fetch_tfhub_albert_model(albert_model_name, ".models") 67 | 68 | albert_params = bert.albert_params(albert_model_name) 69 | model, l_bert = self.build_model(albert_params) 70 | 71 | skipped_weight_value_tuples = bert.load_albert_weights(l_bert, albert_dir) 72 | self.assertEqual(0, len(skipped_weight_value_tuples)) 73 | model.summary() 74 | 75 | def test_albert_google_weights_non_tfhub(self): 76 | albert_model_name = "albert_base_v2" 77 | albert_dir = bert.fetch_google_albert_model(albert_model_name, ".models") 78 | model_ckpt = os.path.join(albert_dir, "model.ckpt-best") 79 | 80 | albert_params = bert.albert_params(albert_dir) 81 | model, l_bert = self.build_model(albert_params) 82 | 83 | skipped_weight_value_tuples = bert.load_albert_weights(l_bert, model_ckpt) 84 | self.assertEqual(0, len(skipped_weight_value_tuples)) 85 | model.summary() 86 | -------------------------------------------------------------------------------- /tests/nonci/test_multi_lang.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # 3 | # created by kpe on 05.06.2019 at 9:01 PM 4 | # 5 | 6 | from __future__ import division, absolute_import, print_function 7 | 8 | import os 9 | 10 | import unittest 11 | 12 | import tensorflow as tf 13 | import bert 14 | 15 | 16 | class TestMultiLang(unittest.TestCase): 17 | def setUp(self) -> None: 18 | self.bert_name = "multilingual_L-12_H-768_A-12" 19 | self.bert_ckpt_dir = bert.fetch_google_bert_model(self.bert_name, fetch_dir=".models") 20 | self.bert_ckpt_file = os.path.join(self.bert_ckpt_dir, "bert_model.ckpt") 21 | self.bert_config_file = os.path.join(self.bert_ckpt_dir, "bert_config.json") 22 | 23 | def test_multi(self): 24 | print(self.bert_ckpt_dir) 25 | bert_params = bert.loader.params_from_pretrained_ckpt(self.bert_ckpt_dir) 26 | bert_params.adapter_size = 32 27 | l_bert = bert.BertModelLayer.from_params(bert_params, name="bert") 28 | 29 | max_seq_len=128 30 | l_input_ids = tf.keras.layers.Input(shape=(max_seq_len,), dtype='int32', name="input_ids") 31 | l_token_type_ids = tf.keras.layers.Input(shape=(max_seq_len,), dtype='int32', name="token_type_ids") 32 | output = l_bert([l_input_ids, l_token_type_ids]) 33 | 34 | model = tf.keras.Model(inputs=[l_input_ids, l_token_type_ids], outputs=output) 35 | model.build(input_shape=[(None, max_seq_len), 36 | (None, max_seq_len)]) 37 | 38 | bert.load_stock_weights(l_bert, self.bert_ckpt_file) 39 | 40 | model.summary() 41 | 42 | -------------------------------------------------------------------------------- /tests/nonci/test_stock_weights.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # 3 | # created by kpe on 25.Jul.2019 at 12:23 4 | # 5 | 6 | from __future__ import absolute_import, division, print_function 7 | 8 | 9 | import unittest 10 | import math 11 | import os 12 | 13 | import tensorflow as tf 14 | from tensorflow import keras 15 | 16 | import bert 17 | 18 | #tf.enable_eager_execution() 19 | #tf.disable_eager_execution() 20 | 21 | 22 | def flatten_layers(root_layer): 23 | if isinstance(root_layer, keras.layers.Layer): 24 | yield root_layer 25 | for layer in root_layer._layers: 26 | for sub_layer in flatten_layers(layer): 27 | yield sub_layer 28 | 29 | 30 | def freeze_bert_layers(l_bert): 31 | """ 32 | Freezes all but LayerNorm and adapter layers - see arXiv:1902.00751. 33 | """ 34 | for layer in flatten_layers(l_bert): 35 | if layer.name in ["LayerNorm", "adapter-down", "adapter-up"]: 36 | layer.trainable = True 37 | elif len(layer._layers) == 0: 38 | layer.trainable = False 39 | l_bert.embeddings_layer.trainable = False 40 | 41 | 42 | def create_learning_rate_scheduler(max_learn_rate=5e-5, 43 | end_learn_rate=1e-7, 44 | warmup_epoch_count=10, 45 | total_epoch_count=90): 46 | 47 | def lr_scheduler(epoch): 48 | if epoch < warmup_epoch_count: 49 | res = (max_learn_rate/warmup_epoch_count) * (epoch + 1) 50 | else: 51 | res = max_learn_rate*math.exp(math.log(end_learn_rate/max_learn_rate)*(epoch-warmup_epoch_count+1)/(total_epoch_count-warmup_epoch_count+1)) 52 | return float(res) 53 | learning_rate_scheduler = tf.keras.callbacks.LearningRateScheduler(lr_scheduler, verbose=1) 54 | 55 | return learning_rate_scheduler 56 | 57 | 58 | class TestWeightsLoading(unittest.TestCase): 59 | #bert_ckpt_dir = ".models/uncased_L-12_H-768_A-12/" 60 | #bert_ckpt_file = bert_ckpt_dir + "bert_model.ckpt" 61 | #bert_config_file = bert_ckpt_dir + "bert_config.json" 62 | 63 | def setUp(self) -> None: 64 | self.bert_name = "uncased_L-12_H-768_A-12" 65 | self.bert_ckpt_dir = bert.fetch_google_bert_model(self.bert_name, fetch_dir=".models") 66 | self.bert_ckpt_file = os.path.join(self.bert_ckpt_dir, "bert_model.ckpt") 67 | self.bert_config_file = os.path.join(self.bert_ckpt_dir, "bert_config.json") 68 | 69 | def test_load_pretrained(self): 70 | print("Eager Execution:", tf.executing_eagerly()) 71 | 72 | bert_params = bert.loader.params_from_pretrained_ckpt(self.bert_ckpt_dir) 73 | bert_params.adapter_size = 32 74 | l_bert = bert.BertModelLayer.from_params(bert_params, name="bert") 75 | 76 | model = keras.models.Sequential([ 77 | keras.layers.InputLayer(input_shape=(128,)), 78 | l_bert, 79 | keras.layers.Lambda(lambda x: x[:, 0, :]), 80 | keras.layers.Dense(2) 81 | ]) 82 | 83 | # we need to freeze before build/compile - otherwise keras counts the params twice 84 | if bert_params.adapter_size is not None: 85 | freeze_bert_layers(l_bert) 86 | 87 | model.build(input_shape=(None, 128)) 88 | model.compile(optimizer=keras.optimizers.Adam(), 89 | loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True), 90 | metrics=[keras.metrics.SparseCategoricalAccuracy(name="acc")]) 91 | 92 | bert.load_stock_weights(l_bert, self.bert_ckpt_file) 93 | 94 | model.summary() 95 | 96 | 97 | -------------------------------------------------------------------------------- /tests/nonci/test_transformer.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # 3 | # created by kpe on 21.Mar.2019 at 13:30 4 | # 5 | 6 | from __future__ import absolute_import, division, print_function 7 | 8 | import random 9 | 10 | import unittest 11 | 12 | import numpy as np 13 | import tensorflow as tf 14 | 15 | from tensorflow import keras 16 | 17 | 18 | from bert.transformer import TransformerEncoderLayer 19 | 20 | 21 | class MaskFlatten(keras.layers.Flatten): 22 | 23 | def __init__(self, **kwargs): 24 | self.supports_masking = True 25 | super(MaskFlatten, self).__init__(**kwargs) 26 | 27 | def compute_mask(self, _, mask=None): 28 | return mask 29 | 30 | 31 | def parity_ds_generator(batch_size=32, max_len=10, max_int=4, modulus=2): 32 | """ 33 | Generates a parity calculation dataset (seq -> sum(seq) mod 2), 34 | where seq is a sequence of length less than max_len 35 | of integers in [1..max_int). 36 | """ 37 | while True: 38 | data = np.zeros((batch_size, max_len)) 39 | tag = np.zeros(batch_size, dtype='int32') 40 | for i in range(batch_size): 41 | datum_len = random.randint(1, max_len - 1) 42 | total = 0 43 | for j in range(datum_len): 44 | data[i, j] = random.randint(1, max_int) 45 | total += data[i, j] 46 | tag[i] = total % modulus 47 | yield data, tag # ([batch_size, max_len], [max_len]) 48 | 49 | 50 | class TransformerTest(unittest.TestCase): 51 | 52 | def test_simple(self): 53 | max_seq_len = 10 54 | model = keras.Sequential([ 55 | keras.layers.Embedding(input_dim=5, output_dim=15, mask_zero=True), # [B, 10, 12] 56 | TransformerEncoderLayer( 57 | hidden_size=15, 58 | num_heads=5, 59 | num_layers=2, 60 | intermediate_size=8, 61 | hidden_dropout=0.1), # [B, 10, 6] 62 | keras.layers.TimeDistributed( 63 | keras.layers.Dense(units=2, activation="softmax")), # [B, 10, 2] 64 | keras.layers.Lambda(lambda x: x[:, -0, ...]) # [B, 2] 65 | ]) 66 | 67 | model.build(input_shape=(None, max_seq_len)) 68 | 69 | model.compile(optimizer=keras.optimizers.Adam(lr=0.003), 70 | loss=keras.losses.sparse_categorical_crossentropy, 71 | metrics=[keras.metrics.sparse_categorical_accuracy] 72 | ) 73 | model.summary(line_length=120) 74 | 75 | model.fit_generator(generator=parity_ds_generator(64, max_seq_len), 76 | steps_per_epoch=100, 77 | epochs=20, 78 | validation_data=parity_ds_generator(12, -4+max_seq_len), 79 | validation_steps=10, 80 | callbacks=[ 81 | keras.callbacks.EarlyStopping(monitor='val_sparse_categorical_accuracy', patience=5), 82 | ], 83 | ) 84 | 85 | 86 | 87 | 88 | -------------------------------------------------------------------------------- /tests/test_adapter_finetune.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # 3 | # created by kpe on 02.Sep.2019 at 23:57 4 | # 5 | 6 | from __future__ import absolute_import, division, print_function 7 | 8 | 9 | import os 10 | 11 | import numpy as np 12 | import tensorflow as tf 13 | from tensorflow import keras 14 | 15 | import params_flow as pf 16 | 17 | import bert 18 | 19 | from .test_common import MiniBertFactory, AbstractBertTest 20 | 21 | 22 | class TestAdapterFineTuning(AbstractBertTest): 23 | """ 24 | Demonstrates a fine tuning workflow using adapte-BERT with 25 | storing the fine tuned and frozen pre-trained weights in 26 | separate checkpoint files. 27 | """ 28 | 29 | def setUp(self) -> None: 30 | tf.compat.v1.reset_default_graph() 31 | tf.compat.v1.enable_eager_execution() 32 | print("Eager Execution:", tf.executing_eagerly()) 33 | 34 | # build a dummy bert 35 | self.ckpt_path = MiniBertFactory.create_mini_bert_weights() 36 | self.ckpt_dir = os.path.dirname(self.ckpt_path) 37 | self.tokenizer = bert.bert_tokenization.FullTokenizer(vocab_file=os.path.join(self.ckpt_dir, "vocab.txt"), do_lower_case=True) 38 | 39 | def test_coverage_improve(self): 40 | bert_params = bert.params_from_pretrained_ckpt(self.ckpt_dir) 41 | model, l_bert = self.build_model(bert_params, 1) 42 | for weight in model.weights: 43 | l_bert_prefix = bert.loader.bert_prefix(l_bert) 44 | 45 | stock_name = bert.loader.map_to_stock_variable_name(weight.name, l_bert_prefix) 46 | 47 | if stock_name is None: 48 | print("No BERT stock weight for", weight.name) 49 | continue 50 | 51 | keras_name = bert.loader.map_from_stock_variale_name(stock_name, l_bert_prefix) 52 | self.assertEqual(weight.name.split(":")[0], keras_name) 53 | 54 | @staticmethod 55 | def build_model(bert_params, max_seq_len): 56 | # enable adapter-BERT 57 | bert_params.adapter_size = 2 58 | l_bert = bert.BertModelLayer.from_params(bert_params) 59 | model = keras.models.Sequential([ 60 | l_bert, 61 | keras.layers.Lambda(lambda seq: seq[:, 0, :]), 62 | keras.layers.Dense(3, name="test_cls") 63 | ]) 64 | model.compile(optimizer=keras.optimizers.Adam(), 65 | loss=keras.losses.SparseCategoricalCrossentropy(), 66 | metrics=[keras.metrics.SparseCategoricalAccuracy()]) 67 | 68 | # build for a given max_seq_len 69 | model.build(input_shape=(None, max_seq_len)) 70 | return model, l_bert 71 | 72 | def test_regularization(self): 73 | # create a BERT layer with config from the checkpoint 74 | bert_params = bert.params_from_pretrained_ckpt(self.ckpt_dir) 75 | 76 | max_seq_len = 12 77 | 78 | model, l_bert = self.build_model(bert_params, max_seq_len=max_seq_len) 79 | l_bert.apply_adapter_freeze() 80 | model.summary() 81 | 82 | kernel_regularizer = keras.regularizers.l2(0.01) 83 | bias_regularizer = keras.regularizers.l2(0.01) 84 | 85 | pf.utils.add_dense_layer_loss(model, 86 | kernel_regularizer=kernel_regularizer, 87 | bias_regularizer=bias_regularizer) 88 | # prepare the data 89 | inputs, targets = ["hello world", "goodbye"], [1, 2] 90 | tokens = [self.tokenizer.tokenize(toks) for toks in inputs] 91 | tokens = [self.tokenizer.convert_tokens_to_ids(toks) for toks in tokens] 92 | tokens = [toks + [0]*(max_seq_len - len(toks)) for toks in tokens] 93 | x = np.array(tokens) 94 | y = np.array(targets) 95 | # fine tune 96 | model.fit(x, y, epochs=3) 97 | 98 | def test_finetuning_workflow(self): 99 | # create a BERT layer with config from the checkpoint 100 | bert_params = bert.params_from_pretrained_ckpt(self.ckpt_dir) 101 | 102 | max_seq_len = 12 103 | 104 | model, l_bert = self.build_model(bert_params, max_seq_len=max_seq_len) 105 | model.summary() 106 | 107 | # freeze non-adapter weights 108 | l_bert.apply_adapter_freeze() 109 | model.summary() 110 | 111 | # load the BERT weights from the pre-trained model 112 | bert.load_stock_weights(l_bert, self.ckpt_path) 113 | 114 | # prepare the data 115 | inputs, targets = ["hello world", "goodbye"], [1, 2] 116 | tokens = [self.tokenizer.tokenize(toks) for toks in inputs] 117 | tokens = [self.tokenizer.convert_tokens_to_ids(toks) for toks in tokens] 118 | tokens = [toks + [0]*(max_seq_len - len(toks)) for toks in tokens] 119 | x = np.array(tokens) 120 | y = np.array(targets) 121 | 122 | # fine tune 123 | model.fit(x, y, epochs=3) 124 | 125 | # preserve the logits for comparison before and after restoring the fine-tuned model 126 | logits = model.predict(x) 127 | 128 | # now store the adapter weights only 129 | 130 | # old fashion - using saver 131 | # finetuned_weights = {w.name: w.value() for w in model.trainable_weights} 132 | # saver = tf.compat.v1.train.Saver(finetuned_weights) 133 | # fine_path = saver.save(tf.compat.v1.keras.backend.get_session(), fine_ckpt) 134 | 135 | fine_ckpt = os.path.join(self.ckpt_dir, "fine-tuned.ckpt") 136 | finetuned_weights = {w.name: w for w in model.trainable_weights} 137 | checkpoint = tf.train.Checkpoint(**finetuned_weights) 138 | fine_path = checkpoint.save(file_prefix=fine_ckpt) 139 | print("fine tuned ckpt:", fine_path) 140 | 141 | # build new model 142 | tf.compat.v1.keras.backend.clear_session() 143 | model, l_bert = self.build_model(bert_params, max_seq_len=max_seq_len) 144 | l_bert.apply_adapter_freeze() 145 | 146 | # load the BERT weights from the pre-trained checkpoint 147 | bert.load_stock_weights(l_bert, self.ckpt_path) 148 | 149 | # load the fine tuned classifier model weights 150 | finetuned_weights = {w.name: w for w in model.trainable_weights} 151 | checkpoint = tf.train.Checkpoint(**finetuned_weights) 152 | load_status = checkpoint.restore(fine_path) 153 | load_status.assert_consumed().run_restore_ops() 154 | 155 | logits_restored = model.predict(x) 156 | 157 | # check the predictions of the restored model 158 | self.assertTrue(np.allclose(logits_restored, logits, 1e-6)) 159 | -------------------------------------------------------------------------------- /tests/test_adapter_freeze.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # 3 | # created by kpe on 09.08.2019 at 10:38 PM 4 | # 5 | 6 | from __future__ import division, absolute_import, print_function 7 | 8 | 9 | import unittest 10 | 11 | import os 12 | import tempfile 13 | 14 | import numpy as np 15 | import tensorflow as tf 16 | from tensorflow import keras 17 | 18 | import bert 19 | 20 | from .test_common import AbstractBertTest, MiniBertFactory 21 | 22 | 23 | class AdapterFreezeTest(AbstractBertTest): 24 | 25 | def setUp(self) -> None: 26 | tf.compat.v1.reset_default_graph() 27 | tf.compat.v1.enable_eager_execution() 28 | print("Eager Execution:", tf.executing_eagerly()) 29 | 30 | def test_adapter_freezing(self): 31 | bert_params = bert.BertModelLayer.Params(hidden_size=32, 32 | vocab_size=67, 33 | max_position_embeddings=64, 34 | num_layers=1, 35 | num_heads=1, 36 | intermediate_size=4, 37 | use_token_type=False) 38 | 39 | def to_model(bert_params): 40 | l_bert = bert.BertModelLayer.from_params(bert_params) 41 | 42 | token_ids = keras.layers.Input(shape=(21,)) 43 | seq_out = l_bert(token_ids) 44 | model = keras.Model(inputs=[token_ids], outputs=seq_out) 45 | 46 | model.build(input_shape=(None, 21)) 47 | l_bert.apply_adapter_freeze() 48 | 49 | return model 50 | 51 | model = to_model(bert_params) 52 | model.summary() 53 | print("trainable wegihts:", len(model.trainable_weights)) 54 | self.assertEqual(20, len(model.trainable_weights)) 55 | for weight in model.trainable_weights: 56 | print(weight.name, weight.shape) 57 | 58 | bert_params.adapter_size = 16 59 | 60 | model = to_model(bert_params) 61 | model.summary() 62 | print("trainable weights:", len(model.trainable_weights)) 63 | self.assertEqual(14, len(model.trainable_weights)) 64 | for weight in model.trainable_weights: 65 | print(weight.name, weight.shape) 66 | 67 | def test_bert_freeze(self): 68 | model_dir = tempfile.TemporaryDirectory().name 69 | os.makedirs(model_dir) 70 | save_path = MiniBertFactory.create_mini_bert_weights(model_dir) 71 | tokenizer = bert.bert_tokenization.FullTokenizer(vocab_file=os.path.join(model_dir, "vocab.txt"), do_lower_case=True) 72 | 73 | # prepare input 74 | max_seq_len = 24 75 | input_str_batch = ["hello, bert!", "how are you doing!"] 76 | 77 | input_ids, token_type_ids = self.prepare_input_batch(input_str_batch, tokenizer, max_seq_len) 78 | 79 | bert_ckpt_file = os.path.join(model_dir, "bert_model.ckpt") 80 | 81 | bert_params = bert.params_from_pretrained_ckpt(model_dir) 82 | bert_params.adapter_size = 4 83 | l_bert = bert.BertModelLayer.from_params(bert_params) 84 | 85 | model = keras.models.Sequential([ 86 | l_bert, 87 | ]) 88 | 89 | model.build(input_shape=(None, max_seq_len)) 90 | 91 | model.summary() 92 | l_bert.apply_adapter_freeze() 93 | model.summary() 94 | 95 | bert.load_stock_weights(l_bert, bert_ckpt_file) 96 | #l_bert.embeddings_layer.trainable = False 97 | 98 | model.summary() 99 | 100 | orig_weight_values = [] 101 | for weight in l_bert.weights: 102 | orig_weight_values.append(weight.numpy()) 103 | 104 | model.compile(optimizer=keras.optimizers.Adam(), 105 | loss=keras.losses.mean_squared_error, 106 | run_eagerly=True) 107 | 108 | trainable_count = len(l_bert.trainable_weights) 109 | 110 | orig_pred = model.predict(input_ids) 111 | model.fit(x=input_ids, y=np.zeros_like(orig_pred), 112 | batch_size=2, 113 | epochs=4) 114 | 115 | trained_count = 0 116 | for ndx, weight in enumerate(l_bert.weights): 117 | weight_equal = np.array_equal(weight.numpy(), orig_weight_values[ndx]) 118 | print("{}: {}".format(weight_equal, weight.name)) 119 | if not weight_equal: 120 | trained_count += 1 121 | 122 | print(" trained weights:", trained_count) 123 | print("trainable weights:", trainable_count) 124 | self.assertEqual(trained_count, trainable_count) 125 | 126 | model.summary() 127 | 128 | def test_adapter_albert_freeze(self): 129 | model_dir = tempfile.TemporaryDirectory().name 130 | os.makedirs(model_dir) 131 | # for tokenizer only 132 | save_path = MiniBertFactory.create_mini_bert_weights(model_dir) 133 | tokenizer = bert.bert_tokenization.FullTokenizer(vocab_file=os.path.join(model_dir, "vocab.txt"), do_lower_case=True) 134 | 135 | # prepare input 136 | max_seq_len = 28 137 | input_str_batch = ["hello, albert!", "how are you doing!"] 138 | input_ids, token_type_ids = self.prepare_input_batch(input_str_batch, tokenizer, max_seq_len, 139 | extra_token_count=3) 140 | 141 | bert_params = bert.BertModelLayer.Params( 142 | attention_dropout=0.1, 143 | intermediate_activation="gelu", 144 | hidden_dropout=0.1, 145 | hidden_size=8, 146 | initializer_range=0.02, 147 | intermediate_size=32, 148 | max_position_embeddings=32, 149 | num_heads=2, 150 | num_layers=2, 151 | token_type_vocab_size=2, 152 | vocab_size=len(tokenizer.vocab), 153 | 154 | adapter_size=2, 155 | 156 | embedding_size=4, 157 | extra_tokens_vocab_size=3, 158 | shared_layer=True, 159 | ) 160 | l_bert = bert.BertModelLayer.from_params(bert_params) 161 | 162 | model = keras.models.Sequential([ 163 | l_bert, 164 | ]) 165 | 166 | model.build(input_shape=(None, max_seq_len)) 167 | 168 | model.summary() 169 | l_bert.apply_adapter_freeze() 170 | model.summary() 171 | 172 | orig_weight_values = [] 173 | for weight in l_bert.weights: 174 | orig_weight_values.append(weight.numpy()) 175 | 176 | model.compile(optimizer=keras.optimizers.Adam(), 177 | loss=keras.losses.mean_squared_error, 178 | run_eagerly=True) 179 | 180 | trainable_count = len(l_bert.trainable_weights) 181 | 182 | orig_pred = model.predict(input_ids) 183 | model.fit(x=input_ids, y=np.zeros_like(orig_pred), 184 | batch_size=2, 185 | epochs=4) 186 | 187 | trained_count = 0 188 | for ndx, weight in enumerate(l_bert.weights): 189 | weight_equal = np.array_equal(weight.numpy(), orig_weight_values[ndx]) 190 | print("trained:[{}]: {}".format(not weight_equal, weight.name)) 191 | if not weight_equal: 192 | trained_count += 1 193 | 194 | print(" trained weights:", trained_count) 195 | print("trainable weights:", trainable_count) 196 | self.assertEqual(trained_count, trainable_count) 197 | 198 | model.summary() 199 | 200 | 201 | 202 | 203 | 204 | 205 | -------------------------------------------------------------------------------- /tests/test_albert_create.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # 3 | # created by kpe on 10.Oct.2019 at 15:41 4 | # 5 | 6 | from __future__ import absolute_import, division, print_function 7 | import unittest 8 | 9 | import os 10 | import tempfile 11 | 12 | import numpy as np 13 | import tensorflow as tf 14 | from tensorflow import keras 15 | 16 | import bert 17 | 18 | from .test_common import AbstractBertTest, MiniBertFactory 19 | 20 | 21 | class AlbertTest(AbstractBertTest): 22 | 23 | def setUp(self) -> None: 24 | tf.compat.v1.reset_default_graph() 25 | tf.compat.v1.enable_eager_execution() 26 | print("Eager Execution:", tf.executing_eagerly()) 27 | 28 | def test_albert(self): 29 | bert_params = bert.BertModelLayer.Params(hidden_size=32, 30 | vocab_size=67, 31 | max_position_embeddings=64, 32 | num_layers=1, 33 | num_heads=1, 34 | intermediate_size=4, 35 | use_token_type=False, 36 | 37 | embedding_size=16, # using ALBERT instead of BERT 38 | project_embeddings_with_bias=True, 39 | shared_layer=True, 40 | extra_tokens_vocab_size=3, 41 | ) 42 | 43 | 44 | def to_model(bert_params): 45 | l_bert = bert.BertModelLayer.from_params(bert_params) 46 | 47 | token_ids = keras.layers.Input(shape=(21,)) 48 | seq_out = l_bert(token_ids) 49 | model = keras.Model(inputs=[token_ids], outputs=seq_out) 50 | 51 | model.build(input_shape=(None, 21)) 52 | l_bert.apply_adapter_freeze() 53 | 54 | return model 55 | 56 | model = to_model(bert_params) 57 | model.summary() 58 | 59 | print("trainable_weights:", len(model.trainable_weights)) 60 | for weight in model.trainable_weights: 61 | print(weight.name, weight.shape) 62 | self.assertEqual(23, len(model.trainable_weights)) 63 | 64 | # adapter-ALBERT :-) 65 | 66 | bert_params.adapter_size = 16 67 | 68 | model = to_model(bert_params) 69 | model.summary() 70 | 71 | print("trainable_weights:", len(model.trainable_weights)) 72 | for weight in model.trainable_weights: 73 | print(weight.name, weight.shape) 74 | self.assertEqual(15, len(model.trainable_weights)) 75 | 76 | print("non_trainable_weights:", len(model.non_trainable_weights)) 77 | for weight in model.non_trainable_weights: 78 | print(weight.name, weight.shape) 79 | self.assertEqual(16, len(model.non_trainable_weights)) 80 | 81 | def test_albert_load_base_google_weights(self): # for coverage mainly 82 | albert_model_name = "albert_base" 83 | albert_dir = bert.fetch_tfhub_albert_model(albert_model_name, ".models") 84 | model_params = bert.albert_params(albert_model_name) 85 | 86 | l_bert = bert.BertModelLayer.from_params(model_params, name="albert") 87 | 88 | model = keras.models.Sequential([ 89 | keras.layers.InputLayer(input_shape=(8,), dtype=tf.int32, name="input_ids"), 90 | l_bert, 91 | keras.layers.Lambda(lambda x: x[:, 0, :]), 92 | keras.layers.Dense(2), 93 | ]) 94 | model.build(input_shape=(None, 8)) 95 | model.compile(optimizer=keras.optimizers.Adam(), 96 | loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True), 97 | metrics=[keras.metrics.SparseCategoricalAccuracy(name="acc")]) 98 | 99 | bert.load_albert_weights(l_bert, albert_dir) 100 | 101 | model.summary() 102 | 103 | def test_albert_params(self): 104 | albert_model_name = "albert_base" 105 | albert_dir = bert.fetch_tfhub_albert_model(albert_model_name, ".models") 106 | dir_params = bert.albert_params(albert_dir) 107 | dir_params.attention_dropout = 0.1 # diff between README and assets/albert_config.json 108 | dir_params.hidden_dropout = 0.1 109 | name_params = bert.albert_params(albert_model_name) 110 | self.assertEqual(name_params, dir_params) 111 | 112 | # coverage 113 | model_params = dir_params 114 | model_params.vocab_size = model_params.vocab_size 115 | model_params.adapter_size = 1 116 | l_bert = bert.BertModelLayer.from_params(model_params, name="albert") 117 | l_bert(tf.zeros((1, 128))) 118 | bert.load_albert_weights(l_bert, albert_dir) 119 | 120 | def test_albert_zh_fetch_and_load(self): 121 | albert_model_name = "albert_tiny" 122 | albert_dir = bert.fetch_brightmart_albert_model(albert_model_name, ".models") 123 | 124 | model_params = bert.params_from_pretrained_ckpt(albert_dir) 125 | model_params.vocab_size = model_params.vocab_size + 2 126 | model_params.adapter_size = 1 127 | l_bert = bert.BertModelLayer.from_params(model_params, name="albert") 128 | l_bert(tf.zeros((1, 128))) 129 | res = bert.load_albert_weights(l_bert, albert_dir) 130 | self.assertTrue(len(res) > 0) 131 | 132 | def test_coverage(self): 133 | try: 134 | bert.fetch_google_bert_model("not-existent_bert_model", ".models") 135 | except: 136 | pass -------------------------------------------------------------------------------- /tests/test_attention.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # 3 | # created by kpe on 30.Jul.2019 at 16:41 4 | # 5 | 6 | from __future__ import absolute_import, division, print_function 7 | 8 | 9 | import unittest 10 | 11 | import random 12 | 13 | import bert 14 | 15 | import numpy as np 16 | import tensorflow as tf 17 | from tensorflow import keras 18 | 19 | 20 | class TestAttention(unittest.TestCase): 21 | 22 | def test_attention(self): 23 | am = bert.AttentionLayer.create_attention_mask(from_shape=[2, 3, 5], # B,S,.. 24 | input_mask=[[2], [1]] # B,seq_len 25 | ) 26 | print(am) # [batch_size, from_seq_len, seq_len] 27 | 28 | def test_compute_shape(self): 29 | l_att = bert.AttentionLayer(num_heads=2, size_per_head=2) 30 | l_att.compute_output_shape(input_shape=(16, 8, 2)) -------------------------------------------------------------------------------- /tests/test_common.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # 3 | # created by kpe on 25.Jul.2019 at 13:30 4 | # 5 | 6 | from __future__ import absolute_import, division, print_function 7 | 8 | 9 | import os 10 | import string 11 | import unittest 12 | import tempfile 13 | 14 | import tensorflow as tf 15 | import numpy as np 16 | 17 | from tensorflow.python import keras 18 | 19 | import bert 20 | from bert import bert_tokenization 21 | 22 | 23 | class MiniBertFactory: 24 | 25 | @staticmethod 26 | def create_mini_bert_weights(model_dir=None): 27 | model_dir = model_dir if model_dir is not None else tempfile.TemporaryDirectory().name 28 | os.makedirs(model_dir, exist_ok=True) 29 | 30 | from bert.loader import StockBertConfig 31 | 32 | bert_tokens = ["[PAD]", "[UNK]", "[CLS]", "[SEP]", "[MASK]"] 33 | bert_config = StockBertConfig( 34 | attention_probs_dropout_prob = 0.1, 35 | hidden_act = "gelu", 36 | hidden_dropout_prob = 0.1, 37 | hidden_size = 8, 38 | initializer_range = 0.02, 39 | intermediate_size = 32, 40 | max_position_embeddings = 32, 41 | num_attention_heads = 2, 42 | num_hidden_layers = 2, 43 | type_vocab_size = 2, 44 | vocab_size = len(string.ascii_lowercase)*2 + len(bert_tokens), 45 | ) 46 | 47 | print("creating mini BERT at:", model_dir) 48 | 49 | bert_config_file = os.path.join(model_dir, "bert_config.json") 50 | bert_vocab_file = os.path.join(model_dir, "vocab.txt") 51 | 52 | with open(bert_config_file, "w") as f: 53 | f.write(bert_config.to_json_string()) 54 | with open(bert_vocab_file, "w") as f: 55 | f.write("\n".join(list(string.ascii_lowercase) + bert_tokens)) 56 | f.write("\n".join(["##"+tok for tok in list(string.ascii_lowercase)])) 57 | 58 | with tf.Graph().as_default(): 59 | _ = MiniBertFactory.create_stock_bert_graph(bert_config_file, 16) 60 | saver = tf.compat.v1.train.Saver(max_to_keep=1, save_relative_paths=True) 61 | 62 | with tf.compat.v1.Session() as sess: 63 | sess.run(tf.compat.v1.global_variables_initializer()) 64 | ckpt_path = os.path.join(model_dir, "bert_model.ckpt") 65 | save_path = saver.save(sess, ckpt_path, write_meta_graph=True) 66 | print("saving to:", save_path) 67 | 68 | bert_tokenization.validate_case_matches_checkpoint(True, save_path) 69 | 70 | return save_path 71 | 72 | @staticmethod 73 | def create_stock_bert_graph(bert_config_file, max_seq_len): 74 | from tests.ext.modeling import BertModel, BertConfig 75 | 76 | tf_placeholder = tf.compat.v1.placeholder 77 | 78 | pl_input_ids = tf_placeholder(tf.int32, shape=(1, max_seq_len)) 79 | pl_mask = tf_placeholder(tf.int32, shape=(1, max_seq_len)) 80 | pl_token_type_ids = tf_placeholder(tf.int32, shape=(1, max_seq_len)) 81 | 82 | bert_config = BertConfig.from_json_file(bert_config_file) 83 | s_model = BertModel(config=bert_config, 84 | is_training=False, 85 | input_ids=pl_input_ids, 86 | input_mask=pl_mask, 87 | token_type_ids=pl_token_type_ids, 88 | use_one_hot_embeddings=False) 89 | 90 | return s_model, pl_input_ids, pl_mask, pl_token_type_ids 91 | 92 | 93 | class AbstractBertTest(unittest.TestCase): 94 | 95 | @staticmethod 96 | def create_mini_bert_weights(): 97 | model_dir = tempfile.TemporaryDirectory().name 98 | # model_dir = "/tmp/mini_bert/"; 99 | os.makedirs(model_dir, exist_ok=True) 100 | save_path = MiniBertFactory.create_mini_bert_weights(model_dir) 101 | print("mini_bert save_path", save_path) 102 | print("\n\t".join([""] + os.listdir(model_dir))) 103 | return model_dir 104 | 105 | def prepare_input_batch(self, input_str_batch, tokenizer, max_seq_len, extra_token_count=0): 106 | input_ids_batch = [] 107 | token_type_ids_batch = [] 108 | 109 | def extra_token_gen(): 110 | token = 0 111 | while True: 112 | yield - ((token % extra_token_count) + 1) 113 | token += 1 114 | 115 | extra_token = extra_token_gen() 116 | 117 | for input_str in input_str_batch: 118 | input_tokens = tokenizer.tokenize(input_str) 119 | input_tokens = ["[CLS]"] + input_tokens + ["[SEP]"] 120 | 121 | print("input_tokens len:", len(input_tokens)) 122 | 123 | input_ids = tokenizer.convert_tokens_to_ids(input_tokens) 124 | if extra_token_count > 0: 125 | input_ids = [next(extra_token)] + input_ids + [next(extra_token)] 126 | input_ids = input_ids + [0]*(max_seq_len - len(input_ids)) 127 | token_type_ids = [0]*len(input_ids) 128 | 129 | input_ids_batch.append(input_ids) 130 | token_type_ids_batch.append(token_type_ids) 131 | 132 | input_ids = np.array(input_ids_batch, dtype=np.int32) 133 | token_type_ids = np.array(token_type_ids_batch, dtype=np.int32) 134 | 135 | return input_ids, token_type_ids -------------------------------------------------------------------------------- /tests/test_compare_activations.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # 3 | # created by kpe on 23.May.2019 at 17:10 4 | # 5 | 6 | from __future__ import absolute_import, division, print_function 7 | 8 | import os 9 | import string 10 | import unittest 11 | import tempfile 12 | 13 | import tensorflow as tf 14 | import numpy as np 15 | 16 | from tensorflow import keras 17 | 18 | from bert import bert_tokenization 19 | 20 | from .test_common import AbstractBertTest, MiniBertFactory 21 | 22 | 23 | class CompareBertActivationsTest(AbstractBertTest): 24 | 25 | def setUp(self): 26 | tf.compat.v1.reset_default_graph() 27 | keras.backend.clear_session() 28 | tf.compat.v1.disable_eager_execution() 29 | print("Eager Execution:", tf.executing_eagerly()) 30 | 31 | @staticmethod 32 | def load_stock_model(model_dir, max_seq_len): 33 | from tests.ext.modeling import BertModel, BertConfig, get_assignment_map_from_checkpoint 34 | 35 | tf.compat.v1.reset_default_graph() # to scope naming for checkpoint loading (if executed more than once) 36 | 37 | bert_config_file = os.path.join(model_dir, "bert_config.json") 38 | bert_ckpt_file = os.path.join(model_dir, "bert_model.ckpt") 39 | 40 | pl_input_ids = tf.compat.v1.placeholder(tf.int32, shape=(1, max_seq_len)) 41 | pl_mask = tf.compat.v1.placeholder(tf.int32, shape=(1, max_seq_len)) 42 | pl_token_type_ids = tf.compat.v1.placeholder(tf.int32, shape=(1, max_seq_len)) 43 | 44 | bert_config = BertConfig.from_json_file(bert_config_file) 45 | 46 | s_model = BertModel(config=bert_config, 47 | is_training=False, 48 | input_ids=pl_input_ids, 49 | input_mask=pl_mask, 50 | token_type_ids=pl_token_type_ids, 51 | use_one_hot_embeddings=False) 52 | 53 | tvars = tf.compat.v1.trainable_variables() 54 | (assignment_map, initialized_var_names) = get_assignment_map_from_checkpoint(tvars, bert_ckpt_file) 55 | tf.compat.v1.train.init_from_checkpoint(bert_ckpt_file, assignment_map) 56 | 57 | return s_model, pl_input_ids, pl_token_type_ids, pl_mask 58 | 59 | @staticmethod 60 | def predict_on_stock_model(model_dir, input_ids, input_mask, token_type_ids): 61 | max_seq_len = input_ids.shape[-1] 62 | (s_model, 63 | pl_input_ids, pl_token_type_ids, pl_mask) = CompareBertActivationsTest.load_stock_model(model_dir, max_seq_len) 64 | 65 | with tf.compat.v1.Session() as sess: 66 | sess.run(tf.compat.v1.global_variables_initializer()) 67 | 68 | s_res = sess.run( 69 | s_model.get_sequence_output(), 70 | feed_dict={pl_input_ids: input_ids, 71 | pl_token_type_ids: token_type_ids, 72 | pl_mask: input_mask, 73 | }) 74 | return s_res 75 | 76 | @staticmethod 77 | def load_keras_model(model_dir, max_seq_len): 78 | from tensorflow.python import keras 79 | from bert import BertModelLayer 80 | from bert.loader import StockBertConfig, load_stock_weights, params_from_pretrained_ckpt 81 | 82 | bert_config_file = os.path.join(model_dir, "bert_config.json") 83 | bert_ckpt_file = os.path.join(model_dir, "bert_model.ckpt") 84 | 85 | l_bert = BertModelLayer.from_params(params_from_pretrained_ckpt(model_dir)) 86 | 87 | l_input_ids = keras.layers.Input(shape=(max_seq_len,), dtype='int32', name="input_ids") 88 | l_token_type_ids = keras.layers.Input(shape=(max_seq_len,), dtype='int32', name="token_type_ids") 89 | 90 | output = l_bert([l_input_ids, l_token_type_ids]) 91 | 92 | model = keras.Model(inputs=[l_input_ids, l_token_type_ids], outputs=output) 93 | model.build(input_shape=[(None, max_seq_len), 94 | (None, max_seq_len)]) 95 | 96 | load_stock_weights(l_bert, bert_ckpt_file) 97 | return model 98 | 99 | @staticmethod 100 | def predict_on_keras_model(model_dir, input_ids, input_mask, token_type_ids): 101 | max_seq_len = input_ids.shape[-1] 102 | 103 | model = CompareBertActivationsTest.load_keras_model(model_dir, max_seq_len) 104 | 105 | k_res = model.predict([input_ids, token_type_ids]) 106 | return k_res 107 | 108 | def test_compare(self): 109 | 110 | model_dir = tempfile.TemporaryDirectory().name 111 | os.makedirs(model_dir) 112 | save_path = MiniBertFactory.create_mini_bert_weights(model_dir) 113 | tokenizer = bert_tokenization.FullTokenizer(vocab_file=os.path.join(model_dir, "vocab.txt"), do_lower_case=True) 114 | 115 | # prepare input 116 | max_seq_len = 16 117 | input_str = "hello, bert!" 118 | input_tokens = tokenizer.tokenize(input_str) 119 | input_tokens = ["[CLS]"] + input_tokens + ["[SEP]"] 120 | input_ids = tokenizer.convert_tokens_to_ids(input_tokens) 121 | input_ids = input_ids + [0]*(max_seq_len - len(input_tokens)) 122 | input_mask = [0]*len(input_tokens) + [0]*(max_seq_len - len(input_tokens)) # FIXME: input_mask broken - chane to [1]* 123 | token_type_ids = [0]*len(input_tokens) + [0]*(max_seq_len - len(input_tokens)) 124 | 125 | input_ids = np.array([input_ids], dtype=np.int32) 126 | input_mask = np.array([input_mask], dtype=np.int32) 127 | token_type_ids = np.array([token_type_ids], dtype=np.int32) 128 | 129 | print(" tokens:", input_tokens) 130 | print("input_ids:{}/{}:{}".format(len(input_tokens), max_seq_len, input_ids), input_ids.shape, token_type_ids) 131 | 132 | bert_1_seq_out = CompareBertActivationsTest.predict_on_stock_model(model_dir, input_ids, input_mask, token_type_ids) 133 | bert_2_seq_out = CompareBertActivationsTest.predict_on_keras_model(model_dir, input_ids, input_mask, token_type_ids) 134 | 135 | np.set_printoptions(precision=9, threshold=20, linewidth=200, sign="+", floatmode="fixed") 136 | 137 | print("stock bert res", bert_1_seq_out.shape) 138 | print("keras bert res", bert_2_seq_out.shape) 139 | 140 | print("stock bert res:\n {}".format(bert_1_seq_out[0, :2, :10]), bert_1_seq_out.dtype) 141 | print("keras bert_res:\n {}".format(bert_2_seq_out[0, :2, :10]), bert_2_seq_out.dtype) 142 | 143 | abs_diff = np.abs(bert_1_seq_out - bert_2_seq_out).flatten() 144 | print("abs diff:", np.max(abs_diff), np.argmax(abs_diff)) 145 | self.assertTrue(np.allclose(bert_1_seq_out, bert_2_seq_out, atol=1e-6)) 146 | 147 | def test_finetune(self): 148 | 149 | 150 | model_dir = tempfile.TemporaryDirectory().name 151 | os.makedirs(model_dir) 152 | save_path = MiniBertFactory.create_mini_bert_weights(model_dir) 153 | tokenizer = bert_tokenization.FullTokenizer(vocab_file=os.path.join(model_dir, "vocab.txt"), do_lower_case=True) 154 | 155 | # prepare input 156 | max_seq_len = 24 157 | input_str_batch = ["hello, bert!", "how are you doing!"] 158 | 159 | input_ids_batch = [] 160 | token_type_ids_batch = [] 161 | for input_str in input_str_batch: 162 | input_tokens = tokenizer.tokenize(input_str) 163 | input_tokens = ["[CLS]"] + input_tokens + ["[SEP]"] 164 | 165 | print("input_tokens len:", len(input_tokens)) 166 | 167 | input_ids = tokenizer.convert_tokens_to_ids(input_tokens) 168 | input_ids = input_ids + [0]*(max_seq_len - len(input_tokens)) 169 | token_type_ids = [0]*len(input_tokens) + [0]*(max_seq_len - len(input_tokens)) 170 | 171 | input_ids_batch.append(input_ids) 172 | token_type_ids_batch.append(token_type_ids) 173 | 174 | input_ids = np.array(input_ids_batch, dtype=np.int32) 175 | token_type_ids = np.array(token_type_ids_batch, dtype=np.int32) 176 | 177 | print(" tokens:", input_tokens) 178 | print("input_ids:{}/{}:{}".format(len(input_tokens), max_seq_len, input_ids), input_ids.shape, token_type_ids) 179 | 180 | model = CompareBertActivationsTest.load_keras_model(model_dir, max_seq_len) 181 | model.compile(optimizer=keras.optimizers.Adam(), 182 | loss=keras.losses.mean_squared_error) 183 | 184 | pres = model.predict([input_ids, token_type_ids]) # just for fetching the shape of the output 185 | print("pres:", pres.shape) 186 | 187 | model.fit(x=(input_ids, token_type_ids), 188 | y=np.zeros_like(pres), 189 | batch_size=2, 190 | epochs=2) 191 | -------------------------------------------------------------------------------- /tests/test_eager.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # 3 | # created by kpe on 25.Jul.2019 at 13:24 4 | # 5 | 6 | from __future__ import absolute_import, division, print_function 7 | 8 | 9 | import numpy as np 10 | import tensorflow as tf 11 | from tensorflow import keras 12 | import params_flow as pf 13 | 14 | from bert import loader, BertModelLayer 15 | 16 | from .test_common import AbstractBertTest 17 | 18 | 19 | class LoaderTest(AbstractBertTest): 20 | 21 | def setUp(self) -> None: 22 | tf.compat.v1.reset_default_graph() 23 | tf.compat.v1.enable_eager_execution() 24 | print("Eager Execution:", tf.executing_eagerly()) 25 | 26 | def test_coverage_improve(self): 27 | for act in ["relu", "gelu", "linear", None]: 28 | BertModelLayer.get_activation(act) 29 | try: 30 | BertModelLayer.get_activation("None") 31 | except ValueError: 32 | pass 33 | 34 | def test_eager_loading(self): 35 | print("Eager Execution:", tf.executing_eagerly()) 36 | 37 | # a temporal mini bert model_dir 38 | model_dir = self.create_mini_bert_weights() 39 | 40 | bert_params = loader.params_from_pretrained_ckpt(model_dir) 41 | bert_params.adapter_size = 32 42 | bert = BertModelLayer.from_params(bert_params, name="bert") 43 | 44 | model = keras.models.Sequential([ 45 | keras.layers.InputLayer(input_shape=(32,)), 46 | bert, 47 | keras.layers.Lambda(lambda x: x[:, 0, :]), 48 | keras.layers.Dense(2) 49 | ]) 50 | 51 | model.build(input_shape=(None, 128)) 52 | model.compile(optimizer=keras.optimizers.Adam(), 53 | loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True), 54 | metrics=[keras.metrics.SparseCategoricalAccuracy(name="acc")], 55 | run_eagerly=True) 56 | 57 | loader.load_stock_weights(bert, model_dir) 58 | 59 | model.summary() 60 | 61 | def test_concat(self): 62 | model_dir = self.create_mini_bert_weights() 63 | 64 | bert_params = loader.params_from_pretrained_ckpt(model_dir) 65 | bert_params.adapter_size = 32 66 | bert = BertModelLayer.from_params(bert_params, name="bert") 67 | 68 | max_seq_len = 4 69 | 70 | model = keras.models.Sequential([ 71 | keras.layers.InputLayer(input_shape=(max_seq_len,)), 72 | bert, 73 | keras.layers.TimeDistributed(keras.layers.Dense(bert_params.hidden_size)), 74 | keras.layers.TimeDistributed(keras.layers.LayerNormalization()), 75 | keras.layers.TimeDistributed(keras.layers.Activation("tanh")), 76 | 77 | pf.Concat([ 78 | keras.layers.Lambda(lambda x: tf.math.reduce_max(x, axis=1)), # GlobalMaxPooling1D 79 | keras.layers.Lambda(lambda x: tf.math.reduce_mean(x, axis=1)), # GlobalAvgPooling1 80 | ]), 81 | 82 | keras.layers.Dense(units=bert_params.hidden_size), 83 | keras.layers.Activation("tanh"), 84 | 85 | keras.layers.Dense(units=2) 86 | ]) 87 | 88 | model.build(input_shape=(None, max_seq_len)) 89 | model.summary() 90 | 91 | model.compile(optimizer=keras.optimizers.Adam(), 92 | loss=[keras.losses.SparseCategoricalCrossentropy(from_logits=True)], 93 | metrics=[keras.metrics.SparseCategoricalAccuracy()], 94 | run_eagerly = True) 95 | 96 | loader.load_stock_weights(bert, model_dir) 97 | 98 | model.summary() 99 | -------------------------------------------------------------------------------- /tests/test_extend_segments.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # 3 | # created by kpe on 02.Sep.2019 at 11:57 4 | # 5 | 6 | from __future__ import absolute_import, division, print_function 7 | 8 | import unittest 9 | 10 | import os 11 | import re 12 | import tempfile 13 | 14 | import bert 15 | 16 | import numpy as np 17 | import tensorflow as tf 18 | from tensorflow import keras 19 | 20 | from .test_common import AbstractBertTest, MiniBertFactory 21 | 22 | #tf.enable_eager_execution() 23 | #tf.disable_eager_execution() 24 | 25 | 26 | class TestExtendSegmentVocab(AbstractBertTest): 27 | 28 | def setUp(self) -> None: 29 | tf.compat.v1.reset_default_graph() 30 | tf.compat.v1.enable_eager_execution() 31 | print("Eager Execution:", tf.executing_eagerly()) 32 | 33 | def test_extend_pretrained_segments(self): 34 | 35 | model_dir = tempfile.TemporaryDirectory().name 36 | os.makedirs(model_dir) 37 | save_path = MiniBertFactory.create_mini_bert_weights(model_dir) 38 | tokenizer = bert.bert_tokenization.FullTokenizer(vocab_file=os.path.join(model_dir, "vocab.txt"), do_lower_case=True) 39 | 40 | ckpt_dir = os.path.dirname(save_path) 41 | bert_params = bert.params_from_pretrained_ckpt(ckpt_dir) 42 | 43 | self.assertEqual(bert_params.token_type_vocab_size, 2) 44 | bert_params.token_type_vocab_size = 4 45 | 46 | l_bert = bert.BertModelLayer.from_params(bert_params) 47 | 48 | # we dummy call the layer once in order to instantiate the weights 49 | l_bert([np.array([[1, 1, 0]]), 50 | np.array([[1, 0, 0]])])#, mask=[[True, True, False]]) 51 | 52 | # 53 | # - load the weights from a pre-trained model, 54 | # - expect a mismatch for the token_type embeddings 55 | # - use the segment/token type id=0 embedding for the missing token types 56 | # 57 | mismatched = bert.load_stock_weights(l_bert, save_path) 58 | 59 | self.assertEqual(1, len(mismatched), "token_type embeddings should have mismatched shape") 60 | 61 | for weight, value in mismatched: 62 | if re.match("(.*)embeddings/token_type_embeddings/embeddings:0", weight.name): 63 | seg0_emb = value[:1, :] 64 | new_segment_embeddings = np.repeat(seg0_emb, (weight.shape[0]-value.shape[0]), axis=0) 65 | new_value = np.concatenate([value, new_segment_embeddings], axis=0) 66 | keras.backend.batch_set_value([(weight, new_value)]) 67 | 68 | tte = l_bert.embeddings_layer.token_type_embeddings_layer.weights[0] 69 | 70 | if not tf.executing_eagerly(): 71 | with tf.keras.backend.get_session() as sess: 72 | tte, = sess.run((tte, )) 73 | 74 | self.assertTrue(np.allclose(seg0_emb, tte[0], 1e-6)) 75 | self.assertFalse(np.allclose(seg0_emb, tte[1], 1e-6)) 76 | self.assertTrue(np.allclose(seg0_emb, tte[2], 1e-6)) 77 | self.assertTrue(np.allclose(seg0_emb, tte[3], 1e-6)) 78 | 79 | bert_params.token_type_vocab_size = 4 80 | print("token_type_vocab_size", bert_params.token_type_vocab_size) 81 | print(l_bert.embeddings_layer.trainable_weights[1]) 82 | 83 | 84 | -------------------------------------------------------------------------------- /tests/test_extend_tokens.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # 3 | # created by kpe on 04.11.2019 at 2:07 PM 4 | # 5 | 6 | from __future__ import division, absolute_import, print_function 7 | 8 | import unittest 9 | 10 | import os 11 | import tempfile 12 | 13 | import numpy as np 14 | import tensorflow as tf 15 | 16 | import bert 17 | 18 | from .test_common import AbstractBertTest, MiniBertFactory 19 | 20 | 21 | class TestExtendSegmentVocab(AbstractBertTest): 22 | 23 | def setUp(self) -> None: 24 | tf.compat.v1.reset_default_graph() 25 | tf.compat.v1.enable_eager_execution() 26 | print("Eager Execution:", tf.executing_eagerly()) 27 | 28 | def test_extend_pretrained_tokens(self): 29 | model_dir = tempfile.TemporaryDirectory().name 30 | os.makedirs(model_dir) 31 | save_path = MiniBertFactory.create_mini_bert_weights(model_dir) 32 | tokenizer = bert.bert_tokenization.FullTokenizer(vocab_file=os.path.join(model_dir, "vocab.txt"), do_lower_case=True) 33 | 34 | ckpt_dir = os.path.dirname(save_path) 35 | bert_params = bert.params_from_pretrained_ckpt(ckpt_dir) 36 | 37 | self.assertEqual(bert_params.token_type_vocab_size, 2) 38 | bert_params.extra_tokens_vocab_size = 3 39 | 40 | l_bert = bert.BertModelLayer.from_params(bert_params) 41 | 42 | # we dummy call the layer once in order to instantiate the weights 43 | l_bert([np.array([[1, 1, 0]]), np.array([[1, 0, 0]])], mask=[[True, True, False]]) 44 | 45 | mismatched = bert.load_stock_weights(l_bert, save_path) 46 | self.assertEqual(0, len(mismatched), "token_type embeddings should have mismatched shape") 47 | 48 | l_bert([np.array([[1, -3, 0]]), np.array([[1, 0, 0]])], mask=[[True, True, False]]) 49 | --------------------------------------------------------------------------------