├── utils ├── __init__.py └── tf_utils.py ├── infer_singlesent.py ├── requirements.txt ├── .gitignore ├── README.md ├── model.py ├── bert.py ├── optimization.py ├── tokenization.py ├── run_ner.py └── bert_modeling.py /utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /infer_singlesent.py: -------------------------------------------------------------------------------- 1 | from bert import Ner 2 | 3 | model = Ner("model_sep20/") 4 | 5 | output = model.predict("Steve went to Paris") 6 | 7 | print(output) -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | tensorflow==2.2.0 2 | # metrics 3 | seqeval==0.0.5 4 | # tokeniztion 5 | nltk==3.4.5 6 | # for rest api 7 | Flask==1.1.1 8 | Flask-Cors==3.0.8 9 | # progress bar 10 | fastprogress==0.1.21 11 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venve/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | 106 | # vscode 107 | .vscode/ 108 | .DS_Store 109 | # weights 110 | bert-*/ 111 | *.h5 112 | *.json 113 | 114 | *.zip 115 | 116 | # output files 117 | model*/ 118 | model_*/ 119 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # BERT NER 2 | 3 | Use google BERT to do CoNLL-2003 NER ! 4 | 5 | Train model using Python and TensorFlow 2.0 6 | 7 | 8 | # Requirements 9 | 10 | - `python3` 11 | - `pip3 install -r requirements.txt` 12 | 13 | ### Download Pretrained Models from Tensorflow offical models 14 | - [bert-base-cased](https://storage.googleapis.com/cloud-tpu-checkpoints/bert/tf_20/cased_L-12_H-768_A-12.tar.gz) unzip into `bert-base-cased` 15 | 16 | code for pre-trained bert from [tensorflow-offical-models](https://github.com/tensorflow/models/tree/master/official/nlp) 17 | 18 | # Run 19 | 20 | ## Single GPU 21 | 22 | To evaluate on valid dataset: 23 | 24 | `python run_ner.py --data_dir=data/ --bert_model=bert-base-cased --output_dir=model_sep20 --max_seq_length=128 --do_train --num_train_epochs 3 --do_eval --eval_on dev` 25 | 26 | To evaluate on test dataset: 27 | 28 | `python run_ner.py --data_dir=data/ --bert_model=bert-base-cased --output_dir=model_sep20 --max_seq_length=128 --num_train_epochs 3 --do_eval --eval_on test` 29 | 30 | # Result 31 | 32 | ## BERT-BASE 33 | 34 | ### Validation Data 35 | ``` 36 | precision recall f1-score support 37 | 38 | MISC 0.8883 0.9143 0.9011 922 39 | PER 0.9693 0.9783 0.9738 1842 40 | LOC 0.9713 0.9575 0.9644 1837 41 | ORG 0.9148 0.9292 0.9219 1341 42 | 43 | micro avg 0.9440 0.9509 0.9474 5942 44 | macro avg 0.9451 0.9509 0.9479 5942 45 | ``` 46 | ### Test Data 47 | ``` 48 | precision recall f1-score support 49 | 50 | LOC 0.9325 0.9353 0.9339 1668 51 | PER 0.9546 0.9629 0.9587 1617 52 | ORG 0.8892 0.9031 0.8961 1661 53 | MISC 0.7770 0.8291 0.8022 702 54 | 55 | micro avg 0.9054 0.9205 0.9129 5648 56 | macro avg 0.9068 0.9205 0.9135 5648 57 | ``` 58 | 59 | 60 | 61 | # Inference 62 | 63 | Refer infer_singlesent.py 64 | 65 | ```python 66 | from bert import Ner 67 | 68 | model = Ner("model_sep20/") 69 | 70 | output = model.predict("Steve went to Paris") 71 | 72 | print(output) 73 | ''' 74 | [ 75 | { 76 | "confidence": 0.99796665, 77 | "tag": "B-PER", 78 | "word": "Steve" 79 | }, 80 | { 81 | "confidence": 0.99980587, 82 | "tag": "O", 83 | "word": "went" 84 | }, 85 | { 86 | "confidence": 0.99981683, 87 | "tag": "O", 88 | "word": "to" 89 | }, 90 | { 91 | "confidence": 0.9993082, 92 | "tag": "B-LOC", 93 | "word": "Paris" 94 | } 95 | ] 96 | ''' 97 | ``` 98 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import, division, print_function 2 | 3 | import os 4 | 5 | import tensorflow as tf 6 | 7 | from bert_modeling import BertConfig, BertModel 8 | from utils import tf_utils 9 | 10 | 11 | class BertNer(tf.keras.Model): 12 | 13 | def __init__(self, bert_model,float_type, num_labels, max_seq_length, final_layer_initializer=None): 14 | ''' 15 | bert_model : string or dict 16 | string: bert pretrained model directory with bert_config.json and bert_model.ckpt 17 | dict: bert model config , pretrained weights are not restored 18 | float_type : tf.float32 19 | num_labels : num of tags in NER task 20 | max_seq_length : max_seq_length of tokens 21 | final_layer_initializer : default: tf.keras.initializers.TruncatedNormal 22 | ''' 23 | super(BertNer, self).__init__() 24 | 25 | input_word_ids = tf.keras.layers.Input(shape=(max_seq_length,), dtype=tf.int32, name='input_word_ids') 26 | input_mask = tf.keras.layers.Input(shape=(max_seq_length,), dtype=tf.int32, name='input_mask') 27 | input_type_ids = tf.keras.layers.Input(shape=(max_seq_length,), dtype=tf.int32, name='input_type_ids') 28 | # bert_model is str for do_train and bert_model is dict for do_eval 29 | if type(bert_model) == str: 30 | bert_config = BertConfig.from_json_file(os.path.join(bert_model,"bert_config.json")) 31 | elif type(bert_model) == dict: 32 | bert_config = BertConfig.from_dict(bert_model) 33 | 34 | bert_layer = BertModel(config=bert_config,float_type=float_type) 35 | _, sequence_output = bert_layer(input_word_ids, input_mask,input_type_ids) 36 | 37 | self.bert = tf.keras.Model(inputs=[input_word_ids, input_mask, input_type_ids], outputs=[sequence_output]) 38 | 39 | if type(bert_model) == str: 40 | init_checkpoint = os.path.join(bert_model,"bert_model.ckpt") 41 | checkpoint = tf.train.Checkpoint(model=self.bert) 42 | checkpoint.restore(init_checkpoint).assert_existing_objects_matched() 43 | 44 | self.dropout = tf.keras.layers.Dropout( 45 | rate=bert_config.hidden_dropout_prob) 46 | 47 | if final_layer_initializer is not None: 48 | initializer = final_layer_initializer 49 | else: 50 | initializer = tf.keras.initializers.TruncatedNormal(stddev=bert_config.initializer_range) 51 | 52 | self.classifier = tf.keras.layers.Dense( 53 | num_labels, kernel_initializer=initializer, activation='softmax', name='output', dtype=float_type) 54 | 55 | 56 | def call(self, input_word_ids,input_mask=None,input_type_ids=None,valid_ids=None, **kwargs): 57 | sequence_output = self.bert([input_word_ids, input_mask, input_type_ids],**kwargs) 58 | 59 | valid_output = [] 60 | for i in range(sequence_output.shape[0]): 61 | r = 0 62 | temp = [] 63 | for j in range(sequence_output.shape[1]): 64 | if valid_ids[i][j] == 1: 65 | temp = temp + [sequence_output[i][j]] 66 | else: 67 | r += 1 68 | temp = temp + r * [tf.zeros_like(sequence_output[i][j])] 69 | valid_output = valid_output + temp 70 | valid_output = tf.reshape(tf.stack(valid_output),sequence_output.shape) 71 | sequence_output = self.dropout( 72 | valid_output, training=kwargs.get('training', False)) 73 | logits = self.classifier(sequence_output) 74 | return logits 75 | 76 | -------------------------------------------------------------------------------- /bert.py: -------------------------------------------------------------------------------- 1 | """BERT NER Inference.""" 2 | 3 | from __future__ import absolute_import, division, print_function 4 | 5 | import json 6 | import os 7 | 8 | import tensorflow as tf 9 | from nltk import word_tokenize 10 | 11 | from model import BertNer 12 | from tokenization import FullTokenizer 13 | 14 | 15 | class Ner: 16 | 17 | def __init__(self,model_dir: str): 18 | self.model , self.tokenizer, self.model_config = self.load_model(model_dir) 19 | self.label_map = self.model_config["label_map"] 20 | self.max_seq_length = self.model_config["max_seq_length"] 21 | self.label_map = {int(k):v for k,v in self.label_map.items()} 22 | 23 | def load_model(self, model_dir: str, model_config: str = "model_config.json"): 24 | model_config = os.path.join(model_dir,model_config) 25 | model_config = json.load(open(model_config)) 26 | bert_config = json.load(open(os.path.join(model_dir,"bert_config.json"))) 27 | model = BertNer(bert_config, tf.float32, model_config['num_labels'], model_config['max_seq_length']) 28 | ids = tf.ones((1,128),dtype=tf.int64) 29 | _ = model(ids,ids,ids,ids, training=False) 30 | model.load_weights(os.path.join(model_dir,"model.h5")) 31 | voacb = os.path.join(model_dir, "vocab.txt") 32 | tokenizer = FullTokenizer(vocab_file=voacb, do_lower_case=model_config["do_lower"]) 33 | return model, tokenizer, model_config 34 | 35 | def tokenize(self, text: str): 36 | """ tokenize input""" 37 | words = word_tokenize(text) 38 | tokens = [] 39 | valid_positions = [] 40 | for i,word in enumerate(words): 41 | token = self.tokenizer.tokenize(word) 42 | tokens.extend(token) 43 | for i in range(len(token)): 44 | if i == 0: 45 | valid_positions.append(1) 46 | else: 47 | valid_positions.append(0) 48 | return tokens, valid_positions 49 | 50 | def preprocess(self, text: str): 51 | """ preprocess """ 52 | tokens, valid_positions = self.tokenize(text) 53 | ## insert "[CLS]" 54 | tokens.insert(0,"[CLS]") 55 | valid_positions.insert(0,1) 56 | ## insert "[SEP]" 57 | tokens.append("[SEP]") 58 | valid_positions.append(1) 59 | segment_ids = [] 60 | for i in range(len(tokens)): 61 | segment_ids.append(0) 62 | input_ids = self.tokenizer.convert_tokens_to_ids(tokens) 63 | input_mask = [1] * len(input_ids) 64 | while len(input_ids) < self.max_seq_length: 65 | input_ids.append(0) 66 | input_mask.append(0) 67 | segment_ids.append(0) 68 | valid_positions.append(0) 69 | return input_ids,input_mask,segment_ids,valid_positions 70 | 71 | def predict(self, text: str): 72 | input_ids,input_mask,segment_ids,valid_ids = self.preprocess(text) 73 | input_ids = tf.Variable([input_ids],dtype=tf.int64) 74 | input_mask = tf.Variable([input_mask],dtype=tf.int64) 75 | segment_ids = tf.Variable([segment_ids],dtype=tf.int64) 76 | valid_ids = tf.Variable([valid_ids],dtype=tf.int64) 77 | logits = self.model(input_ids, segment_ids, input_mask,valid_ids) 78 | logits_label = tf.argmax(logits,axis=2) 79 | logits_label = logits_label.numpy().tolist()[0] 80 | 81 | logits_confidence = [values[label].numpy() for values,label in zip(logits[0],logits_label)] 82 | 83 | logits = [] 84 | pos = 0 85 | for index,mask in enumerate(valid_ids[0]): 86 | if index == 0: 87 | continue 88 | if mask == 1: 89 | logits.append((logits_label[index-pos],logits_confidence[index-pos])) 90 | else: 91 | pos += 1 92 | logits.pop() 93 | 94 | labels = [(self.label_map[label],confidence) for label,confidence in logits] 95 | words = word_tokenize(text) 96 | assert len(labels) == len(words) 97 | output = [{"word":word,"tag":label,"confidence":confidence} for word,(label,confidence) in zip(words,labels)] 98 | return output 99 | -------------------------------------------------------------------------------- /utils/tf_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Common TF utilities.""" 16 | 17 | from __future__ import absolute_import, division, print_function 18 | 19 | import math 20 | 21 | import six 22 | import tensorflow as tf 23 | 24 | 25 | def gelu(x): 26 | """Gaussian Error Linear Unit. 27 | This is a smoother version of the RELU. 28 | Original paper: https://arxiv.org/abs/1606.08415 29 | Args: 30 | x: float Tensor to perform activation. 31 | Returns: 32 | `x` with the GELU activation applied. 33 | """ 34 | cdf = 0.5 * (1.0 + tf.tanh( 35 | (math.sqrt(2 / math.pi) * (x + 0.044715 * tf.pow(x, 3))))) 36 | return x * cdf 37 | 38 | 39 | def swish(features): 40 | """Computes the Swish activation function. 41 | The tf.nn.swish operation uses a custom gradient to reduce memory usage. 42 | Since saving custom gradients in SavedModel is currently not supported, and 43 | one would not be able to use an exported TF-Hub module for fine-tuning, we 44 | provide this wrapper that can allow to select whether to use the native 45 | TensorFlow swish operation, or whether to use a customized operation that 46 | has uses default TensorFlow gradient computation. 47 | Args: 48 | features: A `Tensor` representing preactivation values. 49 | Returns: 50 | The activation value. 51 | """ 52 | features = tf.convert_to_tensor(features) 53 | return features * tf.nn.sigmoid(features) 54 | 55 | 56 | def pack_inputs(inputs): 57 | """Pack a list of `inputs` tensors to a tuple. 58 | 59 | Args: 60 | inputs: a list of tensors. 61 | 62 | Returns: 63 | a tuple of tensors. if any input is None, replace it with a special constant 64 | tensor. 65 | """ 66 | inputs = tf.nest.flatten(inputs) 67 | outputs = [] 68 | for x in inputs: 69 | if x is None: 70 | outputs.append(tf.constant(0, shape=[], dtype=tf.int32)) 71 | else: 72 | outputs.append(x) 73 | return tuple(outputs) 74 | 75 | 76 | def unpack_inputs(inputs): 77 | """unpack a tuple of `inputs` tensors to a tuple. 78 | 79 | Args: 80 | inputs: a list of tensors. 81 | 82 | Returns: 83 | a tuple of tensors. if any input is a special constant tensor, replace it 84 | with None. 85 | """ 86 | inputs = tf.nest.flatten(inputs) 87 | outputs = [] 88 | for x in inputs: 89 | if is_special_none_tensor(x): 90 | outputs.append(None) 91 | else: 92 | outputs.append(x) 93 | x = tuple(outputs) 94 | 95 | # To trick the very pointless 'unbalanced-tuple-unpacking' pylint check 96 | # from triggering. 97 | if len(x) == 1: 98 | return x[0] 99 | return tuple(outputs) 100 | 101 | 102 | def is_special_none_tensor(tensor): 103 | """Checks if a tensor is a special None Tensor.""" 104 | return tensor.shape.ndims == 0 and tensor.dtype == tf.int32 105 | 106 | 107 | # TODO(hongkuny): consider moving custom string-map lookup to keras api. 108 | def get_activation(identifier): 109 | """Maps a identifier to a Python function, e.g., "relu" => `tf.nn.relu`. 110 | 111 | It checks string first and if it is one of customized activation not in TF, 112 | the corresponding activation will be returned. For non-customized activation 113 | names and callable identifiers, always fallback to tf.keras.activations.get. 114 | 115 | Args: 116 | identifier: String name of the activation function or callable. 117 | 118 | Returns: 119 | A Python function corresponding to the activation function. 120 | """ 121 | if isinstance(identifier, six.string_types): 122 | name_to_fn = { 123 | "gelu": gelu, 124 | "custom_swish": swish, 125 | } 126 | identifier = str(identifier).lower() 127 | if identifier in name_to_fn: 128 | return tf.keras.activations.get(name_to_fn[identifier]) 129 | return tf.keras.activations.get(identifier) 130 | 131 | 132 | def get_shape_list(tensor, expected_rank=None, name=None): 133 | """Returns a list of the shape of tensor, preferring static dimensions. 134 | 135 | Args: 136 | tensor: A tf.Tensor object to find the shape of. 137 | expected_rank: (optional) int. The expected rank of `tensor`. If this is 138 | specified and the `tensor` has a different rank, and exception will be 139 | thrown. 140 | name: Optional name of the tensor for the error message. 141 | 142 | Returns: 143 | A list of dimensions of the shape of tensor. All static dimensions will 144 | be returned as python integers, and dynamic dimensions will be returned 145 | as tf.Tensor scalars. 146 | """ 147 | if expected_rank is not None: 148 | assert_rank(tensor, expected_rank, name) 149 | 150 | shape = tensor.shape.as_list() 151 | 152 | non_static_indexes = [] 153 | for (index, dim) in enumerate(shape): 154 | if dim is None: 155 | non_static_indexes.append(index) 156 | 157 | if not non_static_indexes: 158 | return shape 159 | 160 | dyn_shape = tf.shape(tensor) 161 | for index in non_static_indexes: 162 | shape[index] = dyn_shape[index] 163 | return shape 164 | 165 | 166 | def assert_rank(tensor, expected_rank, name=None): 167 | """Raises an exception if the tensor rank is not of the expected rank. 168 | 169 | Args: 170 | tensor: A tf.Tensor to check the rank of. 171 | expected_rank: Python integer or list of integers, expected rank. 172 | name: Optional name of the tensor for the error message. 173 | 174 | Raises: 175 | ValueError: If the expected shape doesn't match the actual shape. 176 | """ 177 | expected_rank_dict = {} 178 | if isinstance(expected_rank, six.integer_types): 179 | expected_rank_dict[expected_rank] = True 180 | else: 181 | for x in expected_rank: 182 | expected_rank_dict[x] = True 183 | 184 | actual_rank = tensor.shape.ndims 185 | if actual_rank not in expected_rank_dict: 186 | raise ValueError( 187 | "For the tensor `%s`, the actual tensor rank `%d` (shape = %s) is not " 188 | "equal to the expected tensor rank `%s`" % 189 | (name, actual_rank, str(tensor.shape), str(expected_rank))) 190 | -------------------------------------------------------------------------------- /optimization.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | 17 | 18 | 19 | """Functions and classes related to optimization (weight updates). 20 | The file is forked from: 21 | https://github.com/google-research/bert/blob/master/optimization.py 22 | """ 23 | 24 | from __future__ import absolute_import 25 | from __future__ import division 26 | from __future__ import print_function 27 | 28 | import re 29 | 30 | import tensorflow as tf 31 | 32 | 33 | class WarmUp(tf.keras.optimizers.schedules.LearningRateSchedule): 34 | """Applys a warmup schedule on a given learning rate decay schedule.""" 35 | 36 | def __init__( 37 | self, 38 | initial_learning_rate, 39 | decay_schedule_fn, 40 | warmup_steps, 41 | power=1.0, 42 | name=None): 43 | super(WarmUp, self).__init__() 44 | self.initial_learning_rate = initial_learning_rate 45 | self.warmup_steps = warmup_steps 46 | self.power = power 47 | self.decay_schedule_fn = decay_schedule_fn 48 | self.name = name 49 | 50 | def __call__(self, step): 51 | with tf.name_scope(self.name or 'WarmUp') as name: 52 | # Implements polynomial warmup. i.e., if global_step < warmup_steps, the 53 | # learning rate will be `global_step/num_warmup_steps * init_lr`. 54 | global_step_float = tf.cast(step, tf.float32) 55 | warmup_steps_float = tf.cast(self.warmup_steps, tf.float32) 56 | warmup_percent_done = global_step_float / warmup_steps_float 57 | warmup_learning_rate = ( 58 | self.initial_learning_rate * 59 | tf.math.pow(warmup_percent_done, self.power)) 60 | return tf.cond(global_step_float < warmup_steps_float, 61 | lambda: warmup_learning_rate, 62 | lambda: self.decay_schedule_fn(step), 63 | name=name) 64 | 65 | def get_config(self): 66 | return { 67 | 'initial_learning_rate': self.initial_learning_rate, 68 | 'decay_schedule_fn': self.decay_schedule_fn, 69 | 'warmup_steps': self.warmup_steps, 70 | 'power': self.power, 71 | 'name': self.name 72 | } 73 | 74 | 75 | def create_optimizer(init_lr, num_train_steps, num_warmup_steps): 76 | """Creates an optimizer with learning rate schedule.""" 77 | # Implements linear decay of the learning rate. 78 | learning_rate_fn = tf.keras.optimizers.schedules.PolynomialDecay( 79 | initial_learning_rate=init_lr, 80 | decay_steps=num_train_steps, 81 | end_learning_rate=0.0) 82 | if num_warmup_steps: 83 | learning_rate_fn = WarmUp(initial_learning_rate=init_lr, 84 | decay_schedule_fn=learning_rate_fn, 85 | warmup_steps=num_warmup_steps) 86 | optimizer = AdamWeightDecay( 87 | learning_rate=learning_rate_fn, 88 | weight_decay_rate=0.01, 89 | beta_1=0.9, 90 | beta_2=0.999, 91 | epsilon=1e-6, 92 | exclude_from_weight_decay=['layer_norm', 'bias']) 93 | return optimizer 94 | 95 | 96 | class AdamWeightDecay(tf.keras.optimizers.Adam): 97 | """Adam enables L2 weight decay and clip_by_global_norm on gradients. 98 | 99 | Just adding the square of the weights to the loss function is *not* the 100 | correct way of using L2 regularization/weight decay with Adam, since that will 101 | interact with the m and v parameters in strange ways. 102 | 103 | Instead we want ot decay the weights in a manner that doesn't interact with 104 | the m/v parameters. This is equivalent to adding the square of the weights to 105 | the loss with plain (non-momentum) SGD. 106 | """ 107 | 108 | def __init__(self, 109 | learning_rate=0.001, 110 | beta_1=0.9, 111 | beta_2=0.999, 112 | epsilon=1e-7, 113 | amsgrad=False, 114 | weight_decay_rate=0.0, 115 | include_in_weight_decay=None, 116 | exclude_from_weight_decay=None, 117 | name='AdamWeightDecay', 118 | **kwargs): 119 | super(AdamWeightDecay, self).__init__( 120 | learning_rate, beta_1, beta_2, epsilon, amsgrad, name, **kwargs) 121 | self.weight_decay_rate = weight_decay_rate 122 | self._include_in_weight_decay = include_in_weight_decay 123 | self._exclude_from_weight_decay = exclude_from_weight_decay 124 | 125 | @classmethod 126 | def from_config(cls, config): 127 | """Creates an optimizer from its config with WarmUp custom object.""" 128 | custom_objects = {'WarmUp': WarmUp} 129 | return super(AdamWeightDecay, cls).from_config( 130 | config, custom_objects=custom_objects) 131 | 132 | def _prepare_local(self, var_device, var_dtype, apply_state): 133 | super(AdamWeightDecay, self)._prepare_local(var_device, var_dtype, 134 | apply_state) 135 | apply_state['weight_decay_rate'] = tf.constant( 136 | self.weight_decay_rate, name='adam_weight_decay_rate') 137 | 138 | def _decay_weights_op(self, var, learning_rate, apply_state): 139 | do_decay = self._do_use_weight_decay(var.name) 140 | if do_decay: 141 | return var.assign_sub( 142 | learning_rate * var * 143 | apply_state['weight_decay_rate'], 144 | use_locking=self._use_locking) 145 | return tf.no_op() 146 | 147 | def apply_gradients(self, grads_and_vars, name=None): 148 | grads, tvars = list(zip(*grads_and_vars)) 149 | (grads, _) = tf.clip_by_global_norm(grads, clip_norm=1.0) 150 | return super(AdamWeightDecay, self).apply_gradients(zip(grads, tvars)) 151 | 152 | def _get_lr(self, var_device, var_dtype, apply_state): 153 | """Retrieves the learning rate with the given state.""" 154 | if apply_state is None: 155 | return self._decayed_lr_t[var_dtype], {} 156 | 157 | apply_state = apply_state or {} 158 | coefficients = apply_state.get((var_device, var_dtype)) 159 | if coefficients is None: 160 | coefficients = self._fallback_apply_state(var_device, var_dtype) 161 | apply_state[(var_device, var_dtype)] = coefficients 162 | 163 | return coefficients['lr_t'], dict(apply_state=apply_state) 164 | 165 | def _resource_apply_dense(self, grad, var, apply_state=None): 166 | lr_t, kwargs = self._get_lr(var.device, var.dtype.base_dtype, apply_state) 167 | decay = self._decay_weights_op(var, lr_t, apply_state) 168 | with tf.control_dependencies([decay]): 169 | return super(AdamWeightDecay, self)._resource_apply_dense( 170 | grad, var, **kwargs) 171 | 172 | def _resource_apply_sparse(self, grad, var, indices, apply_state=None): 173 | lr_t, kwargs = self._get_lr(var.device, var.dtype.base_dtype, apply_state) 174 | decay = self._decay_weights_op(var, lr_t, apply_state) 175 | with tf.control_dependencies([decay]): 176 | return super(AdamWeightDecay, self)._resource_apply_sparse( 177 | grad, var, indices, **kwargs) 178 | 179 | def get_config(self): 180 | config = super(AdamWeightDecay, self).get_config() 181 | config.update({ 182 | 'weight_decay_rate': self.weight_decay_rate, 183 | }) 184 | return config 185 | 186 | def _do_use_weight_decay(self, param_name): 187 | """Whether to use L2 weight decay for `param_name`.""" 188 | if self.weight_decay_rate == 0: 189 | return False 190 | 191 | if self._include_in_weight_decay: 192 | for r in self._include_in_weight_decay: 193 | if re.search(r, param_name) is not None: 194 | return True 195 | 196 | if self._exclude_from_weight_decay: 197 | for r in self._exclude_from_weight_decay: 198 | if re.search(r, param_name) is not None: 199 | return False 200 | return True 201 | -------------------------------------------------------------------------------- /tokenization.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Tokenization classes implementation. 16 | 17 | The file is forked from: 18 | https://github.com/google-research/bert/blob/master/tokenization.py. 19 | """ 20 | 21 | from __future__ import absolute_import 22 | from __future__ import division 23 | from __future__ import print_function 24 | 25 | import collections 26 | import re 27 | import unicodedata 28 | 29 | import six 30 | import tensorflow as tf 31 | 32 | 33 | def validate_case_matches_checkpoint(do_lower_case, init_checkpoint): 34 | """Checks whether the casing config is consistent with the checkpoint name.""" 35 | 36 | # The casing has to be passed in by the user and there is no explicit check 37 | # as to whether it matches the checkpoint. The casing information probably 38 | # should have been stored in the bert_config.json file, but it's not, so 39 | # we have to heuristically detect it to validate. 40 | 41 | if not init_checkpoint: 42 | return 43 | 44 | m = re.match("^.*?([A-Za-z0-9_-]+)/bert_model.ckpt", init_checkpoint) 45 | if m is None: 46 | return 47 | 48 | model_name = m.group(1) 49 | 50 | lower_models = [ 51 | "uncased_L-24_H-1024_A-16", "uncased_L-12_H-768_A-12", 52 | "multilingual_L-12_H-768_A-12", "chinese_L-12_H-768_A-12" 53 | ] 54 | 55 | cased_models = [ 56 | "cased_L-12_H-768_A-12", "cased_L-24_H-1024_A-16", 57 | "multi_cased_L-12_H-768_A-12" 58 | ] 59 | 60 | is_bad_config = False 61 | if model_name in lower_models and not do_lower_case: 62 | is_bad_config = True 63 | actual_flag = "False" 64 | case_name = "lowercased" 65 | opposite_flag = "True" 66 | 67 | if model_name in cased_models and do_lower_case: 68 | is_bad_config = True 69 | actual_flag = "True" 70 | case_name = "cased" 71 | opposite_flag = "False" 72 | 73 | if is_bad_config: 74 | raise ValueError( 75 | "You passed in `--do_lower_case=%s` with `--init_checkpoint=%s`. " 76 | "However, `%s` seems to be a %s model, so you " 77 | "should pass in `--do_lower_case=%s` so that the fine-tuning matches " 78 | "how the model was pre-training. If this error is wrong, please " 79 | "just comment out this check." % 80 | (actual_flag, init_checkpoint, model_name, case_name, opposite_flag)) 81 | 82 | 83 | def convert_to_unicode(text): 84 | """Converts `text` to Unicode (if it's not already), assuming utf-8 input.""" 85 | if six.PY3: 86 | if isinstance(text, str): 87 | return text 88 | elif isinstance(text, bytes): 89 | return text.decode("utf-8", "ignore") 90 | else: 91 | raise ValueError("Unsupported string type: %s" % (type(text))) 92 | elif six.PY2: 93 | if isinstance(text, str): 94 | return text.decode("utf-8", "ignore") 95 | elif isinstance(text, unicode): 96 | return text 97 | else: 98 | raise ValueError("Unsupported string type: %s" % (type(text))) 99 | else: 100 | raise ValueError("Not running on Python2 or Python 3?") 101 | 102 | 103 | def printable_text(text): 104 | """Returns text encoded in a way suitable for print or `tf.logging`.""" 105 | 106 | # These functions want `str` for both Python2 and Python3, but in one case 107 | # it's a Unicode string and in the other it's a byte string. 108 | if six.PY3: 109 | if isinstance(text, str): 110 | return text 111 | elif isinstance(text, bytes): 112 | return text.decode("utf-8", "ignore") 113 | else: 114 | raise ValueError("Unsupported string type: %s" % (type(text))) 115 | elif six.PY2: 116 | if isinstance(text, str): 117 | return text 118 | elif isinstance(text, unicode): 119 | return text.encode("utf-8") 120 | else: 121 | raise ValueError("Unsupported string type: %s" % (type(text))) 122 | else: 123 | raise ValueError("Not running on Python2 or Python 3?") 124 | 125 | 126 | def load_vocab(vocab_file): 127 | """Loads a vocabulary file into a dictionary.""" 128 | vocab = collections.OrderedDict() 129 | index = 0 130 | with tf.io.gfile.GFile(vocab_file, "r") as reader: 131 | while True: 132 | token = convert_to_unicode(reader.readline()) 133 | if not token: 134 | break 135 | token = token.strip() 136 | vocab[token] = index 137 | index += 1 138 | return vocab 139 | 140 | 141 | def convert_by_vocab(vocab, items): 142 | """Converts a sequence of [tokens|ids] using the vocab.""" 143 | output = [] 144 | for item in items: 145 | output.append(vocab[item]) 146 | return output 147 | 148 | 149 | def convert_tokens_to_ids(vocab, tokens): 150 | return convert_by_vocab(vocab, tokens) 151 | 152 | 153 | def convert_ids_to_tokens(inv_vocab, ids): 154 | return convert_by_vocab(inv_vocab, ids) 155 | 156 | 157 | def whitespace_tokenize(text): 158 | """Runs basic whitespace cleaning and splitting on a piece of text.""" 159 | text = text.strip() 160 | if not text: 161 | return [] 162 | tokens = text.split() 163 | return tokens 164 | 165 | 166 | class FullTokenizer(object): 167 | """Runs end-to-end tokenziation.""" 168 | 169 | def __init__(self, vocab_file, do_lower_case=True): 170 | self.vocab = load_vocab(vocab_file) 171 | self.inv_vocab = {v: k for k, v in self.vocab.items()} 172 | self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case) 173 | self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab) 174 | 175 | def tokenize(self, text): 176 | split_tokens = [] 177 | for token in self.basic_tokenizer.tokenize(text): 178 | for sub_token in self.wordpiece_tokenizer.tokenize(token): 179 | split_tokens.append(sub_token) 180 | 181 | return split_tokens 182 | 183 | def convert_tokens_to_ids(self, tokens): 184 | return convert_by_vocab(self.vocab, tokens) 185 | 186 | def convert_ids_to_tokens(self, ids): 187 | return convert_by_vocab(self.inv_vocab, ids) 188 | 189 | 190 | class BasicTokenizer(object): 191 | """Runs basic tokenization (punctuation splitting, lower casing, etc.).""" 192 | 193 | def __init__(self, do_lower_case=True): 194 | """Constructs a BasicTokenizer. 195 | 196 | Args: 197 | do_lower_case: Whether to lower case the input. 198 | """ 199 | self.do_lower_case = do_lower_case 200 | 201 | def tokenize(self, text): 202 | """Tokenizes a piece of text.""" 203 | text = convert_to_unicode(text) 204 | text = self._clean_text(text) 205 | 206 | # This was added on November 1st, 2018 for the multilingual and Chinese 207 | # models. This is also applied to the English models now, but it doesn't 208 | # matter since the English models were not trained on any Chinese data 209 | # and generally don't have any Chinese data in them (there are Chinese 210 | # characters in the vocabulary because Wikipedia does have some Chinese 211 | # words in the English Wikipedia.). 212 | text = self._tokenize_chinese_chars(text) 213 | 214 | orig_tokens = whitespace_tokenize(text) 215 | split_tokens = [] 216 | for token in orig_tokens: 217 | if self.do_lower_case: 218 | token = token.lower() 219 | token = self._run_strip_accents(token) 220 | split_tokens.extend(self._run_split_on_punc(token)) 221 | 222 | output_tokens = whitespace_tokenize(" ".join(split_tokens)) 223 | return output_tokens 224 | 225 | def _run_strip_accents(self, text): 226 | """Strips accents from a piece of text.""" 227 | text = unicodedata.normalize("NFD", text) 228 | output = [] 229 | for char in text: 230 | cat = unicodedata.category(char) 231 | if cat == "Mn": 232 | continue 233 | output.append(char) 234 | return "".join(output) 235 | 236 | def _run_split_on_punc(self, text): 237 | """Splits punctuation on a piece of text.""" 238 | chars = list(text) 239 | i = 0 240 | start_new_word = True 241 | output = [] 242 | while i < len(chars): 243 | char = chars[i] 244 | if _is_punctuation(char): 245 | output.append([char]) 246 | start_new_word = True 247 | else: 248 | if start_new_word: 249 | output.append([]) 250 | start_new_word = False 251 | output[-1].append(char) 252 | i += 1 253 | 254 | return ["".join(x) for x in output] 255 | 256 | def _tokenize_chinese_chars(self, text): 257 | """Adds whitespace around any CJK character.""" 258 | output = [] 259 | for char in text: 260 | cp = ord(char) 261 | if self._is_chinese_char(cp): 262 | output.append(" ") 263 | output.append(char) 264 | output.append(" ") 265 | else: 266 | output.append(char) 267 | return "".join(output) 268 | 269 | def _is_chinese_char(self, cp): 270 | """Checks whether CP is the codepoint of a CJK character.""" 271 | # This defines a "chinese character" as anything in the CJK Unicode block: 272 | # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block) 273 | # 274 | # Note that the CJK Unicode block is NOT all Japanese and Korean characters, 275 | # despite its name. The modern Korean Hangul alphabet is a different block, 276 | # as is Japanese Hiragana and Katakana. Those alphabets are used to write 277 | # space-separated words, so they are not treated specially and handled 278 | # like the all of the other languages. 279 | if ((cp >= 0x4E00 and cp <= 0x9FFF) or # 280 | (cp >= 0x3400 and cp <= 0x4DBF) or # 281 | (cp >= 0x20000 and cp <= 0x2A6DF) or # 282 | (cp >= 0x2A700 and cp <= 0x2B73F) or # 283 | (cp >= 0x2B740 and cp <= 0x2B81F) or # 284 | (cp >= 0x2B820 and cp <= 0x2CEAF) or 285 | (cp >= 0xF900 and cp <= 0xFAFF) or # 286 | (cp >= 0x2F800 and cp <= 0x2FA1F)): # 287 | return True 288 | 289 | return False 290 | 291 | def _clean_text(self, text): 292 | """Performs invalid character removal and whitespace cleanup on text.""" 293 | output = [] 294 | for char in text: 295 | cp = ord(char) 296 | if cp == 0 or cp == 0xfffd or _is_control(char): 297 | continue 298 | if _is_whitespace(char): 299 | output.append(" ") 300 | else: 301 | output.append(char) 302 | return "".join(output) 303 | 304 | 305 | class WordpieceTokenizer(object): 306 | """Runs WordPiece tokenziation.""" 307 | 308 | def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=200): 309 | self.vocab = vocab 310 | self.unk_token = unk_token 311 | self.max_input_chars_per_word = max_input_chars_per_word 312 | 313 | def tokenize(self, text): 314 | """Tokenizes a piece of text into its word pieces. 315 | 316 | This uses a greedy longest-match-first algorithm to perform tokenization 317 | using the given vocabulary. 318 | 319 | For example: 320 | input = "unaffable" 321 | output = ["un", "##aff", "##able"] 322 | 323 | Args: 324 | text: A single token or whitespace separated tokens. This should have 325 | already been passed through `BasicTokenizer. 326 | 327 | Returns: 328 | A list of wordpiece tokens. 329 | """ 330 | 331 | text = convert_to_unicode(text) 332 | 333 | output_tokens = [] 334 | for token in whitespace_tokenize(text): 335 | chars = list(token) 336 | if len(chars) > self.max_input_chars_per_word: 337 | output_tokens.append(self.unk_token) 338 | continue 339 | 340 | is_bad = False 341 | start = 0 342 | sub_tokens = [] 343 | while start < len(chars): 344 | end = len(chars) 345 | cur_substr = None 346 | while start < end: 347 | substr = "".join(chars[start:end]) 348 | if start > 0: 349 | substr = "##" + substr 350 | if substr in self.vocab: 351 | cur_substr = substr 352 | break 353 | end -= 1 354 | if cur_substr is None: 355 | is_bad = True 356 | break 357 | sub_tokens.append(cur_substr) 358 | start = end 359 | 360 | if is_bad: 361 | output_tokens.append(self.unk_token) 362 | else: 363 | output_tokens.extend(sub_tokens) 364 | return output_tokens 365 | 366 | 367 | def _is_whitespace(char): 368 | """Checks whether `chars` is a whitespace character.""" 369 | # \t, \n, and \r are technically contorl characters but we treat them 370 | # as whitespace since they are generally considered as such. 371 | if char == " " or char == "\t" or char == "\n" or char == "\r": 372 | return True 373 | cat = unicodedata.category(char) 374 | if cat == "Zs": 375 | return True 376 | return False 377 | 378 | 379 | def _is_control(char): 380 | """Checks whether `chars` is a control character.""" 381 | # These are technically control characters but we count them as whitespace 382 | # characters. 383 | if char == "\t" or char == "\n" or char == "\r": 384 | return False 385 | cat = unicodedata.category(char) 386 | if cat in ("Cc", "Cf"): 387 | return True 388 | return False 389 | 390 | 391 | def _is_punctuation(char): 392 | """Checks whether `chars` is a punctuation character.""" 393 | cp = ord(char) 394 | # We treat all non-letter/number ASCII as punctuation. 395 | # Characters such as "^", "$", and "`" are not in the Unicode 396 | # Punctuation class but we treat them as punctuation anyways, for 397 | # consistency. 398 | if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or 399 | (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)): 400 | return True 401 | cat = unicodedata.category(char) 402 | if cat.startswith("P"): 403 | return True 404 | return False 405 | -------------------------------------------------------------------------------- /run_ner.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import, division, print_function 2 | 3 | import argparse 4 | import csv 5 | import json 6 | import logging 7 | import math 8 | import os 9 | import random 10 | import shutil 11 | import sys 12 | 13 | 14 | 15 | import numpy as np 16 | import tensorflow as tf 17 | from fastprogress import master_bar, progress_bar 18 | from seqeval.metrics import classification_report 19 | 20 | from model import BertNer 21 | from optimization import AdamWeightDecay, WarmUp 22 | from tokenization import FullTokenizer 23 | 24 | logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', 25 | datefmt='%m/%d/%Y %H:%M:%S', 26 | level=logging.INFO) 27 | logger = logging.getLogger(__name__) 28 | 29 | tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR) 30 | 31 | 32 | class InputExample(object): 33 | """A single training/test example for simple sequence classification.""" 34 | 35 | def __init__(self, guid, text_a, text_b=None, label=None): 36 | """Constructs a InputExample. 37 | 38 | Args: 39 | guid: Unique id for the example. 40 | text_a: string. The untokenized text of the first sequence. For single 41 | sequence tasks, only this sequence must be specified. 42 | text_b: (Optional) string. The untokenized text of the second sequence. 43 | Only must be specified for sequence pair tasks. 44 | label: (Optional) string. The label of the example. This should be 45 | specified for train and dev examples, but not for test examples. 46 | """ 47 | self.guid = guid 48 | self.text_a = text_a 49 | self.text_b = text_b 50 | self.label = label 51 | 52 | 53 | class InputFeatures(object): 54 | """A single set of features of data.""" 55 | 56 | def __init__(self, input_ids, input_mask, segment_ids, label_id, valid_ids=None, label_mask=None): 57 | self.input_ids = input_ids 58 | self.input_mask = input_mask 59 | self.segment_ids = segment_ids 60 | self.label_id = label_id 61 | self.valid_ids = valid_ids 62 | self.label_mask = label_mask 63 | 64 | 65 | def readfile(filename): 66 | ''' 67 | read file 68 | ''' 69 | f = open(filename) 70 | data = [] 71 | sentence = [] 72 | label = [] 73 | for line in f: 74 | if len(line) == 0 or line.startswith('-DOCSTART') or line[0] == "\n": 75 | if len(sentence) > 0: 76 | data.append((sentence, label)) 77 | sentence = [] 78 | label = [] 79 | continue 80 | splits = line.split(' ') 81 | sentence.append(splits[0]) 82 | label.append(splits[-1][:-1]) 83 | 84 | if len(sentence) > 0: 85 | data.append((sentence, label)) 86 | sentence = [] 87 | label = [] 88 | 89 | return data 90 | 91 | 92 | class DataProcessor(object): 93 | """Base class for data converters for sequence classification data sets.""" 94 | 95 | def get_train_examples(self, data_dir): 96 | """Gets a collection of `InputExample`s for the train set.""" 97 | raise NotImplementedError() 98 | 99 | def get_dev_examples(self, data_dir): 100 | """Gets a collection of `InputExample`s for the dev set.""" 101 | raise NotImplementedError() 102 | 103 | def get_labels(self): 104 | """Gets the list of labels for this data set.""" 105 | raise NotImplementedError() 106 | 107 | @classmethod 108 | def _read_tsv(cls, input_file, quotechar=None): 109 | """Reads a tab separated value file.""" 110 | return readfile(input_file) 111 | 112 | 113 | class NerProcessor(DataProcessor): 114 | """Processor for the CoNLL-2003 data set.""" 115 | 116 | def get_train_examples(self, data_dir): 117 | """See base class.""" 118 | return self._create_examples( 119 | self._read_tsv(os.path.join(data_dir, "train.txt")), "train") 120 | 121 | def get_dev_examples(self, data_dir): 122 | """See base class.""" 123 | return self._create_examples( 124 | self._read_tsv(os.path.join(data_dir, "valid.txt")), "dev") 125 | 126 | def get_test_examples(self, data_dir): 127 | """See base class.""" 128 | return self._create_examples( 129 | self._read_tsv(os.path.join(data_dir, "test.txt")), "test") 130 | 131 | def get_labels(self): 132 | return ["O", "B-MISC", "I-MISC", "B-PER", "I-PER", "B-ORG", "I-ORG", "B-LOC", "I-LOC", "[CLS]", "[SEP]"] 133 | 134 | def _create_examples(self, lines, set_type): 135 | examples = [] 136 | for i, (sentence, label) in enumerate(lines): 137 | guid = "%s-%s" % (set_type, i) 138 | text_a = ' '.join(sentence) 139 | text_b = None 140 | label = label 141 | examples.append(InputExample( 142 | guid=guid, text_a=text_a, text_b=text_b, label=label)) 143 | return examples 144 | 145 | 146 | def convert_examples_to_features(examples, label_list, max_seq_length, tokenizer): 147 | """Loads a data file into a list of `InputBatch`s.""" 148 | 149 | label_map = {label: i for i, label in enumerate(label_list, 1)} 150 | 151 | features = [] 152 | for (ex_index, example) in enumerate(examples): 153 | textlist = example.text_a.split(' ') 154 | labellist = example.label 155 | tokens = [] 156 | labels = [] 157 | valid_ids = [] 158 | label_mask = [] 159 | for i, word in enumerate(textlist): 160 | token = tokenizer.tokenize(word) 161 | tokens.extend(token) 162 | label_1 = labellist[i] 163 | for m in range(len(token)): 164 | if m == 0: 165 | labels.append(label_1) 166 | valid_ids.append(1) 167 | label_mask.append(True) 168 | else: 169 | valid_ids.append(0) 170 | if len(tokens) >= max_seq_length - 1: 171 | tokens = tokens[0:(max_seq_length - 2)] 172 | labels = labels[0:(max_seq_length - 2)] 173 | valid_ids = valid_ids[0:(max_seq_length - 2)] 174 | label_mask = label_mask[0:(max_seq_length - 2)] 175 | ntokens = [] 176 | segment_ids = [] 177 | label_ids = [] 178 | ntokens.append("[CLS]") 179 | segment_ids.append(0) 180 | valid_ids.insert(0, 1) 181 | label_mask.insert(0, True) 182 | label_ids.append(label_map["[CLS]"]) 183 | for i, token in enumerate(tokens): 184 | ntokens.append(token) 185 | segment_ids.append(0) 186 | if len(labels) > i: 187 | label_ids.append(label_map[labels[i]]) 188 | ntokens.append("[SEP]") 189 | segment_ids.append(0) 190 | valid_ids.append(1) 191 | label_mask.append(True) 192 | label_ids.append(label_map["[SEP]"]) 193 | input_ids = tokenizer.convert_tokens_to_ids(ntokens) 194 | input_mask = [1] * len(input_ids) 195 | label_mask = [True] * len(label_ids) 196 | while len(input_ids) < max_seq_length: 197 | input_ids.append(0) 198 | input_mask.append(0) 199 | segment_ids.append(0) 200 | label_ids.append(0) 201 | valid_ids.append(1) 202 | label_mask.append(False) 203 | while len(label_ids) < max_seq_length: 204 | label_ids.append(0) 205 | label_mask.append(False) 206 | 207 | 208 | assert len(input_ids) == max_seq_length 209 | assert len(input_mask) == max_seq_length 210 | assert len(segment_ids) == max_seq_length 211 | assert len(label_ids) == max_seq_length 212 | assert len(valid_ids) == max_seq_length 213 | assert len(label_mask) == max_seq_length 214 | 215 | if ex_index < 5: 216 | logger.info("*** Example ***") 217 | logger.info("guid: %s" % (example.guid)) 218 | logger.info("tokens: %s" % " ".join( 219 | [str(x) for x in tokens])) 220 | logger.info("input_ids: %s" % 221 | " ".join([str(x) for x in input_ids])) 222 | logger.info("input_mask: %s" % 223 | " ".join([str(x) for x in input_mask])) 224 | logger.info( 225 | "segment_ids: %s" % " ".join([str(x) for x in segment_ids])) 226 | 227 | features.append( 228 | InputFeatures(input_ids=input_ids, 229 | input_mask=input_mask, 230 | segment_ids=segment_ids, 231 | label_id=label_ids, 232 | valid_ids=valid_ids, 233 | label_mask=label_mask)) 234 | return features 235 | 236 | 237 | def main(): 238 | parser = argparse.ArgumentParser() 239 | 240 | # Required parameters 241 | parser.add_argument("--data_dir", 242 | default=None, 243 | type=str, 244 | required=True, 245 | help="The input data dir. Should contain the .tsv files (or other data files) for the task.") 246 | parser.add_argument("--bert_model", default=None, type=str, required=True, 247 | help="Bert pre-trained model selected in the list: bert-base-cased,bert-large-cased") 248 | parser.add_argument("--output_dir", 249 | default=None, 250 | type=str, 251 | required=True, 252 | help="The output directory where the model predictions and checkpoints will be written.") 253 | 254 | # Other parameters 255 | parser.add_argument("--max_seq_length", 256 | default=128, 257 | type=int, 258 | help="The maximum total input sequence length after WordPiece tokenization. \n" 259 | "Sequences longer than this will be truncated, and sequences shorter \n" 260 | "than this will be padded.") 261 | parser.add_argument("--do_train", 262 | action='store_true', 263 | help="Whether to run training.") 264 | parser.add_argument("--do_eval", 265 | action='store_true', 266 | help="Whether to run eval on the dev/test set.") 267 | parser.add_argument("--eval_on", 268 | default="dev", 269 | type=str, 270 | help="Evaluation set, dev: Development, test: Test") 271 | parser.add_argument("--do_lower_case", 272 | action='store_true', 273 | help="Set this flag if you are using an uncased model.") 274 | parser.add_argument("--train_batch_size", 275 | default=32, 276 | type=int, 277 | help="Total batch size for training.") 278 | parser.add_argument("--eval_batch_size", 279 | default=64, 280 | type=int, 281 | help="Total batch size for eval.") 282 | parser.add_argument("--learning_rate", 283 | default=5e-5, 284 | type=float, 285 | help="The initial learning rate for Adam.") 286 | parser.add_argument("--num_train_epochs", 287 | default=3, 288 | type=int, 289 | help="Total number of training epochs to perform.") 290 | parser.add_argument("--warmup_proportion", 291 | default=0.1, 292 | type=float, 293 | help="Proportion of training to perform linear learning rate warmup for. " 294 | "E.g., 0.1 = 10%% of training.") 295 | parser.add_argument("--weight_decay", default=0.01, type=float, 296 | help="Weight deay if we apply some.") 297 | parser.add_argument("--adam_epsilon", default=1e-8, type=float, 298 | help="Epsilon for Adam optimizer.") 299 | parser.add_argument('--seed', 300 | type=int, 301 | default=42, 302 | help="random seed for initialization") 303 | 304 | 305 | args = parser.parse_args() 306 | 307 | processor = NerProcessor() 308 | label_list = processor.get_labels() 309 | 310 | num_labels = len(label_list) + 1 311 | 312 | 313 | if os.path.exists(args.output_dir) and os.listdir(args.output_dir) and args.do_train: 314 | raise ValueError("Output directory ({}) already exists and is not empty.".format(args.output_dir)) 315 | if not os.path.exists(args.output_dir): 316 | os.makedirs(args.output_dir) 317 | 318 | 319 | if args.do_train: 320 | tokenizer = FullTokenizer(os.path.join(args.bert_model, "vocab.txt"), args.do_lower_case) 321 | 322 | train_examples = None 323 | optimizer = None 324 | num_train_optimization_steps = 0 325 | ner = None 326 | if args.do_train: 327 | train_examples = processor.get_train_examples(args.data_dir) 328 | num_train_optimization_steps = int( 329 | len(train_examples) / args.train_batch_size) * args.num_train_epochs 330 | warmup_steps = int(args.warmup_proportion * 331 | num_train_optimization_steps) 332 | learning_rate_fn = tf.keras.optimizers.schedules.PolynomialDecay(initial_learning_rate=args.learning_rate, 333 | decay_steps=num_train_optimization_steps,end_learning_rate=0.0) 334 | if warmup_steps: 335 | learning_rate_fn = WarmUp(initial_learning_rate=args.learning_rate, 336 | decay_schedule_fn=learning_rate_fn, 337 | warmup_steps=warmup_steps) 338 | optimizer = AdamWeightDecay( 339 | learning_rate=learning_rate_fn, 340 | weight_decay_rate=args.weight_decay, 341 | beta_1=0.9, 342 | beta_2=0.999, 343 | epsilon=args.adam_epsilon, 344 | exclude_from_weight_decay=['layer_norm', 'bias']) 345 | 346 | ner = BertNer(args.bert_model, tf.float32, num_labels, args.max_seq_length) 347 | # loss_fct = tf.keras.losses.SparseCategoricalCrossentropy(reduction=tf.keras.losses.Reduction.NONE) 348 | loss_fct = tf.keras.losses.SparseCategoricalCrossentropy(reduction=tf.keras.losses.Reduction.NONE) 349 | loss11_fct = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True, reduction=tf.keras.losses.Reduction.NONE) 350 | l_fct = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True) 351 | 352 | 353 | label_map = {i: label for i, label in enumerate(label_list, 1)} 354 | if args.do_train: 355 | train_features = convert_examples_to_features( 356 | train_examples, label_list, args.max_seq_length, tokenizer) 357 | 358 | logger.info("***** Running training *****") 359 | logger.info(" Num examples = %d", len(train_examples)) 360 | logger.info(" Batch size = %d", args.train_batch_size) 361 | logger.info(" Num steps = %d", num_train_optimization_steps) 362 | 363 | all_input_ids = tf.data.Dataset.from_tensor_slices( 364 | np.asarray([f.input_ids for f in train_features])) 365 | all_input_mask = tf.data.Dataset.from_tensor_slices( 366 | np.asarray([f.input_mask for f in train_features])) 367 | all_segment_ids = tf.data.Dataset.from_tensor_slices( 368 | np.asarray([f.segment_ids for f in train_features])) 369 | all_valid_ids = tf.data.Dataset.from_tensor_slices( 370 | np.asarray([f.valid_ids for f in train_features])) 371 | all_label_mask = tf.data.Dataset.from_tensor_slices( 372 | np.asarray([f.label_mask for f in train_features])) 373 | all_label_ids = tf.data.Dataset.from_tensor_slices( 374 | np.asarray([f.label_id for f in train_features])) 375 | 376 | 377 | # Dataset using tf.data 378 | train_data = tf.data.Dataset.zip( 379 | (all_input_ids, all_input_mask, all_segment_ids, all_valid_ids, all_label_ids,all_label_mask)) 380 | shuffled_train_data = train_data.shuffle(buffer_size=int(len(train_features) * 0.1), 381 | seed = args.seed, 382 | reshuffle_each_iteration=True) 383 | batched_train_data = shuffled_train_data.batch(args.train_batch_size) 384 | 385 | loss_metric = tf.keras.metrics.Mean() 386 | 387 | epoch_bar = master_bar(range(args.num_train_epochs)) 388 | pb_max_len = math.ceil( 389 | float(len(train_features))/float(args.train_batch_size)) 390 | 391 | def train_step(input_ids, input_mask, segment_ids, valid_ids, label_ids,label_mask): 392 | 393 | with tf.GradientTape() as tape: 394 | logits = ner(input_ids, input_mask,segment_ids, valid_ids, training=True) #batchsize, max_seq_length, num_labels 395 | label_ids_masked = tf.boolean_mask(label_ids,label_mask) 396 | logits_masked = tf.boolean_mask(logits,label_mask) 397 | loss = loss_fct(label_ids_masked, logits_masked) 398 | 399 | grads = tape.gradient(loss, ner.trainable_variables) 400 | optimizer.apply_gradients(list(zip(grads, ner.trainable_variables))) 401 | return loss 402 | 403 | for epoch in epoch_bar: 404 | for (input_ids, input_mask, segment_ids, valid_ids, label_ids,label_mask) in progress_bar(batched_train_data, total=pb_max_len, parent=epoch_bar): 405 | loss = train_step(input_ids, input_mask, segment_ids, valid_ids, label_ids,label_mask) 406 | loss_metric(loss) 407 | epoch_bar.child.comment = f'loss : {loss_metric.result()}' 408 | loss_metric.reset_states() 409 | 410 | # model weight save 411 | ner.save_weights(os.path.join(args.output_dir,"model.h5")) 412 | # copy vocab to output_dir 413 | shutil.copyfile(os.path.join(args.bert_model,"vocab.txt"),os.path.join(args.output_dir,"vocab.txt")) 414 | # copy bert config to output_dir 415 | shutil.copyfile(os.path.join(args.bert_model,"bert_config.json"),os.path.join(args.output_dir,"bert_config.json")) 416 | # save label_map and max_seq_length of trained model 417 | model_config = {"bert_model":args.bert_model,"do_lower":args.do_lower_case, 418 | "max_seq_length":args.max_seq_length,"num_labels":num_labels, 419 | "label_map":label_map} 420 | json.dump(model_config,open(os.path.join(args.output_dir,"model_config.json"),"w"),indent=4) 421 | 422 | 423 | if args.do_eval: 424 | # load tokenizer 425 | tokenizer = FullTokenizer(os.path.join(args.output_dir, "vocab.txt"), args.do_lower_case) 426 | # model build hack : fix 427 | config = json.load(open(os.path.join(args.output_dir,"bert_config.json"))) 428 | ner = BertNer(config, tf.float32, num_labels, args.max_seq_length) 429 | ids = tf.ones((1,128),dtype=tf.int64) 430 | _ = ner(ids,ids,ids,ids, training=False) 431 | ner.load_weights(os.path.join(args.output_dir,"model.h5")) 432 | 433 | # load test or development set based on argsK 434 | if args.eval_on == "dev": 435 | eval_examples = processor.get_dev_examples(args.data_dir) 436 | elif args.eval_on == "test": 437 | eval_examples = processor.get_test_examples(args.data_dir) 438 | 439 | eval_features = convert_examples_to_features( 440 | eval_examples, label_list, args.max_seq_length, tokenizer) 441 | logger.info("***** Running evaluation *****") 442 | logger.info(" Num examples = %d", len(eval_examples)) 443 | logger.info(" Batch size = %d", args.eval_batch_size) 444 | 445 | all_input_ids = tf.data.Dataset.from_tensor_slices( 446 | np.asarray([f.input_ids for f in eval_features])) 447 | all_input_mask = tf.data.Dataset.from_tensor_slices( 448 | np.asarray([f.input_mask for f in eval_features])) 449 | all_segment_ids = tf.data.Dataset.from_tensor_slices( 450 | np.asarray([f.segment_ids for f in eval_features])) 451 | all_valid_ids = tf.data.Dataset.from_tensor_slices( 452 | np.asarray([f.valid_ids for f in eval_features])) 453 | 454 | all_label_ids = tf.data.Dataset.from_tensor_slices( 455 | np.asarray([f.label_id for f in eval_features])) 456 | 457 | eval_data = tf.data.Dataset.zip( 458 | (all_input_ids, all_input_mask, all_segment_ids, all_valid_ids, all_label_ids)) 459 | batched_eval_data = eval_data.batch(args.eval_batch_size) 460 | 461 | loss_metric = tf.keras.metrics.Mean() 462 | epoch_bar = master_bar(range(1)) 463 | pb_max_len = math.ceil( 464 | float(len(eval_features))/float(args.eval_batch_size)) 465 | 466 | y_true = [] 467 | y_pred = [] 468 | label_map = {i : label for i, label in enumerate(label_list,1)} 469 | for epoch in epoch_bar: 470 | for (input_ids, input_mask, segment_ids, valid_ids, label_ids) in progress_bar(batched_eval_data, total=pb_max_len, parent=epoch_bar): 471 | logits = ner(input_ids, input_mask, 472 | segment_ids, valid_ids, training=False) 473 | logits = tf.argmax(logits,axis=2) 474 | for i, label in enumerate(label_ids): 475 | temp_1 = [] 476 | temp_2 = [] 477 | for j,m in enumerate(label): 478 | if j == 0: 479 | continue 480 | elif label_ids[i][j].numpy() == len(label_map): 481 | y_true.append(temp_1) 482 | y_pred.append(temp_2) 483 | break 484 | else: 485 | temp_1.append(label_map[label_ids[i][j].numpy()]) 486 | temp_2.append(label_map[logits[i][j].numpy()]) 487 | report = classification_report(y_true, y_pred,digits=4) 488 | 489 | if args.eval_on == "test": 490 | output_eval_file = os.path.join(args.output_dir, "eval_results_on_test.txt") 491 | with open(output_eval_file, "w") as writer: 492 | logger.info("***** Eval results *****") 493 | logger.info("\n%s", report) 494 | writer.write(report) 495 | else: 496 | output_eval_file = os.path.join(args.output_dir, "eval_results_on_valid.txt") 497 | with open(output_eval_file, "w") as writer: 498 | logger.info("***** Eval results *****") 499 | logger.info("\n%s", report) 500 | writer.write(report) 501 | 502 | 503 | 504 | if __name__ == "__main__": 505 | main() 506 | -------------------------------------------------------------------------------- /bert_modeling.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """The main BERT model and related functions. 16 | This file is forked from : 17 | https://github.com/google-research/bert/blob/master/modeling.py 18 | """ 19 | 20 | from __future__ import absolute_import 21 | from __future__ import division 22 | from __future__ import print_function 23 | 24 | import copy 25 | import json 26 | import math 27 | import six 28 | import tensorflow as tf 29 | 30 | from utils import tf_utils 31 | 32 | 33 | class BertConfig(object): 34 | """Configuration for `BertModel`.""" 35 | 36 | def __init__(self, 37 | vocab_size, 38 | hidden_size=768, 39 | num_hidden_layers=12, 40 | num_attention_heads=12, 41 | intermediate_size=3072, 42 | hidden_act="gelu", 43 | hidden_dropout_prob=0.1, 44 | attention_probs_dropout_prob=0.1, 45 | max_position_embeddings=512, 46 | type_vocab_size=16, 47 | initializer_range=0.02, 48 | backward_compatible=True): 49 | """Constructs BertConfig. 50 | 51 | Args: 52 | vocab_size: Vocabulary size of `inputs_ids` in `BertModel`. 53 | hidden_size: Size of the encoder layers and the pooler layer. 54 | num_hidden_layers: Number of hidden layers in the Transformer encoder. 55 | num_attention_heads: Number of attention heads for each attention layer in 56 | the Transformer encoder. 57 | intermediate_size: The size of the "intermediate" (i.e., feed-forward) 58 | layer in the Transformer encoder. 59 | hidden_act: The non-linear activation function (function or string) in the 60 | encoder and pooler. 61 | hidden_dropout_prob: The dropout probability for all fully connected 62 | layers in the embeddings, encoder, and pooler. 63 | attention_probs_dropout_prob: The dropout ratio for the attention 64 | probabilities. 65 | max_position_embeddings: The maximum sequence length that this model might 66 | ever be used with. Typically set this to something large just in case 67 | (e.g., 512 or 1024 or 2048). 68 | type_vocab_size: The vocabulary size of the `token_type_ids` passed into 69 | `BertModel`. 70 | initializer_range: The stdev of the truncated_normal_initializer for 71 | initializing all weight matrices. 72 | backward_compatible: Boolean, whether the variables shape are compatible 73 | with checkpoints converted from TF 1.x BERT. 74 | """ 75 | self.vocab_size = vocab_size 76 | self.hidden_size = hidden_size 77 | self.num_hidden_layers = num_hidden_layers 78 | self.num_attention_heads = num_attention_heads 79 | self.hidden_act = hidden_act 80 | self.intermediate_size = intermediate_size 81 | self.hidden_dropout_prob = hidden_dropout_prob 82 | self.attention_probs_dropout_prob = attention_probs_dropout_prob 83 | self.max_position_embeddings = max_position_embeddings 84 | self.type_vocab_size = type_vocab_size 85 | self.initializer_range = initializer_range 86 | self.backward_compatible = backward_compatible 87 | 88 | @classmethod 89 | def from_dict(cls, json_object): 90 | """Constructs a `BertConfig` from a Python dictionary of parameters.""" 91 | config = BertConfig(vocab_size=None) 92 | for (key, value) in six.iteritems(json_object): 93 | config.__dict__[key] = value 94 | return config 95 | 96 | @classmethod 97 | def from_json_file(cls, json_file): 98 | """Constructs a `BertConfig` from a json file of parameters.""" 99 | with tf.io.gfile.GFile(json_file, "r") as reader: 100 | text = reader.read() 101 | return cls.from_dict(json.loads(text)) 102 | 103 | def to_dict(self): 104 | """Serializes this instance to a Python dictionary.""" 105 | output = copy.deepcopy(self.__dict__) 106 | return output 107 | 108 | def to_json_string(self): 109 | """Serializes this instance to a JSON string.""" 110 | return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n" 111 | 112 | 113 | def get_bert_model(input_word_ids, 114 | input_mask, 115 | input_type_ids, 116 | config=None, 117 | name=None, 118 | float_type=tf.float32): 119 | """Wraps the core BERT model as a keras.Model.""" 120 | bert_model_layer = BertModel(config=config, float_type=float_type, name=name) 121 | pooled_output, sequence_output = bert_model_layer(input_word_ids, input_mask, 122 | input_type_ids) 123 | bert_model = tf.keras.Model( 124 | inputs=[input_word_ids, input_mask, input_type_ids], 125 | outputs=[pooled_output, sequence_output]) 126 | return bert_model 127 | 128 | 129 | class BertModel(tf.keras.layers.Layer): 130 | """BERT model ("Bidirectional Encoder Representations from Transformers"). 131 | 132 | Example usage: 133 | 134 | ```python 135 | # Already been converted into WordPiece token ids 136 | input_word_ids = tf.constant([[31, 51, 99], [15, 5, 0]]) 137 | input_mask = tf.constant([[1, 1, 1], [1, 1, 0]]) 138 | input_type_ids = tf.constant([[0, 0, 1], [0, 2, 0]]) 139 | 140 | config = modeling.BertConfig(vocab_size=32000, hidden_size=512, 141 | num_hidden_layers=8, num_attention_heads=6, intermediate_size=1024) 142 | 143 | pooled_output, sequence_output = modeling.BertModel(config=config)( 144 | input_word_ids=input_word_ids, 145 | input_mask=input_mask, 146 | input_type_ids=input_type_ids) 147 | ... 148 | ``` 149 | """ 150 | 151 | def __init__(self, config, float_type=tf.float32, **kwargs): 152 | super(BertModel, self).__init__(**kwargs) 153 | self.config = ( 154 | BertConfig.from_dict(config) 155 | if isinstance(config, dict) else copy.deepcopy(config)) 156 | self.float_type = float_type 157 | 158 | def build(self, unused_input_shapes): 159 | """Implements build() for the layer.""" 160 | self.embedding_lookup = EmbeddingLookup( 161 | vocab_size=self.config.vocab_size, 162 | embedding_size=self.config.hidden_size, 163 | initializer_range=self.config.initializer_range, 164 | dtype=tf.float32, 165 | name="word_embeddings") 166 | self.embedding_postprocessor = EmbeddingPostprocessor( 167 | use_type_embeddings=True, 168 | token_type_vocab_size=self.config.type_vocab_size, 169 | use_position_embeddings=True, 170 | max_position_embeddings=self.config.max_position_embeddings, 171 | dropout_prob=self.config.hidden_dropout_prob, 172 | initializer_range=self.config.initializer_range, 173 | dtype=tf.float32, 174 | name="embedding_postprocessor") 175 | self.encoder = Transformer( 176 | num_hidden_layers=self.config.num_hidden_layers, 177 | hidden_size=self.config.hidden_size, 178 | num_attention_heads=self.config.num_attention_heads, 179 | intermediate_size=self.config.intermediate_size, 180 | intermediate_activation=self.config.hidden_act, 181 | hidden_dropout_prob=self.config.hidden_dropout_prob, 182 | attention_probs_dropout_prob=self.config.attention_probs_dropout_prob, 183 | initializer_range=self.config.initializer_range, 184 | backward_compatible=self.config.backward_compatible, 185 | float_type=self.float_type, 186 | name="encoder") 187 | self.pooler_transform = tf.keras.layers.Dense( 188 | units=self.config.hidden_size, 189 | activation="tanh", 190 | kernel_initializer=get_initializer(self.config.initializer_range), 191 | name="pooler_transform") 192 | super(BertModel, self).build(unused_input_shapes) 193 | 194 | def __call__(self, 195 | input_word_ids, 196 | input_mask=None, 197 | input_type_ids=None, 198 | **kwargs): 199 | inputs = tf_utils.pack_inputs([input_word_ids, input_mask, input_type_ids]) 200 | return super(BertModel, self).__call__(inputs, **kwargs) 201 | 202 | def call(self, inputs, mode="bert", **kwargs): 203 | """Implements call() for the layer. 204 | 205 | Args: 206 | inputs: packed input tensors. 207 | mode: string, `bert` or `encoder`. 208 | Returns: 209 | Output tensor of the last layer for BERT training (mode=`bert`) which 210 | is a float Tensor of shape [batch_size, seq_length, hidden_size] or 211 | a list of output tensors for encoder usage (mode=`encoder`). 212 | """ 213 | unpacked_inputs = tf_utils.unpack_inputs(inputs) 214 | input_word_ids = unpacked_inputs[0] 215 | input_mask = unpacked_inputs[1] 216 | input_type_ids = unpacked_inputs[2] 217 | 218 | word_embeddings = self.embedding_lookup(input_word_ids) 219 | embedding_tensor = self.embedding_postprocessor( 220 | word_embeddings=word_embeddings, token_type_ids=input_type_ids) 221 | if self.float_type == tf.float16: 222 | embedding_tensor = tf.cast(embedding_tensor, tf.float16) 223 | attention_mask = None 224 | if input_mask is not None: 225 | attention_mask = create_attention_mask_from_input_mask( 226 | input_word_ids, input_mask) 227 | 228 | if mode == "encoder": 229 | return self.encoder( 230 | embedding_tensor, attention_mask, return_all_layers=True) 231 | 232 | sequence_output = self.encoder(embedding_tensor, attention_mask) 233 | first_token_tensor = tf.squeeze(sequence_output[:, 0:1, :], axis=1) 234 | pooled_output = self.pooler_transform(first_token_tensor) 235 | 236 | return (pooled_output, sequence_output) 237 | 238 | def get_config(self): 239 | config = {"config": self.config.to_dict()} 240 | base_config = super(BertModel, self).get_config() 241 | return dict(list(base_config.items()) + list(config.items())) 242 | 243 | 244 | class EmbeddingLookup(tf.keras.layers.Layer): 245 | """Looks up words embeddings for id tensor.""" 246 | 247 | def __init__(self, 248 | vocab_size, 249 | embedding_size=768, 250 | initializer_range=0.02, 251 | **kwargs): 252 | super(EmbeddingLookup, self).__init__(**kwargs) 253 | self.vocab_size = vocab_size 254 | self.embedding_size = embedding_size 255 | self.initializer_range = initializer_range 256 | 257 | def build(self, unused_input_shapes): 258 | """Implements build() for the layer.""" 259 | self.embeddings = self.add_weight( 260 | "embeddings", 261 | shape=[self.vocab_size, self.embedding_size], 262 | initializer=get_initializer(self.initializer_range), 263 | dtype=self.dtype) 264 | super(EmbeddingLookup, self).build(unused_input_shapes) 265 | 266 | def call(self, inputs): 267 | """Implements call() for the layer.""" 268 | input_shape = tf_utils.get_shape_list(inputs) 269 | flat_input = tf.reshape(inputs, [-1]) 270 | output = tf.gather(self.embeddings, flat_input) 271 | output = tf.reshape(output, input_shape + [self.embedding_size]) 272 | return output 273 | 274 | 275 | class EmbeddingPostprocessor(tf.keras.layers.Layer): 276 | """Performs various post-processing on a word embedding tensor.""" 277 | 278 | def __init__(self, 279 | use_type_embeddings=False, 280 | token_type_vocab_size=None, 281 | use_position_embeddings=True, 282 | max_position_embeddings=512, 283 | dropout_prob=0.0, 284 | initializer_range=0.02, 285 | initializer=None, 286 | **kwargs): 287 | super(EmbeddingPostprocessor, self).__init__(**kwargs) 288 | self.use_type_embeddings = use_type_embeddings 289 | self.token_type_vocab_size = token_type_vocab_size 290 | self.use_position_embeddings = use_position_embeddings 291 | self.max_position_embeddings = max_position_embeddings 292 | self.dropout_prob = dropout_prob 293 | self.initializer_range = initializer_range 294 | 295 | if not initializer: 296 | self.initializer = get_initializer(self.initializer_range) 297 | else: 298 | self.initializer = initializer 299 | 300 | if self.use_type_embeddings and not self.token_type_vocab_size: 301 | raise ValueError("If `use_type_embeddings` is True, then " 302 | "`token_type_vocab_size` must be specified.") 303 | 304 | def build(self, input_shapes): 305 | """Implements build() for the layer.""" 306 | (word_embeddings_shape, _) = input_shapes 307 | width = word_embeddings_shape.as_list()[-1] 308 | self.type_embeddings = None 309 | if self.use_type_embeddings: 310 | self.type_embeddings = self.add_weight( 311 | "type_embeddings", 312 | shape=[self.token_type_vocab_size, width], 313 | initializer=get_initializer(self.initializer_range), 314 | dtype=self.dtype) 315 | 316 | self.position_embeddings = None 317 | if self.use_position_embeddings: 318 | self.position_embeddings = self.add_weight( 319 | "position_embeddings", 320 | shape=[self.max_position_embeddings, width], 321 | initializer=get_initializer(self.initializer_range), 322 | dtype=self.dtype) 323 | 324 | self.output_layer_norm = tf.keras.layers.LayerNormalization( 325 | name="layer_norm", axis=-1, epsilon=1e-12, dtype=tf.float32) 326 | self.output_dropout = tf.keras.layers.Dropout(rate=self.dropout_prob, 327 | dtype=tf.float32) 328 | super(EmbeddingPostprocessor, self).build(input_shapes) 329 | 330 | def __call__(self, word_embeddings, token_type_ids=None, **kwargs): 331 | inputs = tf_utils.pack_inputs([word_embeddings, token_type_ids]) 332 | return super(EmbeddingPostprocessor, self).__call__(inputs, **kwargs) 333 | 334 | def call(self, inputs, **kwargs): 335 | """Implements call() for the layer.""" 336 | unpacked_inputs = tf_utils.unpack_inputs(inputs) 337 | word_embeddings = unpacked_inputs[0] 338 | token_type_ids = unpacked_inputs[1] 339 | input_shape = tf_utils.get_shape_list(word_embeddings, expected_rank=3) 340 | batch_size = input_shape[0] 341 | seq_length = input_shape[1] 342 | width = input_shape[2] 343 | 344 | output = word_embeddings 345 | if self.use_type_embeddings: 346 | flat_token_type_ids = tf.reshape(token_type_ids, [-1]) 347 | one_hot_ids = tf.one_hot( 348 | flat_token_type_ids, 349 | depth=self.token_type_vocab_size, 350 | dtype=self.dtype) 351 | token_type_embeddings = tf.matmul(one_hot_ids, self.type_embeddings) 352 | token_type_embeddings = tf.reshape(token_type_embeddings, 353 | [batch_size, seq_length, width]) 354 | output += token_type_embeddings 355 | 356 | if self.use_position_embeddings: 357 | position_embeddings = tf.expand_dims( 358 | tf.slice(self.position_embeddings, [0, 0], [seq_length, width]), 359 | axis=0) 360 | 361 | output += position_embeddings 362 | 363 | output = self.output_layer_norm(output) 364 | output = self.output_dropout(output,training=kwargs.get('training', False)) 365 | 366 | return output 367 | 368 | 369 | class Attention(tf.keras.layers.Layer): 370 | """Performs multi-headed attention from `from_tensor` to `to_tensor`. 371 | 372 | This is an implementation of multi-headed attention based on "Attention 373 | is all you Need". If `from_tensor` and `to_tensor` are the same, then 374 | this is self-attention. Each timestep in `from_tensor` attends to the 375 | corresponding sequence in `to_tensor`, and returns a fixed-with vector. 376 | 377 | This function first projects `from_tensor` into a "query" tensor and 378 | `to_tensor` into "key" and "value" tensors. These are (effectively) a list 379 | of tensors of length `num_attention_heads`, where each tensor is of shape 380 | [batch_size, seq_length, size_per_head]. 381 | 382 | Then, the query and key tensors are dot-producted and scaled. These are 383 | softmaxed to obtain attention probabilities. The value tensors are then 384 | interpolated by these probabilities, then concatenated back to a single 385 | tensor and returned. 386 | 387 | In practice, the multi-headed attention are done with tf.einsum as follows: 388 | Input_tensor: [BFD] 389 | Wq, Wk, Wv: [DNH] 390 | Q:[BFNH] = einsum('BFD,DNH->BFNH', Input_tensor, Wq) 391 | K:[BTNH] = einsum('BTD,DNH->BTNH', Input_tensor, Wk) 392 | V:[BTNH] = einsum('BTD,DNH->BTNH', Input_tensor, Wv) 393 | attention_scores:[BNFT] = einsum('BTNH,BFNH->BNFT', K, Q) / sqrt(H) 394 | attention_probs:[BNFT] = softmax(attention_scores) 395 | context_layer:[BFNH] = einsum('BNFT,BTNH->BFNH', attention_probs, V) 396 | Wout:[DNH] 397 | Output:[BFD] = einsum('BFNH,DNH>BFD', context_layer, Wout) 398 | """ 399 | 400 | def __init__(self, 401 | num_attention_heads=12, 402 | size_per_head=64, 403 | attention_probs_dropout_prob=0.0, 404 | initializer_range=0.02, 405 | backward_compatible=False, 406 | **kwargs): 407 | super(Attention, self).__init__(**kwargs) 408 | self.num_attention_heads = num_attention_heads 409 | self.size_per_head = size_per_head 410 | self.attention_probs_dropout_prob = attention_probs_dropout_prob 411 | self.initializer_range = initializer_range 412 | self.backward_compatible = backward_compatible 413 | 414 | def build(self, unused_input_shapes): 415 | """Implements build() for the layer.""" 416 | self.query_dense = self._projection_dense_layer("query") 417 | self.key_dense = self._projection_dense_layer("key") 418 | self.value_dense = self._projection_dense_layer("value") 419 | self.attention_probs_dropout = tf.keras.layers.Dropout( 420 | rate=self.attention_probs_dropout_prob) 421 | super(Attention, self).build(unused_input_shapes) 422 | 423 | def reshape_to_matrix(self, input_tensor): 424 | """Reshape N > 2 rank tensor to rank 2 tensor for performance.""" 425 | ndims = input_tensor.shape.ndims 426 | if ndims < 2: 427 | raise ValueError("Input tensor must have at least rank 2." 428 | "Shape = %s" % (input_tensor.shape)) 429 | if ndims == 2: 430 | return input_tensor 431 | 432 | width = input_tensor.shape[-1] 433 | output_tensor = tf.reshape(input_tensor, [-1, width]) 434 | return output_tensor 435 | 436 | def __call__(self, from_tensor, to_tensor, attention_mask=None, **kwargs): 437 | inputs = tf_utils.pack_inputs([from_tensor, to_tensor, attention_mask]) 438 | return super(Attention, self).__call__(inputs, **kwargs) 439 | 440 | def call(self, inputs,**kwargs): 441 | """Implements call() for the layer.""" 442 | (from_tensor, to_tensor, attention_mask) = tf_utils.unpack_inputs(inputs) 443 | 444 | # Scalar dimensions referenced here: 445 | # B = batch size (number of sequences) 446 | # F = `from_tensor` sequence length 447 | # T = `to_tensor` sequence length 448 | # N = `num_attention_heads` 449 | # H = `size_per_head` 450 | # `query_tensor` = [B, F, N ,H] 451 | query_tensor = self.query_dense(from_tensor) 452 | 453 | # `key_tensor` = [B, T, N, H] 454 | key_tensor = self.key_dense(to_tensor) 455 | 456 | # `value_tensor` = [B, T, N, H] 457 | value_tensor = self.value_dense(to_tensor) 458 | 459 | # Take the dot product between "query" and "key" to get the raw 460 | # attention scores. 461 | attention_scores = tf.einsum("BTNH,BFNH->BNFT", key_tensor, query_tensor) 462 | attention_scores = tf.multiply(attention_scores, 463 | 1.0 / math.sqrt(float(self.size_per_head))) 464 | 465 | if attention_mask is not None: 466 | # `attention_mask` = [B, 1, F, T] 467 | attention_mask = tf.expand_dims(attention_mask, axis=[1]) 468 | 469 | # Since attention_mask is 1.0 for positions we want to attend and 0.0 for 470 | # masked positions, this operation will create a tensor which is 0.0 for 471 | # positions we want to attend and -10000.0 for masked positions. 472 | adder = (1.0 - tf.cast(attention_mask, attention_scores.dtype)) * -10000.0 473 | 474 | # Since we are adding it to the raw scores before the softmax, this is 475 | # effectively the same as removing these entirely. 476 | attention_scores += adder 477 | 478 | # Normalize the attention scores to probabilities. 479 | # `attention_probs` = [B, N, F, T] 480 | attention_probs = tf.nn.softmax(attention_scores) 481 | 482 | # This is actually dropping out entire tokens to attend to, which might 483 | # seem a bit unusual, but is taken from the original Transformer paper. 484 | attention_probs = self.attention_probs_dropout(attention_probs,training=kwargs.get('training', False)) 485 | 486 | # `context_layer` = [B, F, N, H] 487 | context_tensor = tf.einsum("BNFT,BTNH->BFNH", attention_probs, value_tensor) 488 | 489 | return context_tensor 490 | 491 | def _projection_dense_layer(self, name): 492 | """A helper to define a projection layer.""" 493 | return Dense3D( 494 | num_attention_heads=self.num_attention_heads, 495 | size_per_head=self.size_per_head, 496 | kernel_initializer=get_initializer(self.initializer_range), 497 | output_projection=False, 498 | backward_compatible=self.backward_compatible, 499 | name=name) 500 | 501 | 502 | class Dense3D(tf.keras.layers.Layer): 503 | """A Dense Layer using 3D kernel with tf.einsum implementation. 504 | 505 | Attributes: 506 | num_attention_heads: An integer, number of attention heads for each 507 | multihead attention layer. 508 | size_per_head: An integer, hidden size per attention head. 509 | hidden_size: An integer, dimension of the hidden layer. 510 | kernel_initializer: An initializer for the kernel weight. 511 | bias_initializer: An initializer for the bias. 512 | activation: An activation function to use. If nothing is specified, no 513 | activation is applied. 514 | use_bias: A bool, whether the layer uses a bias. 515 | output_projection: A bool, whether the Dense3D layer is used for output 516 | linear projection. 517 | backward_compatible: A bool, whether the variables shape are compatible 518 | with checkpoints converted from TF 1.x. 519 | """ 520 | 521 | def __init__(self, 522 | num_attention_heads=12, 523 | size_per_head=72, 524 | kernel_initializer=None, 525 | bias_initializer="zeros", 526 | activation=None, 527 | use_bias=True, 528 | output_projection=False, 529 | backward_compatible=False, 530 | **kwargs): 531 | """Inits Dense3D.""" 532 | super(Dense3D, self).__init__(**kwargs) 533 | self.num_attention_heads = num_attention_heads 534 | self.size_per_head = size_per_head 535 | self.hidden_size = num_attention_heads * size_per_head 536 | self.kernel_initializer = kernel_initializer 537 | self.bias_initializer = bias_initializer 538 | self.activation = activation 539 | self.use_bias = use_bias 540 | self.output_projection = output_projection 541 | self.backward_compatible = backward_compatible 542 | 543 | @property 544 | def compatible_kernel_shape(self): 545 | if self.output_projection: 546 | return [self.hidden_size, self.hidden_size] 547 | return [self.last_dim, self.hidden_size] 548 | 549 | @property 550 | def compatible_bias_shape(self): 551 | return [self.hidden_size] 552 | 553 | @property 554 | def kernel_shape(self): 555 | if self.output_projection: 556 | return [self.num_attention_heads, self.size_per_head, self.hidden_size] 557 | return [self.last_dim, self.num_attention_heads, self.size_per_head] 558 | 559 | @property 560 | def bias_shape(self): 561 | if self.output_projection: 562 | return [self.hidden_size] 563 | return [self.num_attention_heads, self.size_per_head] 564 | 565 | def build(self, input_shape): 566 | """Implements build() for the layer.""" 567 | dtype = tf.as_dtype(self.dtype or tf.keras.backend.floatx()) 568 | if not (dtype.is_floating or dtype.is_complex): 569 | raise TypeError("Unable to build `Dense3D` layer with non-floating " 570 | "point (and non-complex) dtype %s" % (dtype,)) 571 | input_shape = tf.TensorShape(input_shape) 572 | if tf.compat.dimension_value(input_shape[-1]) is None: 573 | raise ValueError("The last dimension of the inputs to `Dense3D` " 574 | "should be defined. Found `None`.") 575 | self.last_dim = tf.compat.dimension_value(input_shape[-1]) 576 | self.input_spec = tf.keras.layers.InputSpec( 577 | min_ndim=3, axes={-1: self.last_dim}) 578 | # Determines variable shapes. 579 | if self.backward_compatible: 580 | kernel_shape = self.compatible_kernel_shape 581 | bias_shape = self.compatible_bias_shape 582 | else: 583 | kernel_shape = self.kernel_shape 584 | bias_shape = self.bias_shape 585 | 586 | self.kernel = self.add_weight( 587 | "kernel", 588 | shape=kernel_shape, 589 | initializer=self.kernel_initializer, 590 | dtype=self.dtype, 591 | trainable=True) 592 | if self.use_bias: 593 | self.bias = self.add_weight( 594 | "bias", 595 | shape=bias_shape, 596 | initializer=self.bias_initializer, 597 | dtype=self.dtype, 598 | trainable=True) 599 | else: 600 | self.bias = None 601 | super(Dense3D, self).build(input_shape) 602 | 603 | def call(self, inputs): 604 | """Implements ``call()`` for Dense3D. 605 | 606 | Args: 607 | inputs: A float tensor of shape [batch_size, sequence_length, hidden_size] 608 | when output_projection is False, otherwise a float tensor of shape 609 | [batch_size, sequence_length, num_heads, dim_per_head]. 610 | 611 | Returns: 612 | The projected tensor with shape [batch_size, sequence_length, num_heads, 613 | dim_per_head] when output_projection is False, otherwise [batch_size, 614 | sequence_length, hidden_size]. 615 | """ 616 | if self.backward_compatible: 617 | kernel = tf.keras.backend.reshape(self.kernel, self.kernel_shape) 618 | bias = (tf.keras.backend.reshape(self.bias, self.bias_shape) 619 | if self.use_bias else None) 620 | else: 621 | kernel = self.kernel 622 | bias = self.bias 623 | 624 | if self.output_projection: 625 | ret = tf.einsum("abcd,cde->abe", inputs, kernel) 626 | else: 627 | ret = tf.einsum("abc,cde->abde", inputs, kernel) 628 | if self.use_bias: 629 | ret += bias 630 | if self.activation is not None: 631 | return self.activation(ret) 632 | return ret 633 | 634 | 635 | class Dense2DProjection(tf.keras.layers.Layer): 636 | """A 2D projection layer with tf.einsum implementation.""" 637 | 638 | def __init__(self, 639 | output_size, 640 | kernel_initializer=None, 641 | bias_initializer="zeros", 642 | activation=None, 643 | fp32_activation=False, 644 | **kwargs): 645 | super(Dense2DProjection, self).__init__(**kwargs) 646 | self.output_size = output_size 647 | self.kernel_initializer = kernel_initializer 648 | self.bias_initializer = bias_initializer 649 | self.activation = activation 650 | self.fp32_activation = fp32_activation 651 | 652 | def build(self, input_shape): 653 | """Implements build() for the layer.""" 654 | dtype = tf.as_dtype(self.dtype or tf.keras.backend.floatx()) 655 | if not (dtype.is_floating or dtype.is_complex): 656 | raise TypeError("Unable to build `Dense2DProjection` layer with " 657 | "non-floating point (and non-complex) " 658 | "dtype %s" % (dtype,)) 659 | input_shape = tf.TensorShape(input_shape) 660 | if tf.compat.dimension_value(input_shape[-1]) is None: 661 | raise ValueError("The last dimension of the inputs to " 662 | "`Dense2DProjection` should be defined. " 663 | "Found `None`.") 664 | last_dim = tf.compat.dimension_value(input_shape[-1]) 665 | self.input_spec = tf.keras.layers.InputSpec(min_ndim=3, axes={-1: last_dim}) 666 | self.kernel = self.add_weight( 667 | "kernel", 668 | shape=[last_dim, self.output_size], 669 | initializer=self.kernel_initializer, 670 | dtype=self.dtype, 671 | trainable=True) 672 | self.bias = self.add_weight( 673 | "bias", 674 | shape=[self.output_size], 675 | initializer=self.bias_initializer, 676 | dtype=self.dtype, 677 | trainable=True) 678 | super(Dense2DProjection, self).build(input_shape) 679 | 680 | def call(self, inputs): 681 | """Implements call() for Dense2DProjection. 682 | 683 | Args: 684 | inputs: float Tensor of shape [batch, from_seq_length, 685 | num_attention_heads, size_per_head]. 686 | 687 | Returns: 688 | A 3D Tensor. 689 | """ 690 | ret = tf.einsum("abc,cd->abd", inputs, self.kernel) 691 | ret += self.bias 692 | if self.activation is not None: 693 | if self.dtype == tf.float16 and self.fp32_activation: 694 | ret = tf.cast(ret, tf.float32) 695 | return self.activation(ret) 696 | return ret 697 | 698 | 699 | class TransformerBlock(tf.keras.layers.Layer): 700 | """Single transformer layer. 701 | 702 | It has two sub-layers. The first is a multi-head self-attention mechanism, and 703 | the second is a positionwise fully connected feed-forward network. 704 | """ 705 | 706 | def __init__(self, 707 | hidden_size=768, 708 | num_attention_heads=12, 709 | intermediate_size=3072, 710 | intermediate_activation="gelu", 711 | hidden_dropout_prob=0.0, 712 | attention_probs_dropout_prob=0.0, 713 | initializer_range=0.02, 714 | backward_compatible=False, 715 | float_type=tf.float32, 716 | **kwargs): 717 | super(TransformerBlock, self).__init__(**kwargs) 718 | self.hidden_size = hidden_size 719 | self.num_attention_heads = num_attention_heads 720 | self.intermediate_size = intermediate_size 721 | self.intermediate_activation = tf_utils.get_activation( 722 | intermediate_activation) 723 | self.hidden_dropout_prob = hidden_dropout_prob 724 | self.attention_probs_dropout_prob = attention_probs_dropout_prob 725 | self.initializer_range = initializer_range 726 | self.backward_compatible = backward_compatible 727 | self.float_type = float_type 728 | 729 | if self.hidden_size % self.num_attention_heads != 0: 730 | raise ValueError( 731 | "The hidden size (%d) is not a multiple of the number of attention " 732 | "heads (%d)" % (self.hidden_size, self.num_attention_heads)) 733 | self.attention_head_size = int(self.hidden_size / self.num_attention_heads) 734 | 735 | def build(self, unused_input_shapes): 736 | """Implements build() for the layer.""" 737 | self.attention_layer = Attention( 738 | num_attention_heads=self.num_attention_heads, 739 | size_per_head=self.attention_head_size, 740 | attention_probs_dropout_prob=self.attention_probs_dropout_prob, 741 | initializer_range=self.initializer_range, 742 | backward_compatible=self.backward_compatible, 743 | name="self_attention") 744 | self.attention_output_dense = Dense3D( 745 | num_attention_heads=self.num_attention_heads, 746 | size_per_head=int(self.hidden_size / self.num_attention_heads), 747 | kernel_initializer=get_initializer(self.initializer_range), 748 | output_projection=True, 749 | backward_compatible=self.backward_compatible, 750 | name="self_attention_output") 751 | self.attention_dropout = tf.keras.layers.Dropout( 752 | rate=self.hidden_dropout_prob) 753 | self.attention_layer_norm = ( 754 | tf.keras.layers.LayerNormalization( 755 | name="self_attention_layer_norm", axis=-1, epsilon=1e-12, 756 | # We do layer norm in float32 for numeric stability. 757 | dtype=tf.float32)) 758 | self.intermediate_dense = Dense2DProjection( 759 | output_size=self.intermediate_size, 760 | kernel_initializer=get_initializer(self.initializer_range), 761 | activation=self.intermediate_activation, 762 | # Uses float32 so that gelu activation is done in float32. 763 | fp32_activation=True, 764 | name="intermediate") 765 | self.output_dense = Dense2DProjection( 766 | output_size=self.hidden_size, 767 | kernel_initializer=get_initializer(self.initializer_range), 768 | name="output") 769 | self.output_dropout = tf.keras.layers.Dropout(rate=self.hidden_dropout_prob) 770 | self.output_layer_norm = tf.keras.layers.LayerNormalization( 771 | name="output_layer_norm", axis=-1, epsilon=1e-12, dtype=tf.float32) 772 | super(TransformerBlock, self).build(unused_input_shapes) 773 | 774 | def common_layers(self): 775 | """Explicitly gets all layer objects inside a Transformer encoder block.""" 776 | return [ 777 | self.attention_layer, self.attention_output_dense, 778 | self.attention_dropout, self.attention_layer_norm, 779 | self.intermediate_dense, self.output_dense, self.output_dropout, 780 | self.output_layer_norm 781 | ] 782 | 783 | def __call__(self, input_tensor, attention_mask=None, **kwargs): 784 | inputs = tf_utils.pack_inputs([input_tensor, attention_mask]) 785 | return super(TransformerBlock, self).__call__(inputs, **kwargs) 786 | 787 | def call(self, inputs, **kwargs): 788 | """Implements call() for the layer.""" 789 | (input_tensor, attention_mask) = tf_utils.unpack_inputs(inputs) 790 | attention_output = self.attention_layer( 791 | from_tensor=input_tensor, 792 | to_tensor=input_tensor, 793 | attention_mask=attention_mask,**kwargs) 794 | attention_output = self.attention_output_dense(attention_output) 795 | attention_output = self.attention_dropout(attention_output,training=kwargs.get('training', False)) 796 | # Use float32 in keras layer norm and the gelu activation in the 797 | # intermediate dense layer for numeric stability 798 | attention_output = self.attention_layer_norm(input_tensor + 799 | attention_output) 800 | if self.float_type == tf.float16: 801 | attention_output = tf.cast(attention_output, tf.float16) 802 | intermediate_output = self.intermediate_dense(attention_output) 803 | if self.float_type == tf.float16: 804 | intermediate_output = tf.cast(intermediate_output, tf.float16) 805 | layer_output = self.output_dense(intermediate_output) 806 | layer_output = self.output_dropout(layer_output,training=kwargs.get('training', False)) 807 | # Use float32 in keras layer norm for numeric stability 808 | layer_output = self.output_layer_norm(layer_output + attention_output) 809 | if self.float_type == tf.float16: 810 | layer_output = tf.cast(layer_output, tf.float16) 811 | return layer_output 812 | 813 | 814 | class Transformer(tf.keras.layers.Layer): 815 | """Multi-headed, multi-layer Transformer from "Attention is All You Need". 816 | 817 | This is almost an exact implementation of the original Transformer encoder. 818 | 819 | See the original paper: 820 | https://arxiv.org/abs/1706.03762 821 | 822 | Also see: 823 | https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/models/transformer.py 824 | """ 825 | 826 | def __init__(self, 827 | num_hidden_layers=12, 828 | hidden_size=768, 829 | num_attention_heads=12, 830 | intermediate_size=3072, 831 | intermediate_activation="gelu", 832 | hidden_dropout_prob=0.0, 833 | attention_probs_dropout_prob=0.0, 834 | initializer_range=0.02, 835 | backward_compatible=False, 836 | float_type=tf.float32, 837 | **kwargs): 838 | super(Transformer, self).__init__(**kwargs) 839 | self.num_hidden_layers = num_hidden_layers 840 | self.hidden_size = hidden_size 841 | self.num_attention_heads = num_attention_heads 842 | self.intermediate_size = intermediate_size 843 | self.intermediate_activation = tf_utils.get_activation( 844 | intermediate_activation) 845 | self.hidden_dropout_prob = hidden_dropout_prob 846 | self.attention_probs_dropout_prob = attention_probs_dropout_prob 847 | self.initializer_range = initializer_range 848 | self.backward_compatible = backward_compatible 849 | self.float_type = float_type 850 | 851 | def build(self, unused_input_shapes): 852 | """Implements build() for the layer.""" 853 | self.layers = [] 854 | for i in range(self.num_hidden_layers): 855 | self.layers.append( 856 | TransformerBlock( 857 | hidden_size=self.hidden_size, 858 | num_attention_heads=self.num_attention_heads, 859 | intermediate_size=self.intermediate_size, 860 | intermediate_activation=self.intermediate_activation, 861 | hidden_dropout_prob=self.hidden_dropout_prob, 862 | attention_probs_dropout_prob=self.attention_probs_dropout_prob, 863 | initializer_range=self.initializer_range, 864 | backward_compatible=self.backward_compatible, 865 | float_type=self.float_type, 866 | name=("layer_%d" % i))) 867 | super(Transformer, self).build(unused_input_shapes) 868 | 869 | def __call__(self, input_tensor, attention_mask=None, **kwargs): 870 | inputs = tf_utils.pack_inputs([input_tensor, attention_mask]) 871 | return super(Transformer, self).__call__(inputs=inputs, **kwargs) 872 | 873 | def call(self, inputs, return_all_layers=False, **kwargs): 874 | """Implements call() for the layer. 875 | 876 | Args: 877 | inputs: packed inputs. 878 | return_all_layers: bool, whether to return outputs of all layers inside 879 | encoders. 880 | Returns: 881 | Output tensor of the last layer or a list of output tensors. 882 | """ 883 | unpacked_inputs = tf_utils.unpack_inputs(inputs) 884 | input_tensor = unpacked_inputs[0] 885 | attention_mask = unpacked_inputs[1] 886 | output_tensor = input_tensor 887 | 888 | all_layer_outputs = [] 889 | for layer in self.layers: 890 | output_tensor = layer(output_tensor, attention_mask,**kwargs) 891 | all_layer_outputs.append(output_tensor) 892 | 893 | if return_all_layers: 894 | return all_layer_outputs 895 | 896 | return all_layer_outputs[-1] 897 | 898 | 899 | def get_initializer(initializer_range=0.02): 900 | """Creates a `tf.initializers.truncated_normal` with the given range. 901 | 902 | Args: 903 | initializer_range: float, initializer range for stddev. 904 | 905 | Returns: 906 | TruncatedNormal initializer with stddev = `initializer_range`. 907 | """ 908 | return tf.keras.initializers.TruncatedNormal(stddev=initializer_range) 909 | 910 | 911 | def create_attention_mask_from_input_mask(from_tensor, to_mask): 912 | """Create 3D attention mask from a 2D tensor mask. 913 | 914 | Args: 915 | from_tensor: 2D or 3D Tensor of shape [batch_size, from_seq_length, ...]. 916 | to_mask: int32 Tensor of shape [batch_size, to_seq_length]. 917 | 918 | Returns: 919 | float Tensor of shape [batch_size, from_seq_length, to_seq_length]. 920 | """ 921 | from_shape = tf_utils.get_shape_list(from_tensor, expected_rank=[2, 3]) 922 | batch_size = from_shape[0] 923 | from_seq_length = from_shape[1] 924 | 925 | to_shape = tf_utils.get_shape_list(to_mask, expected_rank=2) 926 | to_seq_length = to_shape[1] 927 | 928 | to_mask = tf.cast( 929 | tf.reshape(to_mask, [batch_size, 1, to_seq_length]), 930 | dtype=from_tensor.dtype) 931 | 932 | # We don't assume that `from_tensor` is a mask (although it could be). We 933 | # don't actually care if we attend *from* padding tokens (only *to* padding) 934 | # tokens so we create a tensor of all ones. 935 | # 936 | # `broadcast_ones` = [batch_size, from_seq_length, 1] 937 | broadcast_ones = tf.ones( 938 | shape=[batch_size, from_seq_length, 1], dtype=from_tensor.dtype) 939 | 940 | # Here we broadcast along two dimensions to create the mask. 941 | mask = broadcast_ones * to_mask 942 | 943 | return mask 944 | --------------------------------------------------------------------------------