├── requirements.txt ├── keras_albert_model ├── __init__.py ├── albert_test.py └── albert.py ├── LICENSE ├── README.md ├── setup.py └── .gitignore /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | keras-bert 3 | keras-adaptive-softmax 4 | -------------------------------------------------------------------------------- /keras_albert_model/__init__.py: -------------------------------------------------------------------------------- 1 | from .albert import * 2 | 3 | __version__ = '0.1.0' 4 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 TinkerMob 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ALBERT 2 | 3 | Unofficial implementation of [ALBERT](https://arxiv.org/pdf/1909.11942.pdf). 4 | 5 | ## Install 6 | 7 | ```bash 8 | python setup.py install 9 | ``` 10 | 11 | or 12 | 13 | ```bash 14 | # Install latest version: 15 | pip install git+https://github.com/TinkerMob/keras_albert_model.git 16 | 17 | # Install specific version: 18 | pip install git+https://github.com/TinkerMob/keras_albert_model.git@v0.1.0 19 | ``` 20 | 21 | Current versions of dependencies: 22 | 23 | * keras==2.3.0 24 | * tensorflow==2.0.0 25 | 26 | ## Build model 27 | 28 | ```python 29 | from keras_albert_model import build_albert 30 | 31 | model = build_albert(token_num=30000, training=True) 32 | model.summary() 33 | ``` 34 | 35 | ## Load checkpoint 36 | 37 | You can load pretrained model provided by [brightmart/albert_zh](https://github.com/brightmart/albert_zh): 38 | 39 | ```python 40 | from keras_albert_model import load_brightmart_albert_zh_checkpoint 41 | 42 | model = load_brightmart_albert_zh_checkpoint('path_to_checkpoint_folder') 43 | model.summary() 44 | ``` 45 | 46 | ## Select output layers 47 | 48 | ```python 49 | from keras_albert_model import build_albert 50 | 51 | model = build_albert(token_num=30000, training=False, output_layers=[-1, -2, -3, -4]) 52 | model.summary() 53 | ``` 54 | -------------------------------------------------------------------------------- /keras_albert_model/albert_test.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | from .albert import keras 3 | from .albert import get_custom_objects, build_albert 4 | 5 | 6 | class TestALBERT(unittest.TestCase): 7 | 8 | def test_build_train(self): 9 | model = build_albert(333) 10 | model.compile('adam', 'sparse_categorical_crossentropy') 11 | model.save('train.h5') 12 | model = keras.models.load_model('train.h5', 13 | custom_objects=get_custom_objects()) 14 | model.summary() 15 | 16 | def test_build_infer(self): 17 | model = keras.models.Model(*build_albert(345, training=False)) 18 | model.compile('adam', 'sparse_categorical_crossentropy') 19 | model.save('infer.h5') 20 | model = keras.models.load_model('infer.h5', 21 | custom_objects=get_custom_objects()) 22 | model.summary() 23 | 24 | def test_build_select_output_layer(self): 25 | model = build_albert(346, output_layers=-10, training=False) 26 | model.compile('adam', 'sparse_categorical_crossentropy') 27 | model.summary() 28 | 29 | def test_build_output_layers(self): 30 | model = build_albert(346, 31 | output_layers=[-1, -2, -3, -4], 32 | training=False) 33 | model.compile('adam', 'sparse_categorical_crossentropy') 34 | model.summary() 35 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | from setuptools import setup, find_packages 4 | 5 | current_path = os.path.abspath(os.path.dirname(__file__)) 6 | 7 | 8 | def read_file(*parts): 9 | with open(os.path.join(current_path, *parts)) as reader: 10 | return reader.read() 11 | 12 | 13 | def get_requirements(*parts): 14 | with open(os.path.join(current_path, *parts)) as reader: 15 | return list(map(lambda x: x.strip(), reader.readlines())) 16 | 17 | 18 | def find_version(*file_paths): 19 | version_file = read_file(*file_paths) 20 | version_match = re.search( 21 | r"^__version__ = ['\"]([^'\"]*)['\"]", version_file, re.M) 22 | if version_match: 23 | return version_match.group(1) 24 | raise RuntimeError('Unable to find version string.') 25 | 26 | 27 | setup( 28 | name='keras_albert_model', 29 | version=find_version('keras_albert_model', '__init__.py'), 30 | packages=find_packages(), 31 | url='https://github.com/TinkerMob/keras_albert_model', 32 | license='MIT', 33 | author='keras_albert_model', 34 | author_email='TinkerMob@users.noreply.github.com', 35 | description='ALBERT with Keras', 36 | long_description=read_file('README.md'), 37 | long_description_content_type='text/markdown', 38 | install_requires=get_requirements('requirements.txt'), 39 | classifiers=( 40 | "Programming Language :: Python :: 3", 41 | "License :: OSI Approved :: MIT License", 42 | "Operating System :: OS Independent", 43 | ), 44 | ) 45 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | -------------------------------------------------------------------------------- /keras_albert_model/albert.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | 4 | import tensorflow as tf 5 | import numpy as np 6 | 7 | from keras_bert.backend import keras 8 | from keras_bert.activations.gelu_fallback import gelu 9 | from keras_bert import get_custom_objects as get_bert_custom_objects 10 | from keras_bert.layers import Masked, Extract 11 | from keras_pos_embd import PositionEmbedding 12 | from keras_layer_normalization import LayerNormalization 13 | from keras_multi_head import MultiHeadAttention 14 | from keras_position_wise_feed_forward import FeedForward 15 | from keras_adaptive_softmax import AdaptiveEmbedding, AdaptiveSoftmax 16 | 17 | 18 | __all__ = [ 19 | 'get_custom_objects', 'build_albert', 20 | 'load_brightmart_albert_zh_checkpoint', 21 | ] 22 | 23 | 24 | def get_custom_objects(): 25 | custom_objects = get_bert_custom_objects() 26 | custom_objects['AdaptiveEmbedding'] = AdaptiveEmbedding 27 | custom_objects['AdaptiveSoftmax'] = AdaptiveSoftmax 28 | return custom_objects 29 | 30 | 31 | def build_albert(token_num, 32 | pos_num=512, 33 | seq_len=512, 34 | embed_dim=128, 35 | hidden_dim=768, 36 | transformer_num=12, 37 | head_num=12, 38 | feed_forward_dim=3072, 39 | dropout_rate=0.1, 40 | attention_activation=None, 41 | feed_forward_activation='gelu', 42 | training=True, 43 | trainable=None, 44 | output_layers=None): 45 | """Get ALBERT model. 46 | 47 | See: https://arxiv.org/pdf/1909.11942.pdf 48 | 49 | :param token_num: Number of tokens. 50 | :param pos_num: Maximum position. 51 | :param seq_len: Maximum length of the input sequence or None. 52 | :param embed_dim: Dimensions of embeddings. 53 | :param hidden_dim: Dimensions of hidden layers. 54 | :param transformer_num: Number of transformers. 55 | :param head_num: Number of heads in multi-head attention 56 | in each transformer. 57 | :param feed_forward_dim: Dimension of the feed forward layer 58 | in each transformer. 59 | :param dropout_rate: Dropout rate. 60 | :param attention_activation: Activation for attention layers. 61 | :param feed_forward_activation: Activation for feed-forward layers. 62 | :param training: A built model with MLM and NSP outputs will be returned 63 | if it is `True`, otherwise the input layers and the last 64 | feature extraction layer will be returned. 65 | :param trainable: Whether the model is trainable. 66 | :param output_layers: A list of indices of output layers. 67 | """ 68 | if attention_activation == 'gelu': 69 | attention_activation = gelu 70 | if feed_forward_activation == 'gelu': 71 | feed_forward_activation = gelu 72 | if trainable is None: 73 | trainable = training 74 | 75 | def _trainable(_layer): 76 | if isinstance(trainable, (list, tuple, set)): 77 | for prefix in trainable: 78 | if _layer.name.startswith(prefix): 79 | return True 80 | return False 81 | return trainable 82 | 83 | # Build inputs 84 | input_token = keras.layers.Input(shape=(seq_len,), name='Input-Token') 85 | input_segment = keras.layers.Input(shape=(seq_len,), name='Input-Segment') 86 | inputs = [input_token, input_segment] 87 | 88 | # Build embeddings 89 | embed_token, embed_weights, embed_projection = AdaptiveEmbedding( 90 | input_dim=token_num, 91 | output_dim=hidden_dim, 92 | embed_dim=embed_dim, 93 | mask_zero=True, 94 | trainable=trainable, 95 | return_embeddings=True, 96 | return_projections=True, 97 | name='Embed-Token', 98 | )(input_token) 99 | embed_segment = keras.layers.Embedding( 100 | input_dim=2, 101 | output_dim=hidden_dim, 102 | trainable=trainable, 103 | name='Embed-Segment', 104 | )(input_segment) 105 | embed_layer = keras.layers.Add(name='Embed-Token-Segment')( 106 | [embed_token, embed_segment]) 107 | embed_layer = PositionEmbedding( 108 | input_dim=pos_num, 109 | output_dim=hidden_dim, 110 | mode=PositionEmbedding.MODE_ADD, 111 | trainable=trainable, 112 | name='Embedding-Position', 113 | )(embed_layer) 114 | 115 | if dropout_rate > 0.0: 116 | dropout_layer = keras.layers.Dropout( 117 | rate=dropout_rate, 118 | name='Embedding-Dropout', 119 | )(embed_layer) 120 | else: 121 | dropout_layer = embed_layer 122 | embed_layer = LayerNormalization( 123 | trainable=trainable, 124 | name='Embedding-Norm', 125 | )(dropout_layer) 126 | 127 | # Build shared transformer 128 | attention_layer = MultiHeadAttention( 129 | head_num=head_num, 130 | activation=attention_activation, 131 | name='Attention', 132 | ) 133 | attention_normal = LayerNormalization(name='Attention-Normal') 134 | feed_forward_layer = FeedForward( 135 | units=feed_forward_dim, 136 | activation=feed_forward_activation, 137 | name='Feed-Forward' 138 | ) 139 | feed_forward_normal = LayerNormalization(name='Feed-Forward-Normal') 140 | 141 | transformed = embed_layer 142 | transformed_layers = [] 143 | for i in range(transformer_num): 144 | attention_input = transformed 145 | transformed = attention_layer(transformed) 146 | if dropout_rate > 0.0: 147 | transformed = keras.layers.Dropout( 148 | rate=dropout_rate, 149 | name='Attention-Dropout-{}'.format(i + 1), 150 | )(transformed) 151 | transformed = keras.layers.Add( 152 | name='Attention-Add-{}'.format(i + 1), 153 | )([attention_input, transformed]) 154 | transformed = attention_normal(transformed) 155 | 156 | feed_forward_input = transformed 157 | transformed = feed_forward_layer(transformed) 158 | if dropout_rate > 0.0: 159 | transformed = keras.layers.Dropout( 160 | rate=dropout_rate, 161 | name='Feed-Forward-Dropout-{}'.format(i + 1), 162 | )(transformed) 163 | transformed = keras.layers.Add( 164 | name='Feed-Forward-Add-{}'.format(i + 1), 165 | )([feed_forward_input, transformed]) 166 | transformed = feed_forward_normal(transformed) 167 | transformed_layers.append(transformed) 168 | 169 | if training: 170 | # Build tasks 171 | mlm_dense_layer = keras.layers.Dense( 172 | units=hidden_dim, 173 | activation=feed_forward_activation, 174 | name='MLM-Dense', 175 | )(transformed) 176 | mlm_norm_layer = LayerNormalization(name='MLM-Norm')(mlm_dense_layer) 177 | mlm_pred_layer = AdaptiveSoftmax( 178 | input_dim=hidden_dim, 179 | output_dim=token_num, 180 | embed_dim=embed_dim, 181 | bind_embeddings=True, 182 | bind_projections=True, 183 | name='MLM-Sim', 184 | )([mlm_norm_layer, embed_weights, embed_projection]) 185 | masked_layer = Masked(name='MLM')([mlm_pred_layer, inputs[-1]]) 186 | extract_layer = Extract(index=0, name='Extract')(transformed) 187 | nsp_dense_layer = keras.layers.Dense( 188 | units=hidden_dim, 189 | activation='tanh', 190 | name='SOP-Dense', 191 | )(extract_layer) 192 | nsp_pred_layer = keras.layers.Dense( 193 | units=2, 194 | activation='softmax', 195 | name='SOP', 196 | )(nsp_dense_layer) 197 | model = keras.models.Model( 198 | inputs=inputs, 199 | outputs=[masked_layer, nsp_pred_layer]) 200 | for layer in model.layers: 201 | layer.trainable = _trainable(layer) 202 | return model 203 | if output_layers is not None: 204 | if isinstance(output_layers, list): 205 | output_layers = [ 206 | transformed_layers[index] for index in output_layers] 207 | output = keras.layers.Concatenate( 208 | name='Output', 209 | )(output_layers) 210 | else: 211 | output = transformed_layers[output_layers] 212 | model = keras.models.Model(inputs=inputs, outputs=output) 213 | return model 214 | model = keras.models.Model(inputs=inputs, outputs=transformed) 215 | for layer in model.layers: 216 | layer.trainable = _trainable(layer) 217 | return inputs, transformed 218 | 219 | 220 | def load_brightmart_albert_zh_checkpoint(checkpoint_path, **kwargs): 221 | """Load checkpoint from https://github.com/brightmart/albert_zh 222 | 223 | :param checkpoint_path: path to checkpoint folder. 224 | :param kwargs: arguments for albert model. 225 | :return: 226 | """ 227 | config = {} 228 | for file_name in os.listdir(checkpoint_path): 229 | if file_name.startswith('albert_config'): 230 | with open(os.path.join(checkpoint_path, file_name)) as reader: 231 | config = json.load(reader) 232 | break 233 | 234 | def _set_if_not_existed(key, value): 235 | if key not in kwargs: 236 | kwargs[key] = value 237 | 238 | _set_if_not_existed('training', True) 239 | training = kwargs['training'] 240 | _set_if_not_existed('token_num', config['vocab_size']) 241 | _set_if_not_existed('pos_num', config['max_position_embeddings']) 242 | _set_if_not_existed('seq_len', config['max_position_embeddings']) 243 | _set_if_not_existed('embed_dim', config['embedding_size']) 244 | _set_if_not_existed('hidden_dim', config['hidden_size']) 245 | _set_if_not_existed('transformer_num', config['num_hidden_layers']) 246 | _set_if_not_existed('head_num', config['num_attention_heads']) 247 | _set_if_not_existed('feed_forward_dim', config['intermediate_size']) 248 | _set_if_not_existed('dropout_rate', config['hidden_dropout_prob']) 249 | _set_if_not_existed('feed_forward_activation', config['hidden_act']) 250 | 251 | model = build_albert(**kwargs) 252 | if not training: 253 | inputs, outputs = model 254 | model = keras.models.Model(inputs, outputs) 255 | 256 | def _checkpoint_loader(checkpoint_file): 257 | def _loader(name): 258 | return tf.train.load_variable(checkpoint_file, name) 259 | return _loader 260 | 261 | loader = _checkpoint_loader( 262 | os.path.join(checkpoint_path, 'albert_model.ckpt')) 263 | 264 | model.get_layer(name='Embed-Token').set_weights([ 265 | loader('bert/embeddings/word_embeddings'), 266 | loader('bert/embeddings/word_embeddings_2'), 267 | ]) 268 | model.get_layer(name='Embed-Segment').set_weights([ 269 | loader('bert/embeddings/token_type_embeddings'), 270 | ]) 271 | model.get_layer(name='Embedding-Position').set_weights([ 272 | loader('bert/embeddings/position_embeddings'), 273 | ]) 274 | model.get_layer(name='Embedding-Norm').set_weights([ 275 | loader('bert/embeddings/LayerNorm/gamma'), 276 | loader('bert/embeddings/LayerNorm/beta'), 277 | ]) 278 | 279 | model.get_layer(name='Attention').set_weights([ 280 | loader('bert/encoder/layer_shared/attention/self/query/kernel'), 281 | loader('bert/encoder/layer_shared/attention/self/query/bias'), 282 | loader('bert/encoder/layer_shared/attention/self/key/kernel'), 283 | loader('bert/encoder/layer_shared/attention/self/key/bias'), 284 | loader('bert/encoder/layer_shared/attention/self/value/kernel'), 285 | loader('bert/encoder/layer_shared/attention/self/value/bias'), 286 | loader('bert/encoder/layer_shared/attention/output/dense/kernel'), 287 | loader('bert/encoder/layer_shared/attention/output/dense/bias'), 288 | ]) 289 | model.get_layer(name='Attention-Normal').set_weights([ 290 | loader('bert/encoder/layer_shared/attention/output/LayerNorm/gamma'), 291 | loader('bert/encoder/layer_shared/attention/output/LayerNorm/beta'), 292 | ]) 293 | model.get_layer(name='Feed-Forward').set_weights([ 294 | loader('bert/encoder/layer_shared/intermediate/dense/kernel'), 295 | loader('bert/encoder/layer_shared/intermediate/dense/bias'), 296 | loader('bert/encoder/layer_shared/output/dense/kernel'), 297 | loader('bert/encoder/layer_shared/output/dense/bias'), 298 | ]) 299 | model.get_layer(name='Feed-Forward-Normal').set_weights([ 300 | loader('bert/encoder/layer_shared/output/LayerNorm/gamma'), 301 | loader('bert/encoder/layer_shared/output/LayerNorm/beta'), 302 | ]) 303 | 304 | if training: 305 | model.get_layer(name='MLM-Dense').set_weights([ 306 | loader('cls/predictions/transform/dense/kernel'), 307 | loader('cls/predictions/transform/dense/bias'), 308 | ]) 309 | model.get_layer(name='MLM-Norm').set_weights([ 310 | loader('cls/predictions/transform/LayerNorm/gamma'), 311 | loader('cls/predictions/transform/LayerNorm/beta'), 312 | ]) 313 | model.get_layer(name='MLM-Sim').set_weights([ 314 | loader('cls/predictions/output_bias'), 315 | ]) 316 | 317 | model.get_layer(name='SOP-Dense').set_weights([ 318 | loader('bert/pooler/dense/kernel'), 319 | loader('bert/pooler/dense/bias'), 320 | ]) 321 | model.get_layer(name='SOP').set_weights([ 322 | np.transpose(loader('cls/seq_relationship/output_weights')), 323 | loader('cls/seq_relationship/output_bias'), 324 | ]) 325 | 326 | return model 327 | --------------------------------------------------------------------------------