├── albert_config └── albert_config_large.json ├── README.md ├── bert_utils.py ├── optimization.py ├── tokenization.py ├── run_classifier.py ├── run_squad.py └── modeling.py /albert_config/albert_config_large.json: -------------------------------------------------------------------------------- 1 | { 2 | "attention_probs_dropout_prob": 0.0, 3 | "directionality": "bidi", 4 | "hidden_act": "gelu", 5 | "hidden_dropout_prob": 0.0, 6 | "hidden_size": 1024, 7 | "embedding_size": 128, 8 | "initializer_range": 0.02, 9 | "intermediate_size": 4096, 10 | "max_position_embeddings": 512, 11 | "num_attention_heads": 16, 12 | "num_hidden_layers": 24, 13 | 14 | "pooler_fc_size": 768, 15 | "pooler_num_attention_heads": 12, 16 | "pooler_num_fc_layers": 3, 17 | "pooler_size_per_head": 128, 18 | "pooler_type": "first_token_transform", 19 | "type_vocab_size": 2, 20 | "vocab_size": 47473 21 | } 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # KalBert 2 | Korean ALBERT (A Lite BERT for Self-supervised Learning of Language Representations) language model 3 | 4 | Training based on albert_zh (https://github.com/brightmart/albert_zh) 5 | 6 | 512 sequences, Large KalBert: 7 | https://drive.google.com/drive/folders/1a_yZIidugit3TxF__f8LSRPc8gfO2CV-?usp=sharing 8 | 9 | * Training data: ~6GB 10 | - Korean Wiki 11 | - KAIST Book corpus 12 | - Saejong corpus 13 | 14 | * Morph tokenizing without tag + BPE 15 | - (e.g. 이순신은 조선 중기의 무신이다. -> 이순신 은 조선 중기 의 무신 이 다 .) 16 | 17 | * Training steps: 191,000 (128 batch size) 18 | 19 | * KorQuAD v 1.0 Dev set 20 | - f1: 90.01, em: 81.26 21 | 22 | -------------------------------------------------------------------------------- /bert_utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import collections 6 | import copy 7 | import json 8 | import math 9 | import re 10 | import six 11 | import tensorflow as tf 12 | 13 | def get_shape_list(tensor, expected_rank=None, name=None): 14 | """Returns a list of the shape of tensor, preferring static dimensions. 15 | 16 | Args: 17 | tensor: A tf.Tensor object to find the shape of. 18 | expected_rank: (optional) int. The expected rank of `tensor`. If this is 19 | specified and the `tensor` has a different rank, and exception will be 20 | thrown. 21 | name: Optional name of the tensor for the error message. 22 | 23 | Returns: 24 | A list of dimensions of the shape of tensor. All static dimensions will 25 | be returned as python integers, and dynamic dimensions will be returned 26 | as tf.Tensor scalars. 27 | """ 28 | if name is None: 29 | name = tensor.name 30 | 31 | if expected_rank is not None: 32 | assert_rank(tensor, expected_rank, name) 33 | 34 | shape = tensor.shape.as_list() 35 | 36 | non_static_indexes = [] 37 | for (index, dim) in enumerate(shape): 38 | if dim is None: 39 | non_static_indexes.append(index) 40 | 41 | if not non_static_indexes: 42 | return shape 43 | 44 | dyn_shape = tf.shape(tensor) 45 | for index in non_static_indexes: 46 | shape[index] = dyn_shape[index] 47 | return shape 48 | 49 | def reshape_to_matrix(input_tensor): 50 | """Reshapes a >= rank 2 tensor to a rank 2 tensor (i.e., a matrix).""" 51 | ndims = input_tensor.shape.ndims 52 | if ndims < 2: 53 | raise ValueError("Input tensor must have at least rank 2. Shape = %s" % 54 | (input_tensor.shape)) 55 | if ndims == 2: 56 | return input_tensor 57 | 58 | width = input_tensor.shape[-1] 59 | output_tensor = tf.reshape(input_tensor, [-1, width]) 60 | return output_tensor 61 | 62 | def reshape_from_matrix(output_tensor, orig_shape_list): 63 | """Reshapes a rank 2 tensor back to its original rank >= 2 tensor.""" 64 | if len(orig_shape_list) == 2: 65 | return output_tensor 66 | 67 | output_shape = get_shape_list(output_tensor) 68 | 69 | orig_dims = orig_shape_list[0:-1] 70 | width = output_shape[-1] 71 | 72 | return tf.reshape(output_tensor, orig_dims + [width]) 73 | 74 | def assert_rank(tensor, expected_rank, name=None): 75 | """Raises an exception if the tensor rank is not of the expected rank. 76 | 77 | Args: 78 | tensor: A tf.Tensor to check the rank of. 79 | expected_rank: Python integer or list of integers, expected rank. 80 | name: Optional name of the tensor for the error message. 81 | 82 | Raises: 83 | ValueError: If the expected shape doesn't match the actual shape. 84 | """ 85 | if name is None: 86 | name = tensor.name 87 | 88 | expected_rank_dict = {} 89 | if isinstance(expected_rank, six.integer_types): 90 | expected_rank_dict[expected_rank] = True 91 | else: 92 | for x in expected_rank: 93 | expected_rank_dict[x] = True 94 | 95 | actual_rank = tensor.shape.ndims 96 | if actual_rank not in expected_rank_dict: 97 | scope_name = tf.get_variable_scope().name 98 | raise ValueError( 99 | "For the tensor `%s` in scope `%s`, the actual rank " 100 | "`%d` (shape = %s) is not equal to the expected rank `%s`" % 101 | (name, scope_name, actual_rank, str(tensor.shape), str(expected_rank))) 102 | 103 | def gather_indexes(sequence_tensor, positions): 104 | """Gathers the vectors at the specific positions over a minibatch.""" 105 | sequence_shape = get_shape_list(sequence_tensor, expected_rank=3) 106 | batch_size = sequence_shape[0] 107 | seq_length = sequence_shape[1] 108 | width = sequence_shape[2] 109 | 110 | flat_offsets = tf.reshape( 111 | tf.range(0, batch_size, dtype=tf.int32) * seq_length, [-1, 1]) 112 | flat_positions = tf.reshape(positions + flat_offsets, [-1]) 113 | flat_sequence_tensor = tf.reshape(sequence_tensor, 114 | [batch_size * seq_length, width]) 115 | output_tensor = tf.gather(flat_sequence_tensor, flat_positions) 116 | return output_tensor 117 | 118 | # add sequence mask for: 119 | # 1. random shuffle lm modeling---xlnet with random shuffled input 120 | # 2. left2right and right2left language modeling 121 | # 3. conditional generation 122 | def generate_seq2seq_mask(attention_mask, mask_sequence, seq_type, **kargs): 123 | if seq_type == 'seq2seq': 124 | if mask_sequence is not None: 125 | seq_shape = get_shape_list(mask_sequence, expected_rank=2) 126 | seq_len = seq_shape[1] 127 | ones = tf.ones((1, seq_len, seq_len)) 128 | a_mask = tf.matrix_band_part(ones, -1, 0) 129 | s_ex12 = tf.expand_dims(tf.expand_dims(mask_sequence, 1), 2) 130 | s_ex13 = tf.expand_dims(tf.expand_dims(mask_sequence, 1), 3) 131 | a_mask = (1 - s_ex13) * (1 - s_ex12) + s_ex13 * a_mask 132 | # generate mask of batch x seq_len x seq_len 133 | a_mask = tf.reshape(a_mask, (-1, seq_len, seq_len)) 134 | out_mask = attention_mask * a_mask 135 | else: 136 | ones = tf.ones_like(attention_mask[:1]) 137 | mask = (tf.matrix_band_part(ones, -1, 0)) 138 | out_mask = attention_mask * mask 139 | else: 140 | out_mask = attention_mask 141 | 142 | return out_mask 143 | -------------------------------------------------------------------------------- /optimization.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Functions and classes related to optimization (weight updates).""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import re 22 | import tensorflow as tf 23 | 24 | 25 | def create_optimizer(loss, init_lr, num_train_steps, num_warmup_steps, use_tpu): 26 | """Creates an optimizer training op.""" 27 | global_step = tf.train.get_or_create_global_step() 28 | 29 | learning_rate = tf.constant(value=init_lr, shape=[], dtype=tf.float32) 30 | 31 | # Implements linear decay of the learning rate. 32 | learning_rate = tf.train.polynomial_decay( 33 | learning_rate, 34 | global_step, 35 | num_train_steps, 36 | end_learning_rate=0.0, 37 | power=1.0, 38 | cycle=False) 39 | 40 | # Implements linear warmup. I.e., if global_step < num_warmup_steps, the 41 | # learning rate will be `global_step/num_warmup_steps * init_lr`. 42 | if num_warmup_steps: 43 | global_steps_int = tf.cast(global_step, tf.int32) 44 | warmup_steps_int = tf.constant(num_warmup_steps, dtype=tf.int32) 45 | 46 | global_steps_float = tf.cast(global_steps_int, tf.float32) 47 | warmup_steps_float = tf.cast(warmup_steps_int, tf.float32) 48 | 49 | warmup_percent_done = global_steps_float / warmup_steps_float 50 | warmup_learning_rate = init_lr * warmup_percent_done 51 | 52 | is_warmup = tf.cast(global_steps_int < warmup_steps_int, tf.float32) 53 | learning_rate = ( 54 | (1.0 - is_warmup) * learning_rate + is_warmup * warmup_learning_rate) 55 | 56 | # It is recommended that you use this optimizer for fine tuning, since this 57 | # is how the model was trained (note that the Adam m/v variables are NOT 58 | # loaded from init_checkpoint.) 59 | optimizer = AdamWeightDecayOptimizer( 60 | learning_rate=learning_rate, 61 | weight_decay_rate=0.01, 62 | beta_1=0.9, 63 | beta_2=0.999, 64 | epsilon=1e-6, 65 | exclude_from_weight_decay=["LayerNorm", "layer_norm", "bias"]) 66 | 67 | if use_tpu: 68 | optimizer = tf.contrib.tpu.CrossShardOptimizer(optimizer) 69 | 70 | tvars = tf.trainable_variables() 71 | grads = tf.gradients(loss, tvars) 72 | 73 | # This is how the model was pre-trained. 74 | (grads, _) = tf.clip_by_global_norm(grads, clip_norm=1.0) 75 | 76 | train_op = optimizer.apply_gradients( 77 | zip(grads, tvars), global_step=global_step) 78 | 79 | # Normally the global step update is done inside of `apply_gradients`. 80 | # However, `AdamWeightDecayOptimizer` doesn't do this. But if you use 81 | # a different optimizer, you should probably take this line out. 82 | new_global_step = global_step + 1 83 | train_op = tf.group(train_op, [global_step.assign(new_global_step)]) 84 | return train_op 85 | 86 | 87 | class AdamWeightDecayOptimizer(tf.compat.v1.train.Optimizer): 88 | """A basic Adam optimizer that includes "correct" L2 weight decay.""" 89 | 90 | def __init__(self, 91 | learning_rate, 92 | weight_decay_rate=0.0, 93 | beta_1=0.9, 94 | beta_2=0.999, 95 | epsilon=1e-6, 96 | exclude_from_weight_decay=None, 97 | name="AdamWeightDecayOptimizer"): 98 | """Constructs a AdamWeightDecayOptimizer.""" 99 | super(AdamWeightDecayOptimizer, self).__init__(False, name) 100 | 101 | self.learning_rate = learning_rate 102 | self.weight_decay_rate = weight_decay_rate 103 | self.beta_1 = beta_1 104 | self.beta_2 = beta_2 105 | self.epsilon = epsilon 106 | self.exclude_from_weight_decay = exclude_from_weight_decay 107 | 108 | def apply_gradients(self, grads_and_vars, global_step=None, name=None): 109 | """See base class.""" 110 | assignments = [] 111 | for (grad, param) in grads_and_vars: 112 | if grad is None or param is None: 113 | continue 114 | 115 | param_name = self._get_variable_name(param.name) 116 | 117 | m = tf.get_variable( 118 | name=param_name + "/adam_m", 119 | shape=param.shape.as_list(), 120 | dtype=tf.float32, 121 | trainable=False, 122 | initializer=tf.zeros_initializer()) 123 | v = tf.get_variable( 124 | name=param_name + "/adam_v", 125 | shape=param.shape.as_list(), 126 | dtype=tf.float32, 127 | trainable=False, 128 | initializer=tf.zeros_initializer()) 129 | 130 | # Standard Adam update. 131 | next_m = ( 132 | tf.multiply(self.beta_1, m) + tf.multiply(1.0 - self.beta_1, grad)) 133 | next_v = ( 134 | tf.multiply(self.beta_2, v) + tf.multiply(1.0 - self.beta_2, 135 | tf.square(grad))) 136 | 137 | update = next_m / (tf.sqrt(next_v) + self.epsilon) 138 | 139 | # Just adding the square of the weights to the loss function is *not* 140 | # the correct way of using L2 regularization/weight decay with Adam, 141 | # since that will interact with the m and v parameters in strange ways. 142 | # 143 | # Instead we want ot decay the weights in a manner that doesn't interact 144 | # with the m/v parameters. This is equivalent to adding the square 145 | # of the weights to the loss with plain (non-momentum) SGD. 146 | if self._do_use_weight_decay(param_name): 147 | update += self.weight_decay_rate * param 148 | 149 | update_with_lr = self.learning_rate * update 150 | 151 | next_param = param - update_with_lr 152 | 153 | assignments.extend( 154 | [param.assign(next_param), 155 | m.assign(next_m), 156 | v.assign(next_v)]) 157 | return tf.group(*assignments, name=name) 158 | 159 | def _do_use_weight_decay(self, param_name): 160 | """Whether to use L2 weight decay for `param_name`.""" 161 | if not self.weight_decay_rate: 162 | return False 163 | if self.exclude_from_weight_decay: 164 | for r in self.exclude_from_weight_decay: 165 | if re.search(r, param_name) is not None: 166 | return False 167 | return True 168 | 169 | def _get_variable_name(self, param_name): 170 | """Get the variable name from the tensor name.""" 171 | m = re.match("^(.*):\\d+$", param_name) 172 | if m is not None: 173 | param_name = m.group(1) 174 | return param_name 175 | -------------------------------------------------------------------------------- /tokenization.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # -*- coding:utf-8 -*- 3 | 4 | # coding=utf-8 5 | # Copyright 2018 The Google AI Language Team Authors. 6 | # 7 | # Licensed under the Apache License, Version 2.0 (the "License"); 8 | # you may not use this file except in compliance with the License. 9 | # You may obtain a copy of the License at 10 | # 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | # 13 | # Unless required by applicable law or agreed to in writing, software 14 | # distributed under the License is distributed on an "AS IS" BASIS, 15 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16 | # See the License for the specific language governing permissions and 17 | # limitations under the License. 18 | """Tokenization classes.""" 19 | 20 | from __future__ import absolute_import 21 | from __future__ import division 22 | from __future__ import print_function 23 | 24 | import collections 25 | import unicodedata 26 | import six 27 | import tensorflow as tf 28 | 29 | 30 | def convert_to_unicode(text): 31 | """Converts `text` to Unicode (if it's not already), assuming utf-8 input.""" 32 | if six.PY3: 33 | if isinstance(text, str): 34 | return text 35 | elif isinstance(text, bytes): 36 | return text.decode("utf-8", "ignore") 37 | else: 38 | raise ValueError("Unsupported string type: %s" % (type(text))) 39 | elif six.PY2: 40 | if isinstance(text, str): 41 | return text.decode("utf-8", "ignore") 42 | elif isinstance(text, unicode): 43 | return text 44 | else: 45 | raise ValueError("Unsupported string type: %s" % (type(text))) 46 | else: 47 | raise ValueError("Not running on Python2 or Python 3?") 48 | 49 | 50 | def printable_text(text): 51 | """Returns text encoded in a way suitable for print or `tf.logging`.""" 52 | 53 | # These functions want `str` for both Python2 and Python3, but in one case 54 | # it's a Unicode string and in the other it's a byte string. 55 | if six.PY3: 56 | if isinstance(text, str): 57 | return text 58 | elif isinstance(text, bytes): 59 | return text.decode("utf-8", "ignore") 60 | else: 61 | raise ValueError("Unsupported string type: %s" % (type(text))) 62 | elif six.PY2: 63 | if isinstance(text, str): 64 | return text 65 | elif isinstance(text, unicode): 66 | return text.encode("utf-8") 67 | else: 68 | raise ValueError("Unsupported string type: %s" % (type(text))) 69 | else: 70 | raise ValueError("Not running on Python2 or Python 3?") 71 | 72 | 73 | def load_vocab(vocab_file): 74 | """Loads a vocabulary file into a dictionary.""" 75 | vocab = collections.OrderedDict() 76 | index = 0 77 | with open(vocab_file, 'r', encoding='utf-8') as reader: 78 | #with tf.gfile.GFile(vocab_file, "r") as reader: 79 | while True: 80 | token = convert_to_unicode(reader.readline()) 81 | if not token: 82 | break 83 | token = token.strip() 84 | vocab[token] = index 85 | index += 1 86 | return vocab 87 | 88 | 89 | def convert_by_vocab(vocab, items): 90 | """Converts a sequence of [tokens|ids] using the vocab.""" 91 | output = [] 92 | for item in items: 93 | output.append(vocab[item]) 94 | return output 95 | 96 | 97 | def convert_tokens_to_ids(vocab, tokens): 98 | return convert_by_vocab(vocab, tokens) 99 | 100 | 101 | def convert_ids_to_tokens(inv_vocab, ids): 102 | return convert_by_vocab(inv_vocab, ids) 103 | 104 | 105 | def whitespace_tokenize(text): 106 | """Runs basic whitespace cleaning and splitting on a peice of text.""" 107 | text = text.strip() 108 | if not text: 109 | return [] 110 | tokens = text.split() 111 | return tokens 112 | 113 | 114 | class FullTokenizer(object): 115 | """Runs end-to-end tokenziation.""" 116 | 117 | def __init__(self, vocab_file, do_lower_case=True): 118 | self.vocab = load_vocab(vocab_file) 119 | #for key, value in self.vocab.items(): 120 | # print(str(key) + ' - ' + str(type(key))) 121 | #print("vocab: " + str(self.vocab)) 122 | #print(self.vocab['이순신']) 123 | self.inv_vocab = {v: k for k, v in self.vocab.items()} 124 | self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case) 125 | self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab) 126 | 127 | def tokenize(self, text): 128 | split_tokens = [] 129 | for token in self.basic_tokenizer.tokenize(text): 130 | for sub_token in self.wordpiece_tokenizer.tokenize(token): 131 | split_tokens.append(sub_token) 132 | 133 | return split_tokens 134 | 135 | def convert_tokens_to_ids(self, tokens): 136 | return convert_by_vocab(self.vocab, tokens) 137 | 138 | def convert_ids_to_tokens(self, ids): 139 | return convert_by_vocab(self.inv_vocab, ids) 140 | 141 | 142 | class BasicTokenizer(object): 143 | """Runs basic tokenization (punctuation splitting, lower casing, etc.).""" 144 | 145 | def __init__(self, do_lower_case=True): 146 | """Constructs a BasicTokenizer. 147 | 148 | Args: 149 | do_lower_case: Whether to lower case the input. 150 | """ 151 | self.do_lower_case = do_lower_case 152 | 153 | def tokenize(self, text): 154 | """Tokenizes a piece of text.""" 155 | text = convert_to_unicode(text) 156 | text = self._clean_text(text) 157 | 158 | # This was added on November 1st, 2018 for the multilingual and Chinese 159 | # models. This is also applied to the English models now, but it doesn't 160 | # matter since the English models were not trained on any Chinese data 161 | # and generally don't have any Chinese data in them (there are Chinese 162 | # characters in the vocabulary because Wikipedia does have some Chinese 163 | # words in the English Wikipedia.). 164 | 165 | #text = self._tokenize_chinese_chars(text) 166 | orig_tokens = whitespace_tokenize(text) 167 | #split_tokens = [] 168 | #for token in orig_tokens: 169 | # if self.do_lower_case: 170 | # token = token.lower() 171 | # token = self._run_strip_accents(token) 172 | # split_tokens.extend(self._run_split_on_punc(token)) 173 | 174 | #output_tokens = whitespace_tokenize(" ".join(split_tokens)) 175 | #return output_tokens 176 | return orig_tokens 177 | 178 | def _run_strip_accents(self, text): 179 | """Strips accents from a piece of text.""" 180 | text = unicodedata.normalize("NFD", text) 181 | output = [] 182 | for char in text: 183 | cat = unicodedata.category(char) 184 | if cat == "Mn": 185 | continue 186 | output.append(char) 187 | return "".join(output) 188 | 189 | def _run_split_on_punc(self, text): 190 | """Splits punctuation on a piece of text.""" 191 | chars = list(text) 192 | i = 0 193 | start_new_word = True 194 | output = [] 195 | while i < len(chars): 196 | char = chars[i] 197 | if _is_punctuation(char): 198 | output.append([char]) 199 | start_new_word = True 200 | else: 201 | if start_new_word: 202 | output.append([]) 203 | start_new_word = False 204 | output[-1].append(char) 205 | i += 1 206 | 207 | return ["".join(x) for x in output] 208 | 209 | def _tokenize_chinese_chars(self, text): 210 | """Adds whitespace around any CJK character.""" 211 | output = [] 212 | for char in text: 213 | cp = ord(char) 214 | if self._is_chinese_char(cp): 215 | output.append(" ") 216 | output.append(char) 217 | output.append(" ") 218 | else: 219 | output.append(char) 220 | return "".join(output) 221 | 222 | def _is_chinese_char(self, cp): 223 | """Checks whether CP is the codepoint of a CJK character.""" 224 | # This defines a "chinese character" as anything in the CJK Unicode block: 225 | # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block) 226 | # 227 | # Note that the CJK Unicode block is NOT all Japanese and Korean characters, 228 | # despite its name. The modern Korean Hangul alphabet is a different block, 229 | # as is Japanese Hiragana and Katakana. Those alphabets are used to write 230 | # space-separated words, so they are not treated specially and handled 231 | # like the all of the other languages. 232 | if ((cp >= 0x4E00 and cp <= 0x9FFF) or # 233 | (cp >= 0x3400 and cp <= 0x4DBF) or # 234 | (cp >= 0x20000 and cp <= 0x2A6DF) or # 235 | (cp >= 0x2A700 and cp <= 0x2B73F) or # 236 | (cp >= 0x2B740 and cp <= 0x2B81F) or # 237 | (cp >= 0x2B820 and cp <= 0x2CEAF) or 238 | (cp >= 0xF900 and cp <= 0xFAFF) or # 239 | (cp >= 0x2F800 and cp <= 0x2FA1F)): # 240 | return True 241 | 242 | return False 243 | 244 | def _clean_text(self, text): 245 | """Performs invalid character removal and whitespace cleanup on text.""" 246 | output = [] 247 | for char in text: 248 | cp = ord(char) 249 | if cp == 0 or cp == 0xfffd or _is_control(char): 250 | continue 251 | if _is_whitespace(char): 252 | output.append(" ") 253 | else: 254 | output.append(char) 255 | return "".join(output) 256 | 257 | 258 | class WordpieceTokenizer(object): 259 | """Runs WordPiece tokenziation.""" 260 | 261 | def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=100): 262 | self.vocab = vocab 263 | #print(vocab) 264 | self.unk_token = unk_token 265 | self.max_input_chars_per_word = max_input_chars_per_word 266 | 267 | def tokenize(self, text): 268 | """Tokenizes a piece of text into its word pieces. 269 | 270 | This uses a greedy longest-match-first algorithm to perform tokenization 271 | using the given vocabulary. 272 | 273 | For example: 274 | input = "unaffable" 275 | output = ["un", "##aff", "##able"] 276 | 277 | Args: 278 | text: A single token or whitespace separated tokens. This should have 279 | already been passed through `BasicTokenizer. 280 | 281 | Returns: 282 | A list of wordpiece tokens. 283 | """ 284 | 285 | #text = convert_to_unicode(text) 286 | 287 | output_tokens = [] 288 | for token in whitespace_tokenize(text): 289 | #print("current token: " + token) 290 | #print(type(token)) 291 | #token = token.encode('utf-8') 292 | #print(self.vocab['이순신']) 293 | #print(self.vocab[token.replace(' ','')]) 294 | #if token in self.vocab: 295 | # print("----------in vocab") 296 | chars = list(token) 297 | if len(chars) > self.max_input_chars_per_word: 298 | output_tokens.append(self.unk_token) 299 | continue 300 | 301 | is_bad = False 302 | start = 0 303 | sub_tokens = [] 304 | while start < len(chars): 305 | end = len(chars) 306 | cur_substr = None 307 | while start < end: 308 | substr = "".join(chars[start:end]) 309 | if start > 0: 310 | substr = "##" + substr 311 | if substr in self.vocab: 312 | cur_substr = substr 313 | break 314 | end -= 1 315 | if cur_substr is None: 316 | is_bad = True 317 | break 318 | sub_tokens.append(cur_substr) 319 | start = end 320 | 321 | if is_bad: 322 | output_tokens.append(self.unk_token) 323 | else: 324 | output_tokens.extend(sub_tokens) 325 | return output_tokens 326 | 327 | 328 | def _is_whitespace(char): 329 | """Checks whether `chars` is a whitespace character.""" 330 | # \t, \n, and \r are technically contorl characters but we treat them 331 | # as whitespace since they are generally considered as such. 332 | if char == " " or char == "\t" or char == "\n" or char == "\r": 333 | return True 334 | cat = unicodedata.category(char) 335 | if cat == "Zs": 336 | return True 337 | return False 338 | 339 | 340 | def _is_control(char): 341 | """Checks whether `chars` is a control character.""" 342 | # These are technically control characters but we count them as whitespace 343 | # characters. 344 | if char == "\t" or char == "\n" or char == "\r": 345 | return False 346 | cat = unicodedata.category(char) 347 | if cat.startswith("C"): 348 | return True 349 | return False 350 | 351 | 352 | def _is_punctuation(char): 353 | """Checks whether `chars` is a punctuation character.""" 354 | cp = ord(char) 355 | # We treat all non-letter/number ASCII as punctuation. 356 | # Characters such as "^", "$", and "`" are not in the Unicode 357 | # Punctuation class but we treat them as punctuation anyways, for 358 | # consistency. 359 | if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or 360 | (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)): 361 | return True 362 | cat = unicodedata.category(char) 363 | if cat.startswith("P"): 364 | return True 365 | return False 366 | -------------------------------------------------------------------------------- /run_classifier.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """BERT finetuning runner.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import collections 22 | import csv 23 | import os 24 | import modeling 25 | import optimization 26 | import tokenization 27 | import tensorflow as tf 28 | 29 | flags = tf.compat.v1.app.flags 30 | 31 | FLAGS = flags.FLAGS 32 | 33 | ## Required parameters 34 | flags.DEFINE_string( 35 | "data_dir", None, 36 | "The input data dir. Should contain the .tsv files (or other data files) " 37 | "for the task.") 38 | 39 | flags.DEFINE_string( 40 | "bert_config_file", None, 41 | "The config json file corresponding to the pre-trained BERT model. " 42 | "This specifies the model architecture.") 43 | 44 | flags.DEFINE_string("task_name", None, "The name of the task to train.") 45 | 46 | flags.DEFINE_string("vocab_file", None, 47 | "The vocabulary file that the BERT model was trained on.") 48 | 49 | flags.DEFINE_string( 50 | "output_dir", None, 51 | "The output directory where the model checkpoints will be written.") 52 | 53 | ## Other parameters 54 | 55 | flags.DEFINE_string( 56 | "init_checkpoint", None, 57 | "Initial checkpoint (usually from a pre-trained BERT model).") 58 | 59 | flags.DEFINE_bool( 60 | "do_lower_case", True, 61 | "Whether to lower case the input text. Should be True for uncased " 62 | "models and False for cased models.") 63 | 64 | flags.DEFINE_integer( 65 | "max_seq_length", 128, 66 | "The maximum total input sequence length after WordPiece tokenization. " 67 | "Sequences longer than this will be truncated, and sequences shorter " 68 | "than this will be padded.") 69 | 70 | flags.DEFINE_bool("do_train", False, "Whether to run training.") 71 | 72 | flags.DEFINE_bool("do_eval", False, "Whether to run eval on the dev set.") 73 | 74 | flags.DEFINE_bool( 75 | "do_predict", False, 76 | "Whether to run the model in inference mode on the test set.") 77 | 78 | flags.DEFINE_integer("train_batch_size", 32, "Total batch size for training.") 79 | 80 | flags.DEFINE_integer("eval_batch_size", 8, "Total batch size for eval.") 81 | 82 | flags.DEFINE_integer("predict_batch_size", 8, "Total batch size for predict.") 83 | 84 | flags.DEFINE_float("learning_rate", 5e-5, "The initial learning rate for Adam.") 85 | 86 | flags.DEFINE_float("num_train_epochs", 3.0, 87 | "Total number of training epochs to perform.") 88 | 89 | flags.DEFINE_float( 90 | "warmup_proportion", 0.1, 91 | "Proportion of training to perform linear learning rate warmup for. " 92 | "E.g., 0.1 = 10% of training.") 93 | 94 | flags.DEFINE_integer("save_checkpoints_steps", 1000, 95 | "How often to save the model checkpoint.") 96 | 97 | flags.DEFINE_integer("iterations_per_loop", 1000, 98 | "How many steps to make in each estimator call.") 99 | 100 | flags.DEFINE_bool("use_tpu", False, "Whether to use TPU or GPU/CPU.") 101 | 102 | flags.DEFINE_string( 103 | "tpu_name", None, 104 | "The Cloud TPU to use for training. This should be either the name " 105 | "used when creating the Cloud TPU, or a grpc://ip.address.of.tpu:8470 " 106 | "url.") 107 | 108 | flags.DEFINE_string( 109 | "tpu_zone", None, 110 | "[Optional] GCE zone where the Cloud TPU is located in. If not " 111 | "specified, we will attempt to automatically detect the GCE project from " 112 | "metadata.") 113 | 114 | flags.DEFINE_string( 115 | "gcp_project", None, 116 | "[Optional] Project name for the Cloud TPU-enabled project. If not " 117 | "specified, we will attempt to automatically detect the GCE project from " 118 | "metadata.") 119 | 120 | flags.DEFINE_string("master", None, "[Optional] TensorFlow master URL.") 121 | 122 | flags.DEFINE_integer( 123 | "num_tpu_cores", 8, 124 | "Only used if `use_tpu` is True. Total number of TPU cores to use.") 125 | 126 | 127 | class InputExample(object): 128 | """A single training/test example for simple sequence classification.""" 129 | 130 | def __init__(self, guid, text_a, text_b=None, label=None): 131 | """Constructs a InputExample. 132 | 133 | Args: 134 | guid: Unique id for the example. 135 | text_a: string. The untokenized text of the first sequence. For single 136 | sequence tasks, only this sequence must be specified. 137 | text_b: (Optional) string. The untokenized text of the second sequence. 138 | Only must be specified for sequence pair tasks. 139 | label: (Optional) string. The label of the example. This should be 140 | specified for train and dev examples, but not for test examples. 141 | """ 142 | self.guid = guid 143 | self.text_a = text_a 144 | self.text_b = text_b 145 | self.label = label 146 | 147 | 148 | class PaddingInputExample(object): 149 | """Fake example so the num input examples is a multiple of the batch size. 150 | 151 | When running eval/predict on the TPU, we need to pad the number of examples 152 | to be a multiple of the batch size, because the TPU requires a fixed batch 153 | size. The alternative is to drop the last batch, which is bad because it means 154 | the entire output data won't be generated. 155 | 156 | We use this class instead of `None` because treating `None` as padding 157 | battches could cause silent errors. 158 | """ 159 | 160 | 161 | class InputFeatures(object): 162 | """A single set of features of data.""" 163 | 164 | def __init__(self, 165 | input_ids, 166 | input_mask, 167 | segment_ids, 168 | label_id, 169 | is_real_example=True): 170 | self.input_ids = input_ids 171 | self.input_mask = input_mask 172 | self.segment_ids = segment_ids 173 | self.label_id = label_id 174 | self.is_real_example = is_real_example 175 | 176 | 177 | class DataProcessor(object): 178 | """Base class for data converters for sequence classification data sets.""" 179 | 180 | def get_train_examples(self, data_dir): 181 | """Gets a collection of `InputExample`s for the train set.""" 182 | raise NotImplementedError() 183 | 184 | def get_dev_examples(self, data_dir): 185 | """Gets a collection of `InputExample`s for the dev set.""" 186 | raise NotImplementedError() 187 | 188 | def get_test_examples(self, data_dir): 189 | """Gets a collection of `InputExample`s for prediction.""" 190 | raise NotImplementedError() 191 | 192 | def get_labels(self): 193 | """Gets the list of labels for this data set.""" 194 | raise NotImplementedError() 195 | 196 | @classmethod 197 | def _read_tsv(cls, input_file, quotechar=None): 198 | """Reads a tab separated value file.""" 199 | with tf.gfile.Open(input_file, "r") as f: 200 | reader = csv.reader(f, delimiter="\t", quotechar=quotechar) 201 | lines = [] 202 | for line in reader: 203 | lines.append(line) 204 | return lines 205 | 206 | 207 | class XnliProcessor(DataProcessor): 208 | """Processor for the XNLI data set.""" 209 | 210 | def __init__(self): 211 | self.language = "zh" 212 | 213 | def get_train_examples(self, data_dir): 214 | """See base class.""" 215 | lines = self._read_tsv( 216 | os.path.join(data_dir, "multinli", 217 | "multinli.train.%s.tsv" % self.language)) 218 | examples = [] 219 | for (i, line) in enumerate(lines): 220 | if i == 0: 221 | continue 222 | guid = "train-%d" % (i) 223 | text_a = tokenization.convert_to_unicode(line[0]) 224 | text_b = tokenization.convert_to_unicode(line[1]) 225 | label = tokenization.convert_to_unicode(line[2]) 226 | if label == tokenization.convert_to_unicode("contradictory"): 227 | label = tokenization.convert_to_unicode("contradiction") 228 | examples.append( 229 | InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) 230 | return examples 231 | 232 | def get_dev_examples(self, data_dir): 233 | """See base class.""" 234 | lines = self._read_tsv(os.path.join(data_dir, "xnli.dev.tsv")) 235 | examples = [] 236 | for (i, line) in enumerate(lines): 237 | if i == 0: 238 | continue 239 | guid = "dev-%d" % (i) 240 | language = tokenization.convert_to_unicode(line[0]) 241 | if language != tokenization.convert_to_unicode(self.language): 242 | continue 243 | text_a = tokenization.convert_to_unicode(line[6]) 244 | text_b = tokenization.convert_to_unicode(line[7]) 245 | label = tokenization.convert_to_unicode(line[1]) 246 | examples.append( 247 | InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) 248 | return examples 249 | 250 | def get_labels(self): 251 | """See base class.""" 252 | return ["contradiction", "entailment", "neutral"] 253 | 254 | 255 | class MnliProcessor(DataProcessor): 256 | """Processor for the MultiNLI data set (GLUE version).""" 257 | 258 | def get_train_examples(self, data_dir): 259 | """See base class.""" 260 | return self._create_examples( 261 | self._read_tsv(os.path.join(data_dir, "train.tsv")), "train") 262 | 263 | def get_dev_examples(self, data_dir): 264 | """See base class.""" 265 | return self._create_examples( 266 | self._read_tsv(os.path.join(data_dir, "dev_matched.tsv")), 267 | "dev_matched") 268 | 269 | def get_test_examples(self, data_dir): 270 | """See base class.""" 271 | return self._create_examples( 272 | self._read_tsv(os.path.join(data_dir, "test_matched.tsv")), "test") 273 | 274 | def get_labels(self): 275 | """See base class.""" 276 | return ["contradiction", "entailment", "neutral"] 277 | 278 | def _create_examples(self, lines, set_type): 279 | """Creates examples for the training and dev sets.""" 280 | examples = [] 281 | for (i, line) in enumerate(lines): 282 | if i == 0: 283 | continue 284 | guid = "%s-%s" % (set_type, tokenization.convert_to_unicode(line[0])) 285 | text_a = tokenization.convert_to_unicode(line[8]) 286 | text_b = tokenization.convert_to_unicode(line[9]) 287 | if set_type == "test": 288 | label = "contradiction" 289 | else: 290 | label = tokenization.convert_to_unicode(line[-1]) 291 | examples.append( 292 | InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) 293 | return examples 294 | 295 | 296 | class MrpcProcessor(DataProcessor): 297 | """Processor for the MRPC data set (GLUE version).""" 298 | 299 | def get_train_examples(self, data_dir): 300 | """See base class.""" 301 | return self._create_examples( 302 | self._read_tsv(os.path.join(data_dir, "train.tsv")), "train") 303 | 304 | def get_dev_examples(self, data_dir): 305 | """See base class.""" 306 | return self._create_examples( 307 | self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev") 308 | 309 | def get_test_examples(self, data_dir): 310 | """See base class.""" 311 | return self._create_examples( 312 | self._read_tsv(os.path.join(data_dir, "test.tsv")), "test") 313 | 314 | def get_labels(self): 315 | """See base class.""" 316 | return ["0", "1"] 317 | 318 | def _create_examples(self, lines, set_type): 319 | """Creates examples for the training and dev sets.""" 320 | examples = [] 321 | for (i, line) in enumerate(lines): 322 | if i == 0: 323 | continue 324 | guid = "%s-%s" % (set_type, i) 325 | text_a = tokenization.convert_to_unicode(line[3]) 326 | text_b = tokenization.convert_to_unicode(line[4]) 327 | if set_type == "test": 328 | label = "0" 329 | else: 330 | label = tokenization.convert_to_unicode(line[0]) 331 | examples.append( 332 | InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) 333 | return examples 334 | 335 | 336 | class ColaProcessor(DataProcessor): 337 | """Processor for the CoLA data set (GLUE version).""" 338 | 339 | def get_train_examples(self, data_dir): 340 | """See base class.""" 341 | return self._create_examples( 342 | self._read_tsv(os.path.join(data_dir, "train.tsv")), "train") 343 | 344 | def get_dev_examples(self, data_dir): 345 | """See base class.""" 346 | return self._create_examples( 347 | self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev") 348 | 349 | def get_test_examples(self, data_dir): 350 | """See base class.""" 351 | return self._create_examples( 352 | self._read_tsv(os.path.join(data_dir, "test.tsv")), "test") 353 | 354 | def get_labels(self): 355 | """See base class.""" 356 | return ["0", "1"] 357 | 358 | def _create_examples(self, lines, set_type): 359 | """Creates examples for the training and dev sets.""" 360 | examples = [] 361 | for (i, line) in enumerate(lines): 362 | # Only the test set has a header 363 | if set_type == "test" and i == 0: 364 | continue 365 | guid = "%s-%s" % (set_type, i) 366 | if set_type == "test": 367 | text_a = tokenization.convert_to_unicode(line[1]) 368 | label = "0" 369 | else: 370 | text_a = tokenization.convert_to_unicode(line[3]) 371 | label = tokenization.convert_to_unicode(line[1]) 372 | examples.append( 373 | InputExample(guid=guid, text_a=text_a, text_b=None, label=label)) 374 | return examples 375 | 376 | 377 | def convert_single_example(ex_index, example, label_list, max_seq_length, 378 | tokenizer): 379 | """Converts a single `InputExample` into a single `InputFeatures`.""" 380 | 381 | if isinstance(example, PaddingInputExample): 382 | return InputFeatures( 383 | input_ids=[0] * max_seq_length, 384 | input_mask=[0] * max_seq_length, 385 | segment_ids=[0] * max_seq_length, 386 | label_id=0, 387 | is_real_example=False) 388 | 389 | label_map = {} 390 | for (i, label) in enumerate(label_list): 391 | label_map[label] = i 392 | 393 | tokens_a = tokenizer.tokenize(example.text_a) 394 | tokens_b = None 395 | if example.text_b: 396 | tokens_b = tokenizer.tokenize(example.text_b) 397 | 398 | if tokens_b: 399 | # Modifies `tokens_a` and `tokens_b` in place so that the total 400 | # length is less than the specified length. 401 | # Account for [CLS], [SEP], [SEP] with "- 3" 402 | _truncate_seq_pair(tokens_a, tokens_b, max_seq_length - 3) 403 | else: 404 | # Account for [CLS] and [SEP] with "- 2" 405 | if len(tokens_a) > max_seq_length - 2: 406 | tokens_a = tokens_a[0:(max_seq_length - 2)] 407 | 408 | # The convention in BERT is: 409 | # (a) For sequence pairs: 410 | # tokens: [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP] 411 | # type_ids: 0 0 0 0 0 0 0 0 1 1 1 1 1 1 412 | # (b) For single sequences: 413 | # tokens: [CLS] the dog is hairy . [SEP] 414 | # type_ids: 0 0 0 0 0 0 0 415 | # 416 | # Where "type_ids" are used to indicate whether this is the first 417 | # sequence or the second sequence. The embedding vectors for `type=0` and 418 | # `type=1` were learned during pre-training and are added to the wordpiece 419 | # embedding vector (and position vector). This is not *strictly* necessary 420 | # since the [SEP] token unambiguously separates the sequences, but it makes 421 | # it easier for the model to learn the concept of sequences. 422 | # 423 | # For classification tasks, the first vector (corresponding to [CLS]) is 424 | # used as the "sentence vector". Note that this only makes sense because 425 | # the entire model is fine-tuned. 426 | tokens = [] 427 | segment_ids = [] 428 | tokens.append("[CLS]") 429 | segment_ids.append(0) 430 | for token in tokens_a: 431 | tokens.append(token) 432 | segment_ids.append(0) 433 | tokens.append("[SEP]") 434 | segment_ids.append(0) 435 | 436 | if tokens_b: 437 | for token in tokens_b: 438 | tokens.append(token) 439 | segment_ids.append(1) 440 | tokens.append("[SEP]") 441 | segment_ids.append(1) 442 | 443 | input_ids = tokenizer.convert_tokens_to_ids(tokens) 444 | 445 | # The mask has 1 for real tokens and 0 for padding tokens. Only real 446 | # tokens are attended to. 447 | input_mask = [1] * len(input_ids) 448 | 449 | # Zero-pad up to the sequence length. 450 | while len(input_ids) < max_seq_length: 451 | input_ids.append(0) 452 | input_mask.append(0) 453 | segment_ids.append(0) 454 | 455 | assert len(input_ids) == max_seq_length 456 | assert len(input_mask) == max_seq_length 457 | assert len(segment_ids) == max_seq_length 458 | 459 | label_id = label_map[example.label] 460 | if ex_index < 5: 461 | tf.logging.info("*** Example ***") 462 | tf.logging.info("guid: %s" % (example.guid)) 463 | tf.logging.info("tokens: %s" % " ".join( 464 | [tokenization.printable_text(x) for x in tokens])) 465 | tf.logging.info("input_ids: %s" % " ".join([str(x) for x in input_ids])) 466 | tf.logging.info("input_mask: %s" % " ".join([str(x) for x in input_mask])) 467 | tf.logging.info("segment_ids: %s" % " ".join([str(x) for x in segment_ids])) 468 | tf.logging.info("label: %s (id = %d)" % (example.label, label_id)) 469 | 470 | feature = InputFeatures( 471 | input_ids=input_ids, 472 | input_mask=input_mask, 473 | segment_ids=segment_ids, 474 | label_id=label_id, 475 | is_real_example=True) 476 | return feature 477 | 478 | 479 | def file_based_convert_examples_to_features( 480 | examples, label_list, max_seq_length, tokenizer, output_file): 481 | """Convert a set of `InputExample`s to a TFRecord file.""" 482 | 483 | writer = tf.python_io.TFRecordWriter(output_file) 484 | 485 | for (ex_index, example) in enumerate(examples): 486 | if ex_index % 10000 == 0: 487 | tf.logging.info("Writing example %d of %d" % (ex_index, len(examples))) 488 | 489 | feature = convert_single_example(ex_index, example, label_list, 490 | max_seq_length, tokenizer) 491 | 492 | def create_int_feature(values): 493 | f = tf.train.Feature(int64_list=tf.train.Int64List(value=list(values))) 494 | return f 495 | 496 | features = collections.OrderedDict() 497 | features["input_ids"] = create_int_feature(feature.input_ids) 498 | features["input_mask"] = create_int_feature(feature.input_mask) 499 | features["segment_ids"] = create_int_feature(feature.segment_ids) 500 | features["label_ids"] = create_int_feature([feature.label_id]) 501 | features["is_real_example"] = create_int_feature( 502 | [int(feature.is_real_example)]) 503 | 504 | tf_example = tf.train.Example(features=tf.train.Features(feature=features)) 505 | writer.write(tf_example.SerializeToString()) 506 | writer.close() 507 | 508 | 509 | def file_based_input_fn_builder(input_file, seq_length, is_training, 510 | drop_remainder): 511 | """Creates an `input_fn` closure to be passed to TPUEstimator.""" 512 | 513 | name_to_features = { 514 | "input_ids": tf.FixedLenFeature([seq_length], tf.int64), 515 | "input_mask": tf.FixedLenFeature([seq_length], tf.int64), 516 | "segment_ids": tf.FixedLenFeature([seq_length], tf.int64), 517 | "label_ids": tf.FixedLenFeature([], tf.int64), 518 | "is_real_example": tf.FixedLenFeature([], tf.int64), 519 | } 520 | 521 | def _decode_record(record, name_to_features): 522 | """Decodes a record to a TensorFlow example.""" 523 | example = tf.parse_single_example(record, name_to_features) 524 | 525 | # tf.Example only supports tf.int64, but the TPU only supports tf.int32. 526 | # So cast all int64 to int32. 527 | for name in list(example.keys()): 528 | t = example[name] 529 | if t.dtype == tf.int64: 530 | t = tf.to_int32(t) 531 | example[name] = t 532 | 533 | return example 534 | 535 | def input_fn(params): 536 | """The actual input function.""" 537 | batch_size = params["batch_size"] 538 | 539 | # For training, we want a lot of parallel reading and shuffling. 540 | # For eval, we want no shuffling and parallel reading doesn't matter. 541 | d = tf.data.TFRecordDataset(input_file) 542 | if is_training: 543 | d = d.repeat() 544 | d = d.shuffle(buffer_size=100) 545 | 546 | d = d.apply( 547 | tf.contrib.data.map_and_batch( 548 | lambda record: _decode_record(record, name_to_features), 549 | batch_size=batch_size, 550 | drop_remainder=drop_remainder)) 551 | 552 | return d 553 | 554 | return input_fn 555 | 556 | 557 | def _truncate_seq_pair(tokens_a, tokens_b, max_length): 558 | """Truncates a sequence pair in place to the maximum length.""" 559 | 560 | # This is a simple heuristic which will always truncate the longer sequence 561 | # one token at a time. This makes more sense than truncating an equal percent 562 | # of tokens from each, since if one sequence is very short then each token 563 | # that's truncated likely contains more information than a longer sequence. 564 | while True: 565 | total_length = len(tokens_a) + len(tokens_b) 566 | if total_length <= max_length: 567 | break 568 | if len(tokens_a) > len(tokens_b): 569 | tokens_a.pop() 570 | else: 571 | tokens_b.pop() 572 | 573 | 574 | def create_model(bert_config, is_training, input_ids, input_mask, segment_ids, 575 | labels, num_labels, use_one_hot_embeddings): 576 | """Creates a classification model.""" 577 | model = modeling.BertModel( 578 | config=bert_config, 579 | is_training=is_training, 580 | input_ids=input_ids, 581 | input_mask=input_mask, 582 | token_type_ids=segment_ids, 583 | use_one_hot_embeddings=use_one_hot_embeddings) 584 | 585 | # In the demo, we are doing a simple classification task on the entire 586 | # segment. 587 | # 588 | # If you want to use the token-level output, use model.get_sequence_output() 589 | # instead. 590 | output_layer = model.get_pooled_output() 591 | 592 | hidden_size = output_layer.shape[-1].value 593 | 594 | output_weights = tf.get_variable( 595 | "output_weights", [num_labels, hidden_size], 596 | initializer=tf.truncated_normal_initializer(stddev=0.02)) 597 | 598 | output_bias = tf.get_variable( 599 | "output_bias", [num_labels], initializer=tf.zeros_initializer()) 600 | 601 | with tf.variable_scope("loss"): 602 | if is_training: 603 | # I.e., 0.1 dropout 604 | output_layer = tf.nn.dropout(output_layer, keep_prob=0.9) 605 | 606 | logits = tf.matmul(output_layer, output_weights, transpose_b=True) 607 | logits = tf.nn.bias_add(logits, output_bias) 608 | probabilities = tf.nn.softmax(logits, axis=-1) 609 | log_probs = tf.nn.log_softmax(logits, axis=-1) 610 | 611 | one_hot_labels = tf.one_hot(labels, depth=num_labels, dtype=tf.float32) 612 | 613 | per_example_loss = -tf.reduce_sum(one_hot_labels * log_probs, axis=-1) 614 | loss = tf.reduce_mean(per_example_loss) 615 | 616 | return (loss, per_example_loss, logits, probabilities) 617 | 618 | 619 | def model_fn_builder(bert_config, num_labels, init_checkpoint, learning_rate, 620 | num_train_steps, num_warmup_steps, use_tpu, 621 | use_one_hot_embeddings): 622 | """Returns `model_fn` closure for TPUEstimator.""" 623 | 624 | def model_fn(features, labels, mode, params): # pylint: disable=unused-argument 625 | """The `model_fn` for TPUEstimator.""" 626 | 627 | tf.logging.info("*** Features ***") 628 | for name in sorted(features.keys()): 629 | tf.logging.info(" name = %s, shape = %s" % (name, features[name].shape)) 630 | 631 | input_ids = features["input_ids"] 632 | input_mask = features["input_mask"] 633 | segment_ids = features["segment_ids"] 634 | label_ids = features["label_ids"] 635 | is_real_example = None 636 | if "is_real_example" in features: 637 | is_real_example = tf.cast(features["is_real_example"], dtype=tf.float32) 638 | else: 639 | is_real_example = tf.ones(tf.shape(label_ids), dtype=tf.float32) 640 | 641 | is_training = (mode == tf.estimator.ModeKeys.TRAIN) 642 | 643 | (total_loss, per_example_loss, logits, probabilities) = create_model( 644 | bert_config, is_training, input_ids, input_mask, segment_ids, label_ids, 645 | num_labels, use_one_hot_embeddings) 646 | 647 | tvars = tf.trainable_variables() 648 | initialized_variable_names = {} 649 | scaffold_fn = None 650 | if init_checkpoint: 651 | (assignment_map, initialized_variable_names 652 | ) = modeling.get_assignment_map_from_checkpoint(tvars, init_checkpoint) 653 | if use_tpu: 654 | 655 | def tpu_scaffold(): 656 | tf.train.init_from_checkpoint(init_checkpoint, assignment_map) 657 | return tf.train.Scaffold() 658 | 659 | scaffold_fn = tpu_scaffold 660 | else: 661 | tf.train.init_from_checkpoint(init_checkpoint, assignment_map) 662 | 663 | tf.logging.info("**** Trainable Variables ****") 664 | for var in tvars: 665 | init_string = "" 666 | if var.name in initialized_variable_names: 667 | init_string = ", *INIT_FROM_CKPT*" 668 | tf.logging.info(" name = %s, shape = %s%s", var.name, var.shape, 669 | init_string) 670 | 671 | output_spec = None 672 | if mode == tf.estimator.ModeKeys.TRAIN: 673 | 674 | train_op = optimization.create_optimizer( 675 | total_loss, learning_rate, num_train_steps, num_warmup_steps, use_tpu) 676 | 677 | output_spec = tf.contrib.tpu.TPUEstimatorSpec( 678 | mode=mode, 679 | loss=total_loss, 680 | train_op=train_op, 681 | scaffold_fn=scaffold_fn) 682 | elif mode == tf.estimator.ModeKeys.EVAL: 683 | 684 | def metric_fn(per_example_loss, label_ids, logits, is_real_example): 685 | predictions = tf.argmax(logits, axis=-1, output_type=tf.int32) 686 | accuracy = tf.metrics.accuracy( 687 | labels=label_ids, predictions=predictions, weights=is_real_example) 688 | loss = tf.metrics.mean(values=per_example_loss, weights=is_real_example) 689 | return { 690 | "eval_accuracy": accuracy, 691 | "eval_loss": loss, 692 | } 693 | 694 | eval_metrics = (metric_fn, 695 | [per_example_loss, label_ids, logits, is_real_example]) 696 | output_spec = tf.contrib.tpu.TPUEstimatorSpec( 697 | mode=mode, 698 | loss=total_loss, 699 | eval_metrics=eval_metrics, 700 | scaffold_fn=scaffold_fn) 701 | else: 702 | output_spec = tf.contrib.tpu.TPUEstimatorSpec( 703 | mode=mode, 704 | predictions={"probabilities": probabilities}, 705 | scaffold_fn=scaffold_fn) 706 | return output_spec 707 | 708 | return model_fn 709 | 710 | 711 | # This function is not used by this file but is still used by the Colab and 712 | # people who depend on it. 713 | def input_fn_builder(features, seq_length, is_training, drop_remainder): 714 | """Creates an `input_fn` closure to be passed to TPUEstimator.""" 715 | 716 | all_input_ids = [] 717 | all_input_mask = [] 718 | all_segment_ids = [] 719 | all_label_ids = [] 720 | 721 | for feature in features: 722 | all_input_ids.append(feature.input_ids) 723 | all_input_mask.append(feature.input_mask) 724 | all_segment_ids.append(feature.segment_ids) 725 | all_label_ids.append(feature.label_id) 726 | 727 | def input_fn(params): 728 | """The actual input function.""" 729 | batch_size = params["batch_size"] 730 | 731 | num_examples = len(features) 732 | 733 | # This is for demo purposes and does NOT scale to large data sets. We do 734 | # not use Dataset.from_generator() because that uses tf.py_func which is 735 | # not TPU compatible. The right way to load data is with TFRecordReader. 736 | d = tf.data.Dataset.from_tensor_slices({ 737 | "input_ids": 738 | tf.constant( 739 | all_input_ids, shape=[num_examples, seq_length], 740 | dtype=tf.int32), 741 | "input_mask": 742 | tf.constant( 743 | all_input_mask, 744 | shape=[num_examples, seq_length], 745 | dtype=tf.int32), 746 | "segment_ids": 747 | tf.constant( 748 | all_segment_ids, 749 | shape=[num_examples, seq_length], 750 | dtype=tf.int32), 751 | "label_ids": 752 | tf.constant(all_label_ids, shape=[num_examples], dtype=tf.int32), 753 | }) 754 | 755 | if is_training: 756 | d = d.repeat() 757 | d = d.shuffle(buffer_size=100) 758 | 759 | d = d.batch(batch_size=batch_size, drop_remainder=drop_remainder) 760 | return d 761 | 762 | return input_fn 763 | 764 | 765 | # This function is not used by this file but is still used by the Colab and 766 | # people who depend on it. 767 | def convert_examples_to_features(examples, label_list, max_seq_length, 768 | tokenizer): 769 | """Convert a set of `InputExample`s to a list of `InputFeatures`.""" 770 | 771 | features = [] 772 | for (ex_index, example) in enumerate(examples): 773 | if ex_index % 10000 == 0: 774 | tf.logging.info("Writing example %d of %d" % (ex_index, len(examples))) 775 | 776 | feature = convert_single_example(ex_index, example, label_list, 777 | max_seq_length, tokenizer) 778 | 779 | features.append(feature) 780 | return features 781 | 782 | 783 | def main(_): 784 | tf.logging.set_verbosity(tf.logging.INFO) 785 | 786 | processors = { 787 | "cola": ColaProcessor, 788 | "mnli": MnliProcessor, 789 | "mrpc": MrpcProcessor, 790 | "xnli": XnliProcessor, 791 | } 792 | 793 | tokenization.validate_case_matches_checkpoint(FLAGS.do_lower_case, 794 | FLAGS.init_checkpoint) 795 | 796 | if not FLAGS.do_train and not FLAGS.do_eval and not FLAGS.do_predict: 797 | raise ValueError( 798 | "At least one of `do_train`, `do_eval` or `do_predict' must be True.") 799 | 800 | bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file) 801 | 802 | if FLAGS.max_seq_length > bert_config.max_position_embeddings: 803 | raise ValueError( 804 | "Cannot use sequence length %d because the BERT model " 805 | "was only trained up to sequence length %d" % 806 | (FLAGS.max_seq_length, bert_config.max_position_embeddings)) 807 | 808 | tf.gfile.MakeDirs(FLAGS.output_dir) 809 | 810 | task_name = FLAGS.task_name.lower() 811 | 812 | if task_name not in processors: 813 | raise ValueError("Task not found: %s" % (task_name)) 814 | 815 | processor = processors[task_name]() 816 | 817 | label_list = processor.get_labels() 818 | 819 | tokenizer = tokenization.FullTokenizer( 820 | vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case) 821 | 822 | tpu_cluster_resolver = None 823 | if FLAGS.use_tpu and FLAGS.tpu_name: 824 | tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver( 825 | FLAGS.tpu_name, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project) 826 | 827 | is_per_host = tf.contrib.tpu.InputPipelineConfig.PER_HOST_V2 828 | run_config = tf.contrib.tpu.RunConfig( 829 | cluster=tpu_cluster_resolver, 830 | master=FLAGS.master, 831 | model_dir=FLAGS.output_dir, 832 | save_checkpoints_steps=FLAGS.save_checkpoints_steps, 833 | tpu_config=tf.contrib.tpu.TPUConfig( 834 | iterations_per_loop=FLAGS.iterations_per_loop, 835 | num_shards=FLAGS.num_tpu_cores, 836 | per_host_input_for_training=is_per_host)) 837 | 838 | train_examples = None 839 | num_train_steps = None 840 | num_warmup_steps = None 841 | if FLAGS.do_train: 842 | train_examples = processor.get_train_examples(FLAGS.data_dir) 843 | num_train_steps = int( 844 | len(train_examples) / FLAGS.train_batch_size * FLAGS.num_train_epochs) 845 | num_warmup_steps = int(num_train_steps * FLAGS.warmup_proportion) 846 | 847 | model_fn = model_fn_builder( 848 | bert_config=bert_config, 849 | num_labels=len(label_list), 850 | init_checkpoint=FLAGS.init_checkpoint, 851 | learning_rate=FLAGS.learning_rate, 852 | num_train_steps=num_train_steps, 853 | num_warmup_steps=num_warmup_steps, 854 | use_tpu=FLAGS.use_tpu, 855 | use_one_hot_embeddings=FLAGS.use_tpu) 856 | 857 | # If TPU is not available, this will fall back to normal Estimator on CPU 858 | # or GPU. 859 | estimator = tf.contrib.tpu.TPUEstimator( 860 | use_tpu=FLAGS.use_tpu, 861 | model_fn=model_fn, 862 | config=run_config, 863 | train_batch_size=FLAGS.train_batch_size, 864 | eval_batch_size=FLAGS.eval_batch_size, 865 | predict_batch_size=FLAGS.predict_batch_size) 866 | 867 | if FLAGS.do_train: 868 | train_file = os.path.join(FLAGS.output_dir, "train.tf_record") 869 | file_based_convert_examples_to_features( 870 | train_examples, label_list, FLAGS.max_seq_length, tokenizer, train_file) 871 | tf.logging.info("***** Running training *****") 872 | tf.logging.info(" Num examples = %d", len(train_examples)) 873 | tf.logging.info(" Batch size = %d", FLAGS.train_batch_size) 874 | tf.logging.info(" Num steps = %d", num_train_steps) 875 | train_input_fn = file_based_input_fn_builder( 876 | input_file=train_file, 877 | seq_length=FLAGS.max_seq_length, 878 | is_training=True, 879 | drop_remainder=True) 880 | estimator.train(input_fn=train_input_fn, max_steps=num_train_steps) 881 | 882 | if FLAGS.do_eval: 883 | eval_examples = processor.get_dev_examples(FLAGS.data_dir) 884 | num_actual_eval_examples = len(eval_examples) 885 | if FLAGS.use_tpu: 886 | # TPU requires a fixed batch size for all batches, therefore the number 887 | # of examples must be a multiple of the batch size, or else examples 888 | # will get dropped. So we pad with fake examples which are ignored 889 | # later on. These do NOT count towards the metric (all tf.metrics 890 | # support a per-instance weight, and these get a weight of 0.0). 891 | while len(eval_examples) % FLAGS.eval_batch_size != 0: 892 | eval_examples.append(PaddingInputExample()) 893 | 894 | eval_file = os.path.join(FLAGS.output_dir, "eval.tf_record") 895 | file_based_convert_examples_to_features( 896 | eval_examples, label_list, FLAGS.max_seq_length, tokenizer, eval_file) 897 | 898 | tf.logging.info("***** Running evaluation *****") 899 | tf.logging.info(" Num examples = %d (%d actual, %d padding)", 900 | len(eval_examples), num_actual_eval_examples, 901 | len(eval_examples) - num_actual_eval_examples) 902 | tf.logging.info(" Batch size = %d", FLAGS.eval_batch_size) 903 | 904 | # This tells the estimator to run through the entire set. 905 | eval_steps = None 906 | # However, if running eval on the TPU, you will need to specify the 907 | # number of steps. 908 | if FLAGS.use_tpu: 909 | assert len(eval_examples) % FLAGS.eval_batch_size == 0 910 | eval_steps = int(len(eval_examples) // FLAGS.eval_batch_size) 911 | 912 | eval_drop_remainder = True if FLAGS.use_tpu else False 913 | eval_input_fn = file_based_input_fn_builder( 914 | input_file=eval_file, 915 | seq_length=FLAGS.max_seq_length, 916 | is_training=False, 917 | drop_remainder=eval_drop_remainder) 918 | 919 | result = estimator.evaluate(input_fn=eval_input_fn, steps=eval_steps) 920 | 921 | output_eval_file = os.path.join(FLAGS.output_dir, "eval_results.txt") 922 | with tf.gfile.GFile(output_eval_file, "w") as writer: 923 | tf.logging.info("***** Eval results *****") 924 | for key in sorted(result.keys()): 925 | tf.logging.info(" %s = %s", key, str(result[key])) 926 | writer.write("%s = %s\n" % (key, str(result[key]))) 927 | 928 | if FLAGS.do_predict: 929 | predict_examples = processor.get_test_examples(FLAGS.data_dir) 930 | num_actual_predict_examples = len(predict_examples) 931 | if FLAGS.use_tpu: 932 | # TPU requires a fixed batch size for all batches, therefore the number 933 | # of examples must be a multiple of the batch size, or else examples 934 | # will get dropped. So we pad with fake examples which are ignored 935 | # later on. 936 | while len(predict_examples) % FLAGS.predict_batch_size != 0: 937 | predict_examples.append(PaddingInputExample()) 938 | 939 | predict_file = os.path.join(FLAGS.output_dir, "predict.tf_record") 940 | file_based_convert_examples_to_features(predict_examples, label_list, 941 | FLAGS.max_seq_length, tokenizer, 942 | predict_file) 943 | 944 | tf.logging.info("***** Running prediction*****") 945 | tf.logging.info(" Num examples = %d (%d actual, %d padding)", 946 | len(predict_examples), num_actual_predict_examples, 947 | len(predict_examples) - num_actual_predict_examples) 948 | tf.logging.info(" Batch size = %d", FLAGS.predict_batch_size) 949 | 950 | predict_drop_remainder = True if FLAGS.use_tpu else False 951 | predict_input_fn = file_based_input_fn_builder( 952 | input_file=predict_file, 953 | seq_length=FLAGS.max_seq_length, 954 | is_training=False, 955 | drop_remainder=predict_drop_remainder) 956 | 957 | result = estimator.predict(input_fn=predict_input_fn) 958 | 959 | output_predict_file = os.path.join(FLAGS.output_dir, "test_results.tsv") 960 | with tf.gfile.GFile(output_predict_file, "w") as writer: 961 | num_written_lines = 0 962 | tf.logging.info("***** Predict results *****") 963 | for (i, prediction) in enumerate(result): 964 | probabilities = prediction["probabilities"] 965 | if i >= num_actual_predict_examples: 966 | break 967 | output_line = "\t".join( 968 | str(class_probability) 969 | for class_probability in probabilities) + "\n" 970 | writer.write(output_line) 971 | num_written_lines += 1 972 | assert num_written_lines == num_actual_predict_examples 973 | 974 | 975 | if __name__ == "__main__": 976 | flags.mark_flag_as_required("data_dir") 977 | flags.mark_flag_as_required("task_name") 978 | flags.mark_flag_as_required("vocab_file") 979 | flags.mark_flag_as_required("bert_config_file") 980 | flags.mark_flag_as_required("output_dir") 981 | tf.compat.v1.app.run() 982 | -------------------------------------------------------------------------------- /run_squad.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Run BERT on SQuAD 1.1 and SQuAD 2.0.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import collections 22 | import json 23 | import math 24 | import os 25 | import random 26 | import modeling 27 | import optimization 28 | import tokenization 29 | import six 30 | import tensorflow as tf 31 | 32 | flags = tf.compat.v1.app.flags 33 | 34 | FLAGS = flags.FLAGS 35 | 36 | ## Required parameters 37 | flags.DEFINE_string( 38 | "bert_config_file", None, 39 | "The config json file corresponding to the pre-trained BERT model. " 40 | "This specifies the model architecture.") 41 | 42 | flags.DEFINE_string("vocab_file", None, 43 | "The vocabulary file that the BERT model was trained on.") 44 | 45 | flags.DEFINE_string( 46 | "output_dir", None, 47 | "The output directory where the model checkpoints will be written.") 48 | 49 | ## Other parameters 50 | flags.DEFINE_string("train_file", None, 51 | "SQuAD json for training. E.g., train-v1.1.json") 52 | 53 | flags.DEFINE_string( 54 | "predict_file", None, 55 | "SQuAD json for predictions. E.g., dev-v1.1.json or test-v1.1.json") 56 | 57 | flags.DEFINE_string( 58 | "init_checkpoint", None, 59 | "Initial checkpoint (usually from a pre-trained BERT model).") 60 | 61 | flags.DEFINE_bool( 62 | "do_lower_case", True, 63 | "Whether to lower case the input text. Should be True for uncased " 64 | "models and False for cased models.") 65 | 66 | flags.DEFINE_integer( 67 | "max_seq_length", 384, 68 | "The maximum total input sequence length after WordPiece tokenization. " 69 | "Sequences longer than this will be truncated, and sequences shorter " 70 | "than this will be padded.") 71 | 72 | flags.DEFINE_integer( 73 | "doc_stride", 128, 74 | "When splitting up a long document into chunks, how much stride to " 75 | "take between chunks.") 76 | 77 | flags.DEFINE_integer( 78 | "max_query_length", 64, 79 | "The maximum number of tokens for the question. Questions longer than " 80 | "this will be truncated to this length.") 81 | 82 | flags.DEFINE_bool("do_train", False, "Whether to run training.") 83 | 84 | flags.DEFINE_bool("do_predict", False, "Whether to run eval on the dev set.") 85 | 86 | flags.DEFINE_integer("train_batch_size", 32, "Total batch size for training.") 87 | 88 | flags.DEFINE_integer("predict_batch_size", 8, 89 | "Total batch size for predictions.") 90 | 91 | flags.DEFINE_float("learning_rate", 5e-5, "The initial learning rate for Adam.") 92 | 93 | flags.DEFINE_float("num_train_epochs", 3.0, 94 | "Total number of training epochs to perform.") 95 | 96 | flags.DEFINE_float( 97 | "warmup_proportion", 0.1, 98 | "Proportion of training to perform linear learning rate warmup for. " 99 | "E.g., 0.1 = 10% of training.") 100 | 101 | flags.DEFINE_integer("save_checkpoints_steps", 1000, 102 | "How often to save the model checkpoint.") 103 | 104 | flags.DEFINE_integer("iterations_per_loop", 1000, 105 | "How many steps to make in each estimator call.") 106 | 107 | flags.DEFINE_integer( 108 | "n_best_size", 20, 109 | "The total number of n-best predictions to generate in the " 110 | "nbest_predictions.json output file.") 111 | 112 | flags.DEFINE_integer( 113 | "max_answer_length", 30, 114 | "The maximum length of an answer that can be generated. This is needed " 115 | "because the start and end predictions are not conditioned on one another.") 116 | 117 | flags.DEFINE_bool("use_tpu", False, "Whether to use TPU or GPU/CPU.") 118 | 119 | flags.DEFINE_string( 120 | "tpu_name", None, 121 | "The Cloud TPU to use for training. This should be either the name " 122 | "used when creating the Cloud TPU, or a grpc://ip.address.of.tpu:8470 " 123 | "url.") 124 | 125 | flags.DEFINE_string( 126 | "tpu_zone", None, 127 | "[Optional] GCE zone where the Cloud TPU is located in. If not " 128 | "specified, we will attempt to automatically detect the GCE project from " 129 | "metadata.") 130 | 131 | flags.DEFINE_string( 132 | "gcp_project", None, 133 | "[Optional] Project name for the Cloud TPU-enabled project. If not " 134 | "specified, we will attempt to automatically detect the GCE project from " 135 | "metadata.") 136 | 137 | flags.DEFINE_string("master", None, "[Optional] TensorFlow master URL.") 138 | 139 | flags.DEFINE_integer( 140 | "num_tpu_cores", 8, 141 | "Only used if `use_tpu` is True. Total number of TPU cores to use.") 142 | 143 | flags.DEFINE_bool( 144 | "verbose_logging", False, 145 | "If true, all of the warnings related to data processing will be printed. " 146 | "A number of warnings are expected for a normal SQuAD evaluation.") 147 | 148 | flags.DEFINE_bool( 149 | "version_2_with_negative", False, 150 | "If true, the SQuAD examples contain some that do not have an answer.") 151 | 152 | flags.DEFINE_float( 153 | "null_score_diff_threshold", 0.0, 154 | "If null_score - best_non_null is greater than the threshold predict null.") 155 | 156 | 157 | class SquadExample(object): 158 | """A single training/test example for simple sequence classification. 159 | 160 | For examples without an answer, the start and end position are -1. 161 | """ 162 | 163 | def __init__(self, 164 | qas_id, 165 | question_text, 166 | doc_tokens, 167 | orig_answer_text=None, 168 | start_position=None, 169 | end_position=None, 170 | is_impossible=False): 171 | self.qas_id = qas_id 172 | self.question_text = question_text 173 | self.doc_tokens = doc_tokens 174 | self.orig_answer_text = orig_answer_text 175 | self.start_position = start_position 176 | self.end_position = end_position 177 | self.is_impossible = is_impossible 178 | 179 | def __str__(self): 180 | return self.__repr__() 181 | 182 | def __repr__(self): 183 | s = "" 184 | s += "qas_id: %s" % (tokenization.printable_text(self.qas_id)) 185 | s += ", question_text: %s" % ( 186 | tokenization.printable_text(self.question_text)) 187 | s += ", doc_tokens: [%s]" % (" ".join(self.doc_tokens)) 188 | if self.start_position: 189 | s += ", start_position: %d" % (self.start_position) 190 | if self.start_position: 191 | s += ", end_position: %d" % (self.end_position) 192 | if self.start_position: 193 | s += ", is_impossible: %r" % (self.is_impossible) 194 | return s 195 | 196 | 197 | class InputFeatures(object): 198 | """A single set of features of data.""" 199 | 200 | def __init__(self, 201 | unique_id, 202 | example_index, 203 | doc_span_index, 204 | tokens, 205 | token_to_orig_map, 206 | token_is_max_context, 207 | input_ids, 208 | input_mask, 209 | segment_ids, 210 | start_position=None, 211 | end_position=None, 212 | is_impossible=None): 213 | self.unique_id = unique_id 214 | self.example_index = example_index 215 | self.doc_span_index = doc_span_index 216 | self.tokens = tokens 217 | self.token_to_orig_map = token_to_orig_map 218 | self.token_is_max_context = token_is_max_context 219 | self.input_ids = input_ids 220 | self.input_mask = input_mask 221 | self.segment_ids = segment_ids 222 | self.start_position = start_position 223 | self.end_position = end_position 224 | self.is_impossible = is_impossible 225 | 226 | 227 | def read_squad_examples(input_file, is_training): 228 | """Read a SQuAD json file into a list of SquadExample.""" 229 | with tf.gfile.Open(input_file, "r") as reader: 230 | input_data = json.load(reader)["data"] 231 | 232 | def is_whitespace(c): 233 | if c == " " or c == "\t" or c == "\r" or c == "\n" or ord(c) == 0x202F: 234 | return True 235 | return False 236 | 237 | examples = [] 238 | for entry in input_data: 239 | for paragraph in entry["paragraphs"]: 240 | paragraph_text = paragraph["context"] 241 | doc_tokens = [] 242 | char_to_word_offset = [] 243 | prev_is_whitespace = True 244 | for c in paragraph_text: 245 | if is_whitespace(c): 246 | prev_is_whitespace = True 247 | else: 248 | if prev_is_whitespace: 249 | doc_tokens.append(c) 250 | else: 251 | doc_tokens[-1] += c 252 | prev_is_whitespace = False 253 | char_to_word_offset.append(len(doc_tokens) - 1) 254 | 255 | for qa in paragraph["qas"]: 256 | qas_id = qa["id"] 257 | question_text = qa["question"] 258 | start_position = None 259 | end_position = None 260 | orig_answer_text = None 261 | is_impossible = False 262 | if is_training: 263 | 264 | if FLAGS.version_2_with_negative: 265 | is_impossible = qa["is_impossible"] 266 | if (len(qa["answers"]) != 1) and (not is_impossible): 267 | raise ValueError( 268 | "For training, each question should have exactly 1 answer.") 269 | if not is_impossible: 270 | answer = qa["answers"][0] 271 | orig_answer_text = answer["text"] 272 | answer_offset = answer["answer_start"] 273 | answer_length = len(orig_answer_text) 274 | start_position = char_to_word_offset[answer_offset] 275 | end_position = char_to_word_offset[answer_offset + answer_length - 276 | 1] 277 | # Only add answers where the text can be exactly recovered from the 278 | # document. If this CAN'T happen it's likely due to weird Unicode 279 | # stuff so we will just skip the example. 280 | # 281 | # Note that this means for training mode, every example is NOT 282 | # guaranteed to be preserved. 283 | actual_text = " ".join( 284 | doc_tokens[start_position:(end_position + 1)]) 285 | cleaned_answer_text = " ".join( 286 | tokenization.whitespace_tokenize(orig_answer_text)) 287 | if actual_text.find(cleaned_answer_text) == -1: 288 | tf.logging.warning("Could not find answer: '%s' vs. '%s'", 289 | actual_text, cleaned_answer_text) 290 | continue 291 | else: 292 | start_position = -1 293 | end_position = -1 294 | orig_answer_text = "" 295 | 296 | example = SquadExample( 297 | qas_id=qas_id, 298 | question_text=question_text, 299 | doc_tokens=doc_tokens, 300 | orig_answer_text=orig_answer_text, 301 | start_position=start_position, 302 | end_position=end_position, 303 | is_impossible=is_impossible) 304 | examples.append(example) 305 | 306 | return examples 307 | 308 | 309 | def convert_examples_to_features(examples, tokenizer, max_seq_length, 310 | doc_stride, max_query_length, is_training, 311 | output_fn): 312 | """Loads a data file into a list of `InputBatch`s.""" 313 | 314 | unique_id = 1000000000 315 | 316 | for (example_index, example) in enumerate(examples): 317 | query_tokens = tokenizer.tokenize(example.question_text) 318 | 319 | if len(query_tokens) > max_query_length: 320 | query_tokens = query_tokens[0:max_query_length] 321 | 322 | tok_to_orig_index = [] 323 | orig_to_tok_index = [] 324 | all_doc_tokens = [] 325 | for (i, token) in enumerate(example.doc_tokens): 326 | orig_to_tok_index.append(len(all_doc_tokens)) 327 | sub_tokens = tokenizer.tokenize(token) 328 | for sub_token in sub_tokens: 329 | tok_to_orig_index.append(i) 330 | all_doc_tokens.append(sub_token) 331 | 332 | tok_start_position = None 333 | tok_end_position = None 334 | if is_training and example.is_impossible: 335 | tok_start_position = -1 336 | tok_end_position = -1 337 | if is_training and not example.is_impossible: 338 | tok_start_position = orig_to_tok_index[example.start_position] 339 | if example.end_position < len(example.doc_tokens) - 1: 340 | tok_end_position = orig_to_tok_index[example.end_position + 1] - 1 341 | else: 342 | tok_end_position = len(all_doc_tokens) - 1 343 | (tok_start_position, tok_end_position) = _improve_answer_span( 344 | all_doc_tokens, tok_start_position, tok_end_position, tokenizer, 345 | example.orig_answer_text) 346 | 347 | # The -3 accounts for [CLS], [SEP] and [SEP] 348 | max_tokens_for_doc = max_seq_length - len(query_tokens) - 3 349 | 350 | # We can have documents that are longer than the maximum sequence length. 351 | # To deal with this we do a sliding window approach, where we take chunks 352 | # of the up to our max length with a stride of `doc_stride`. 353 | _DocSpan = collections.namedtuple( # pylint: disable=invalid-name 354 | "DocSpan", ["start", "length"]) 355 | doc_spans = [] 356 | start_offset = 0 357 | while start_offset < len(all_doc_tokens): 358 | length = len(all_doc_tokens) - start_offset 359 | if length > max_tokens_for_doc: 360 | length = max_tokens_for_doc 361 | doc_spans.append(_DocSpan(start=start_offset, length=length)) 362 | if start_offset + length == len(all_doc_tokens): 363 | break 364 | start_offset += min(length, doc_stride) 365 | 366 | for (doc_span_index, doc_span) in enumerate(doc_spans): 367 | tokens = [] 368 | token_to_orig_map = {} 369 | token_is_max_context = {} 370 | segment_ids = [] 371 | tokens.append("[CLS]") 372 | segment_ids.append(0) 373 | for token in query_tokens: 374 | tokens.append(token) 375 | segment_ids.append(0) 376 | tokens.append("[SEP]") 377 | segment_ids.append(0) 378 | 379 | for i in range(doc_span.length): 380 | split_token_index = doc_span.start + i 381 | token_to_orig_map[len(tokens)] = tok_to_orig_index[split_token_index] 382 | 383 | is_max_context = _check_is_max_context(doc_spans, doc_span_index, 384 | split_token_index) 385 | token_is_max_context[len(tokens)] = is_max_context 386 | tokens.append(all_doc_tokens[split_token_index]) 387 | segment_ids.append(1) 388 | tokens.append("[SEP]") 389 | segment_ids.append(1) 390 | 391 | input_ids = tokenizer.convert_tokens_to_ids(tokens) 392 | 393 | # The mask has 1 for real tokens and 0 for padding tokens. Only real 394 | # tokens are attended to. 395 | input_mask = [1] * len(input_ids) 396 | 397 | # Zero-pad up to the sequence length. 398 | while len(input_ids) < max_seq_length: 399 | input_ids.append(0) 400 | input_mask.append(0) 401 | segment_ids.append(0) 402 | 403 | assert len(input_ids) == max_seq_length 404 | assert len(input_mask) == max_seq_length 405 | assert len(segment_ids) == max_seq_length 406 | 407 | start_position = None 408 | end_position = None 409 | if is_training and not example.is_impossible: 410 | # For training, if our document chunk does not contain an annotation 411 | # we throw it out, since there is nothing to predict. 412 | doc_start = doc_span.start 413 | doc_end = doc_span.start + doc_span.length - 1 414 | out_of_span = False 415 | if not (tok_start_position >= doc_start and 416 | tok_end_position <= doc_end): 417 | out_of_span = True 418 | if out_of_span: 419 | start_position = 0 420 | end_position = 0 421 | else: 422 | doc_offset = len(query_tokens) + 2 423 | start_position = tok_start_position - doc_start + doc_offset 424 | end_position = tok_end_position - doc_start + doc_offset 425 | 426 | if is_training and example.is_impossible: 427 | start_position = 0 428 | end_position = 0 429 | 430 | if example_index < 20: 431 | tf.logging.info("*** Example ***") 432 | tf.logging.info("unique_id: %s" % (unique_id)) 433 | tf.logging.info("example_index: %s" % (example_index)) 434 | tf.logging.info("doc_span_index: %s" % (doc_span_index)) 435 | tf.logging.info("tokens: %s" % " ".join( 436 | [tokenization.printable_text(x) for x in tokens])) 437 | tf.logging.info("token_to_orig_map: %s" % " ".join( 438 | ["%d:%d" % (x, y) for (x, y) in six.iteritems(token_to_orig_map)])) 439 | tf.logging.info("token_is_max_context: %s" % " ".join([ 440 | "%d:%s" % (x, y) for (x, y) in six.iteritems(token_is_max_context) 441 | ])) 442 | tf.logging.info("input_ids: %s" % " ".join([str(x) for x in input_ids])) 443 | tf.logging.info( 444 | "input_mask: %s" % " ".join([str(x) for x in input_mask])) 445 | tf.logging.info( 446 | "segment_ids: %s" % " ".join([str(x) for x in segment_ids])) 447 | if is_training and example.is_impossible: 448 | tf.logging.info("impossible example") 449 | if is_training and not example.is_impossible: 450 | answer_text = " ".join(tokens[start_position:(end_position + 1)]) 451 | tf.logging.info("start_position: %d" % (start_position)) 452 | tf.logging.info("end_position: %d" % (end_position)) 453 | tf.logging.info( 454 | "answer: %s" % (tokenization.printable_text(answer_text))) 455 | 456 | feature = InputFeatures( 457 | unique_id=unique_id, 458 | example_index=example_index, 459 | doc_span_index=doc_span_index, 460 | tokens=tokens, 461 | token_to_orig_map=token_to_orig_map, 462 | token_is_max_context=token_is_max_context, 463 | input_ids=input_ids, 464 | input_mask=input_mask, 465 | segment_ids=segment_ids, 466 | start_position=start_position, 467 | end_position=end_position, 468 | is_impossible=example.is_impossible) 469 | 470 | # Run callback 471 | output_fn(feature) 472 | 473 | unique_id += 1 474 | 475 | 476 | def _improve_answer_span(doc_tokens, input_start, input_end, tokenizer, 477 | orig_answer_text): 478 | """Returns tokenized answer spans that better match the annotated answer.""" 479 | 480 | # The SQuAD annotations are character based. We first project them to 481 | # whitespace-tokenized words. But then after WordPiece tokenization, we can 482 | # often find a "better match". For example: 483 | # 484 | # Question: What year was John Smith born? 485 | # Context: The leader was John Smith (1895-1943). 486 | # Answer: 1895 487 | # 488 | # The original whitespace-tokenized answer will be "(1895-1943).". However 489 | # after tokenization, our tokens will be "( 1895 - 1943 ) .". So we can match 490 | # the exact answer, 1895. 491 | # 492 | # However, this is not always possible. Consider the following: 493 | # 494 | # Question: What country is the top exporter of electornics? 495 | # Context: The Japanese electronics industry is the lagest in the world. 496 | # Answer: Japan 497 | # 498 | # In this case, the annotator chose "Japan" as a character sub-span of 499 | # the word "Japanese". Since our WordPiece tokenizer does not split 500 | # "Japanese", we just use "Japanese" as the annotation. This is fairly rare 501 | # in SQuAD, but does happen. 502 | tok_answer_text = " ".join(tokenizer.tokenize(orig_answer_text)) 503 | 504 | for new_start in range(input_start, input_end + 1): 505 | for new_end in range(input_end, new_start - 1, -1): 506 | text_span = " ".join(doc_tokens[new_start:(new_end + 1)]) 507 | if text_span == tok_answer_text: 508 | return (new_start, new_end) 509 | 510 | return (input_start, input_end) 511 | 512 | 513 | def _check_is_max_context(doc_spans, cur_span_index, position): 514 | """Check if this is the 'max context' doc span for the token.""" 515 | 516 | # Because of the sliding window approach taken to scoring documents, a single 517 | # token can appear in multiple documents. E.g. 518 | # Doc: the man went to the store and bought a gallon of milk 519 | # Span A: the man went to the 520 | # Span B: to the store and bought 521 | # Span C: and bought a gallon of 522 | # ... 523 | # 524 | # Now the word 'bought' will have two scores from spans B and C. We only 525 | # want to consider the score with "maximum context", which we define as 526 | # the *minimum* of its left and right context (the *sum* of left and 527 | # right context will always be the same, of course). 528 | # 529 | # In the example the maximum context for 'bought' would be span C since 530 | # it has 1 left context and 3 right context, while span B has 4 left context 531 | # and 0 right context. 532 | best_score = None 533 | best_span_index = None 534 | for (span_index, doc_span) in enumerate(doc_spans): 535 | end = doc_span.start + doc_span.length - 1 536 | if position < doc_span.start: 537 | continue 538 | if position > end: 539 | continue 540 | num_left_context = position - doc_span.start 541 | num_right_context = end - position 542 | score = min(num_left_context, num_right_context) + 0.01 * doc_span.length 543 | if best_score is None or score > best_score: 544 | best_score = score 545 | best_span_index = span_index 546 | 547 | return cur_span_index == best_span_index 548 | 549 | 550 | def create_model(bert_config, is_training, input_ids, input_mask, segment_ids, 551 | use_one_hot_embeddings): 552 | """Creates a classification model.""" 553 | model = modeling.BertModel( 554 | config=bert_config, 555 | is_training=is_training, 556 | input_ids=input_ids, 557 | input_mask=input_mask, 558 | token_type_ids=segment_ids, 559 | use_one_hot_embeddings=use_one_hot_embeddings) 560 | 561 | final_hidden = model.get_sequence_output() 562 | 563 | final_hidden_shape = modeling.get_shape_list(final_hidden, expected_rank=3) 564 | batch_size = final_hidden_shape[0] 565 | seq_length = final_hidden_shape[1] 566 | hidden_size = final_hidden_shape[2] 567 | 568 | output_weights = tf.get_variable( 569 | "cls/squad/output_weights", [2, hidden_size], 570 | initializer=tf.truncated_normal_initializer(stddev=0.02)) 571 | 572 | output_bias = tf.get_variable( 573 | "cls/squad/output_bias", [2], initializer=tf.zeros_initializer()) 574 | 575 | final_hidden_matrix = tf.reshape(final_hidden, 576 | [batch_size * seq_length, hidden_size]) 577 | logits = tf.matmul(final_hidden_matrix, output_weights, transpose_b=True) 578 | logits = tf.nn.bias_add(logits, output_bias) 579 | 580 | logits = tf.reshape(logits, [batch_size, seq_length, 2]) 581 | logits = tf.transpose(logits, [2, 0, 1]) 582 | 583 | unstacked_logits = tf.unstack(logits, axis=0) 584 | 585 | (start_logits, end_logits) = (unstacked_logits[0], unstacked_logits[1]) 586 | 587 | return (start_logits, end_logits) 588 | 589 | 590 | def model_fn_builder(bert_config, init_checkpoint, learning_rate, 591 | num_train_steps, num_warmup_steps, use_tpu, 592 | use_one_hot_embeddings): 593 | """Returns `model_fn` closure for TPUEstimator.""" 594 | 595 | def model_fn(features, labels, mode, params): # pylint: disable=unused-argument 596 | """The `model_fn` for TPUEstimator.""" 597 | 598 | tf.logging.info("*** Features ***") 599 | for name in sorted(features.keys()): 600 | tf.logging.info(" name = %s, shape = %s" % (name, features[name].shape)) 601 | 602 | unique_ids = features["unique_ids"] 603 | input_ids = features["input_ids"] 604 | input_mask = features["input_mask"] 605 | segment_ids = features["segment_ids"] 606 | 607 | is_training = (mode == tf.estimator.ModeKeys.TRAIN) 608 | 609 | (start_logits, end_logits) = create_model( 610 | bert_config=bert_config, 611 | is_training=is_training, 612 | input_ids=input_ids, 613 | input_mask=input_mask, 614 | segment_ids=segment_ids, 615 | use_one_hot_embeddings=use_one_hot_embeddings) 616 | 617 | tvars = tf.trainable_variables() 618 | 619 | initialized_variable_names = {} 620 | scaffold_fn = None 621 | if init_checkpoint: 622 | (assignment_map, initialized_variable_names 623 | ) = modeling.get_assignment_map_from_checkpoint(tvars, init_checkpoint) 624 | if use_tpu: 625 | 626 | def tpu_scaffold(): 627 | tf.train.init_from_checkpoint(init_checkpoint, assignment_map) 628 | return tf.train.Scaffold() 629 | 630 | scaffold_fn = tpu_scaffold 631 | else: 632 | tf.train.init_from_checkpoint(init_checkpoint, assignment_map) 633 | 634 | tf.logging.info("**** Trainable Variables ****") 635 | for var in tvars: 636 | init_string = "" 637 | if var.name in initialized_variable_names: 638 | init_string = ", *INIT_FROM_CKPT*" 639 | tf.logging.info(" name = %s, shape = %s%s", var.name, var.shape, 640 | init_string) 641 | 642 | output_spec = None 643 | if mode == tf.estimator.ModeKeys.TRAIN: 644 | seq_length = modeling.get_shape_list(input_ids)[1] 645 | 646 | def compute_loss(logits, positions): 647 | one_hot_positions = tf.one_hot( 648 | positions, depth=seq_length, dtype=tf.float32) 649 | log_probs = tf.nn.log_softmax(logits, axis=-1) 650 | loss = -tf.reduce_mean( 651 | tf.reduce_sum(one_hot_positions * log_probs, axis=-1)) 652 | return loss 653 | 654 | start_positions = features["start_positions"] 655 | end_positions = features["end_positions"] 656 | 657 | start_loss = compute_loss(start_logits, start_positions) 658 | end_loss = compute_loss(end_logits, end_positions) 659 | 660 | total_loss = (start_loss + end_loss) / 2.0 661 | 662 | train_op = optimization.create_optimizer( 663 | total_loss, learning_rate, num_train_steps, num_warmup_steps, use_tpu) 664 | 665 | output_spec = tf.contrib.tpu.TPUEstimatorSpec( 666 | mode=mode, 667 | loss=total_loss, 668 | train_op=train_op, 669 | scaffold_fn=scaffold_fn) 670 | elif mode == tf.estimator.ModeKeys.PREDICT: 671 | predictions = { 672 | "unique_ids": unique_ids, 673 | "start_logits": start_logits, 674 | "end_logits": end_logits, 675 | } 676 | output_spec = tf.contrib.tpu.TPUEstimatorSpec( 677 | mode=mode, predictions=predictions, scaffold_fn=scaffold_fn) 678 | else: 679 | raise ValueError( 680 | "Only TRAIN and PREDICT modes are supported: %s" % (mode)) 681 | 682 | return output_spec 683 | 684 | return model_fn 685 | 686 | 687 | def input_fn_builder(input_file, seq_length, is_training, drop_remainder): 688 | """Creates an `input_fn` closure to be passed to TPUEstimator.""" 689 | 690 | name_to_features = { 691 | "unique_ids": tf.FixedLenFeature([], tf.int64), 692 | "input_ids": tf.FixedLenFeature([seq_length], tf.int64), 693 | "input_mask": tf.FixedLenFeature([seq_length], tf.int64), 694 | "segment_ids": tf.FixedLenFeature([seq_length], tf.int64), 695 | } 696 | 697 | if is_training: 698 | name_to_features["start_positions"] = tf.FixedLenFeature([], tf.int64) 699 | name_to_features["end_positions"] = tf.FixedLenFeature([], tf.int64) 700 | 701 | def _decode_record(record, name_to_features): 702 | """Decodes a record to a TensorFlow example.""" 703 | example = tf.parse_single_example(record, name_to_features) 704 | 705 | # tf.Example only supports tf.int64, but the TPU only supports tf.int32. 706 | # So cast all int64 to int32. 707 | for name in list(example.keys()): 708 | t = example[name] 709 | if t.dtype == tf.int64: 710 | t = tf.to_int32(t) 711 | example[name] = t 712 | 713 | return example 714 | 715 | def input_fn(params): 716 | """The actual input function.""" 717 | batch_size = params["batch_size"] 718 | 719 | # For training, we want a lot of parallel reading and shuffling. 720 | # For eval, we want no shuffling and parallel reading doesn't matter. 721 | d = tf.data.TFRecordDataset(input_file) 722 | if is_training: 723 | d = d.repeat() 724 | d = d.shuffle(buffer_size=100) 725 | 726 | d = d.apply( 727 | tf.contrib.data.map_and_batch( 728 | lambda record: _decode_record(record, name_to_features), 729 | batch_size=batch_size, 730 | drop_remainder=drop_remainder)) 731 | 732 | return d 733 | 734 | return input_fn 735 | 736 | 737 | RawResult = collections.namedtuple("RawResult", 738 | ["unique_id", "start_logits", "end_logits"]) 739 | 740 | 741 | def write_predictions(all_examples, all_features, all_results, n_best_size, 742 | max_answer_length, do_lower_case, output_prediction_file, 743 | output_nbest_file, output_null_log_odds_file): 744 | """Write final predictions to the json file and log-odds of null if needed.""" 745 | tf.logging.info("Writing predictions to: %s" % (output_prediction_file)) 746 | tf.logging.info("Writing nbest to: %s" % (output_nbest_file)) 747 | 748 | example_index_to_features = collections.defaultdict(list) 749 | for feature in all_features: 750 | example_index_to_features[feature.example_index].append(feature) 751 | 752 | unique_id_to_result = {} 753 | for result in all_results: 754 | unique_id_to_result[result.unique_id] = result 755 | 756 | _PrelimPrediction = collections.namedtuple( # pylint: disable=invalid-name 757 | "PrelimPrediction", 758 | ["feature_index", "start_index", "end_index", "start_logit", "end_logit"]) 759 | 760 | all_predictions = collections.OrderedDict() 761 | all_nbest_json = collections.OrderedDict() 762 | scores_diff_json = collections.OrderedDict() 763 | 764 | for (example_index, example) in enumerate(all_examples): 765 | features = example_index_to_features[example_index] 766 | 767 | prelim_predictions = [] 768 | # keep track of the minimum score of null start+end of position 0 769 | score_null = 1000000 # large and positive 770 | min_null_feature_index = 0 # the paragraph slice with min mull score 771 | null_start_logit = 0 # the start logit at the slice with min null score 772 | null_end_logit = 0 # the end logit at the slice with min null score 773 | for (feature_index, feature) in enumerate(features): 774 | result = unique_id_to_result[feature.unique_id] 775 | start_indexes = _get_best_indexes(result.start_logits, n_best_size) 776 | end_indexes = _get_best_indexes(result.end_logits, n_best_size) 777 | # if we could have irrelevant answers, get the min score of irrelevant 778 | if FLAGS.version_2_with_negative: 779 | feature_null_score = result.start_logits[0] + result.end_logits[0] 780 | if feature_null_score < score_null: 781 | score_null = feature_null_score 782 | min_null_feature_index = feature_index 783 | null_start_logit = result.start_logits[0] 784 | null_end_logit = result.end_logits[0] 785 | for start_index in start_indexes: 786 | for end_index in end_indexes: 787 | # We could hypothetically create invalid predictions, e.g., predict 788 | # that the start of the span is in the question. We throw out all 789 | # invalid predictions. 790 | if start_index >= len(feature.tokens): 791 | continue 792 | if end_index >= len(feature.tokens): 793 | continue 794 | if start_index not in feature.token_to_orig_map: 795 | continue 796 | if end_index not in feature.token_to_orig_map: 797 | continue 798 | if not feature.token_is_max_context.get(start_index, False): 799 | continue 800 | if end_index < start_index: 801 | continue 802 | length = end_index - start_index + 1 803 | if length > max_answer_length: 804 | continue 805 | prelim_predictions.append( 806 | _PrelimPrediction( 807 | feature_index=feature_index, 808 | start_index=start_index, 809 | end_index=end_index, 810 | start_logit=result.start_logits[start_index], 811 | end_logit=result.end_logits[end_index])) 812 | 813 | if FLAGS.version_2_with_negative: 814 | prelim_predictions.append( 815 | _PrelimPrediction( 816 | feature_index=min_null_feature_index, 817 | start_index=0, 818 | end_index=0, 819 | start_logit=null_start_logit, 820 | end_logit=null_end_logit)) 821 | prelim_predictions = sorted( 822 | prelim_predictions, 823 | key=lambda x: (x.start_logit + x.end_logit), 824 | reverse=True) 825 | 826 | _NbestPrediction = collections.namedtuple( # pylint: disable=invalid-name 827 | "NbestPrediction", ["text", "start_logit", "end_logit"]) 828 | 829 | seen_predictions = {} 830 | nbest = [] 831 | for pred in prelim_predictions: 832 | if len(nbest) >= n_best_size: 833 | break 834 | feature = features[pred.feature_index] 835 | if pred.start_index > 0: # this is a non-null prediction 836 | tok_tokens = feature.tokens[pred.start_index:(pred.end_index + 1)] 837 | orig_doc_start = feature.token_to_orig_map[pred.start_index] 838 | orig_doc_end = feature.token_to_orig_map[pred.end_index] 839 | orig_tokens = example.doc_tokens[orig_doc_start:(orig_doc_end + 1)] 840 | tok_text = " ".join(tok_tokens) 841 | 842 | # De-tokenize WordPieces that have been split off. 843 | tok_text = tok_text.replace(" ##", "") 844 | tok_text = tok_text.replace("##", "") 845 | 846 | # Clean whitespace 847 | tok_text = tok_text.strip() 848 | tok_text = " ".join(tok_text.split()) 849 | orig_text = " ".join(orig_tokens) 850 | 851 | final_text = get_final_text(tok_text, orig_text, do_lower_case) 852 | if final_text in seen_predictions: 853 | continue 854 | 855 | seen_predictions[final_text] = True 856 | else: 857 | final_text = "" 858 | seen_predictions[final_text] = True 859 | 860 | nbest.append( 861 | _NbestPrediction( 862 | text=final_text, 863 | start_logit=pred.start_logit, 864 | end_logit=pred.end_logit)) 865 | 866 | # if we didn't inlude the empty option in the n-best, inlcude it 867 | if FLAGS.version_2_with_negative: 868 | if "" not in seen_predictions: 869 | nbest.append( 870 | _NbestPrediction( 871 | text="", start_logit=null_start_logit, 872 | end_logit=null_end_logit)) 873 | # In very rare edge cases we could have no valid predictions. So we 874 | # just create a nonce prediction in this case to avoid failure. 875 | if not nbest: 876 | nbest.append( 877 | _NbestPrediction(text="empty", start_logit=0.0, end_logit=0.0)) 878 | 879 | assert len(nbest) >= 1 880 | 881 | total_scores = [] 882 | best_non_null_entry = None 883 | for entry in nbest: 884 | total_scores.append(entry.start_logit + entry.end_logit) 885 | if not best_non_null_entry: 886 | if entry.text: 887 | best_non_null_entry = entry 888 | 889 | probs = _compute_softmax(total_scores) 890 | 891 | nbest_json = [] 892 | for (i, entry) in enumerate(nbest): 893 | output = collections.OrderedDict() 894 | output["text"] = entry.text 895 | output["probability"] = probs[i] 896 | output["start_logit"] = entry.start_logit 897 | output["end_logit"] = entry.end_logit 898 | nbest_json.append(output) 899 | 900 | assert len(nbest_json) >= 1 901 | 902 | if not FLAGS.version_2_with_negative: 903 | all_predictions[example.qas_id] = nbest_json[0]["text"] 904 | else: 905 | # predict "" iff the null score - the score of best non-null > threshold 906 | score_diff = score_null - best_non_null_entry.start_logit - ( 907 | best_non_null_entry.end_logit) 908 | scores_diff_json[example.qas_id] = score_diff 909 | if score_diff > FLAGS.null_score_diff_threshold: 910 | all_predictions[example.qas_id] = "" 911 | else: 912 | all_predictions[example.qas_id] = best_non_null_entry.text 913 | 914 | all_nbest_json[example.qas_id] = nbest_json 915 | 916 | with tf.gfile.GFile(output_prediction_file, "w") as writer: 917 | writer.write(json.dumps(all_predictions, indent=4) + "\n") 918 | 919 | with tf.gfile.GFile(output_nbest_file, "w") as writer: 920 | writer.write(json.dumps(all_nbest_json, indent=4) + "\n") 921 | 922 | if FLAGS.version_2_with_negative: 923 | with tf.gfile.GFile(output_null_log_odds_file, "w") as writer: 924 | writer.write(json.dumps(scores_diff_json, indent=4) + "\n") 925 | 926 | 927 | def get_final_text(pred_text, orig_text, do_lower_case): 928 | """Project the tokenized prediction back to the original text.""" 929 | 930 | # When we created the data, we kept track of the alignment between original 931 | # (whitespace tokenized) tokens and our WordPiece tokenized tokens. So 932 | # now `orig_text` contains the span of our original text corresponding to the 933 | # span that we predicted. 934 | # 935 | # However, `orig_text` may contain extra characters that we don't want in 936 | # our prediction. 937 | # 938 | # For example, let's say: 939 | # pred_text = steve smith 940 | # orig_text = Steve Smith's 941 | # 942 | # We don't want to return `orig_text` because it contains the extra "'s". 943 | # 944 | # We don't want to return `pred_text` because it's already been normalized 945 | # (the SQuAD eval script also does punctuation stripping/lower casing but 946 | # our tokenizer does additional normalization like stripping accent 947 | # characters). 948 | # 949 | # What we really want to return is "Steve Smith". 950 | # 951 | # Therefore, we have to apply a semi-complicated alignment heruistic between 952 | # `pred_text` and `orig_text` to get a character-to-charcter alignment. This 953 | # can fail in certain cases in which case we just return `orig_text`. 954 | 955 | def _strip_spaces(text): 956 | ns_chars = [] 957 | ns_to_s_map = collections.OrderedDict() 958 | for (i, c) in enumerate(text): 959 | if c == " ": 960 | continue 961 | ns_to_s_map[len(ns_chars)] = i 962 | ns_chars.append(c) 963 | ns_text = "".join(ns_chars) 964 | return (ns_text, ns_to_s_map) 965 | 966 | # We first tokenize `orig_text`, strip whitespace from the result 967 | # and `pred_text`, and check if they are the same length. If they are 968 | # NOT the same length, the heuristic has failed. If they are the same 969 | # length, we assume the characters are one-to-one aligned. 970 | tokenizer = tokenization.BasicTokenizer(do_lower_case=do_lower_case) 971 | 972 | tok_text = " ".join(tokenizer.tokenize(orig_text)) 973 | 974 | start_position = tok_text.find(pred_text) 975 | if start_position == -1: 976 | if FLAGS.verbose_logging: 977 | tf.logging.info( 978 | "Unable to find text: '%s' in '%s'" % (pred_text, orig_text)) 979 | return orig_text 980 | end_position = start_position + len(pred_text) - 1 981 | 982 | (orig_ns_text, orig_ns_to_s_map) = _strip_spaces(orig_text) 983 | (tok_ns_text, tok_ns_to_s_map) = _strip_spaces(tok_text) 984 | 985 | if len(orig_ns_text) != len(tok_ns_text): 986 | if FLAGS.verbose_logging: 987 | tf.logging.info("Length not equal after stripping spaces: '%s' vs '%s'", 988 | orig_ns_text, tok_ns_text) 989 | return orig_text 990 | 991 | # We then project the characters in `pred_text` back to `orig_text` using 992 | # the character-to-character alignment. 993 | tok_s_to_ns_map = {} 994 | for (i, tok_index) in six.iteritems(tok_ns_to_s_map): 995 | tok_s_to_ns_map[tok_index] = i 996 | 997 | orig_start_position = None 998 | if start_position in tok_s_to_ns_map: 999 | ns_start_position = tok_s_to_ns_map[start_position] 1000 | if ns_start_position in orig_ns_to_s_map: 1001 | orig_start_position = orig_ns_to_s_map[ns_start_position] 1002 | 1003 | if orig_start_position is None: 1004 | if FLAGS.verbose_logging: 1005 | tf.logging.info("Couldn't map start position") 1006 | return orig_text 1007 | 1008 | orig_end_position = None 1009 | if end_position in tok_s_to_ns_map: 1010 | ns_end_position = tok_s_to_ns_map[end_position] 1011 | if ns_end_position in orig_ns_to_s_map: 1012 | orig_end_position = orig_ns_to_s_map[ns_end_position] 1013 | 1014 | if orig_end_position is None: 1015 | if FLAGS.verbose_logging: 1016 | tf.logging.info("Couldn't map end position") 1017 | return orig_text 1018 | 1019 | output_text = orig_text[orig_start_position:(orig_end_position + 1)] 1020 | return output_text 1021 | 1022 | 1023 | def _get_best_indexes(logits, n_best_size): 1024 | """Get the n-best logits from a list.""" 1025 | index_and_score = sorted(enumerate(logits), key=lambda x: x[1], reverse=True) 1026 | 1027 | best_indexes = [] 1028 | for i in range(len(index_and_score)): 1029 | if i >= n_best_size: 1030 | break 1031 | best_indexes.append(index_and_score[i][0]) 1032 | return best_indexes 1033 | 1034 | 1035 | def _compute_softmax(scores): 1036 | """Compute softmax probability over raw logits.""" 1037 | if not scores: 1038 | return [] 1039 | 1040 | max_score = None 1041 | for score in scores: 1042 | if max_score is None or score > max_score: 1043 | max_score = score 1044 | 1045 | exp_scores = [] 1046 | total_sum = 0.0 1047 | for score in scores: 1048 | x = math.exp(score - max_score) 1049 | exp_scores.append(x) 1050 | total_sum += x 1051 | 1052 | probs = [] 1053 | for score in exp_scores: 1054 | probs.append(score / total_sum) 1055 | return probs 1056 | 1057 | 1058 | class FeatureWriter(object): 1059 | """Writes InputFeature to TF example file.""" 1060 | 1061 | def __init__(self, filename, is_training): 1062 | self.filename = filename 1063 | self.is_training = is_training 1064 | self.num_features = 0 1065 | self._writer = tf.python_io.TFRecordWriter(filename) 1066 | 1067 | def process_feature(self, feature): 1068 | """Write a InputFeature to the TFRecordWriter as a tf.train.Example.""" 1069 | self.num_features += 1 1070 | 1071 | def create_int_feature(values): 1072 | feature = tf.train.Feature( 1073 | int64_list=tf.train.Int64List(value=list(values))) 1074 | return feature 1075 | 1076 | features = collections.OrderedDict() 1077 | features["unique_ids"] = create_int_feature([feature.unique_id]) 1078 | features["input_ids"] = create_int_feature(feature.input_ids) 1079 | features["input_mask"] = create_int_feature(feature.input_mask) 1080 | features["segment_ids"] = create_int_feature(feature.segment_ids) 1081 | 1082 | if self.is_training: 1083 | features["start_positions"] = create_int_feature([feature.start_position]) 1084 | features["end_positions"] = create_int_feature([feature.end_position]) 1085 | impossible = 0 1086 | if feature.is_impossible: 1087 | impossible = 1 1088 | features["is_impossible"] = create_int_feature([impossible]) 1089 | 1090 | tf_example = tf.train.Example(features=tf.train.Features(feature=features)) 1091 | self._writer.write(tf_example.SerializeToString()) 1092 | 1093 | def close(self): 1094 | self._writer.close() 1095 | 1096 | 1097 | def validate_flags_or_throw(bert_config): 1098 | """Validate the input FLAGS or throw an exception.""" 1099 | tokenization.validate_case_matches_checkpoint(FLAGS.do_lower_case, 1100 | FLAGS.init_checkpoint) 1101 | 1102 | if not FLAGS.do_train and not FLAGS.do_predict: 1103 | raise ValueError("At least one of `do_train` or `do_predict` must be True.") 1104 | 1105 | if FLAGS.do_train: 1106 | if not FLAGS.train_file: 1107 | raise ValueError( 1108 | "If `do_train` is True, then `train_file` must be specified.") 1109 | if FLAGS.do_predict: 1110 | if not FLAGS.predict_file: 1111 | raise ValueError( 1112 | "If `do_predict` is True, then `predict_file` must be specified.") 1113 | 1114 | if FLAGS.max_seq_length > bert_config.max_position_embeddings: 1115 | raise ValueError( 1116 | "Cannot use sequence length %d because the BERT model " 1117 | "was only trained up to sequence length %d" % 1118 | (FLAGS.max_seq_length, bert_config.max_position_embeddings)) 1119 | 1120 | if FLAGS.max_seq_length <= FLAGS.max_query_length + 3: 1121 | raise ValueError( 1122 | "The max_seq_length (%d) must be greater than max_query_length " 1123 | "(%d) + 3" % (FLAGS.max_seq_length, FLAGS.max_query_length)) 1124 | 1125 | 1126 | def main(_): 1127 | tf.logging.set_verbosity(tf.logging.INFO) 1128 | 1129 | bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file) 1130 | 1131 | validate_flags_or_throw(bert_config) 1132 | 1133 | tf.gfile.MakeDirs(FLAGS.output_dir) 1134 | 1135 | tokenizer = tokenization.FullTokenizer( 1136 | vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case) 1137 | 1138 | tpu_cluster_resolver = None 1139 | if FLAGS.use_tpu and FLAGS.tpu_name: 1140 | tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver( 1141 | FLAGS.tpu_name, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project) 1142 | 1143 | is_per_host = tf.contrib.tpu.InputPipelineConfig.PER_HOST_V2 1144 | run_config = tf.contrib.tpu.RunConfig( 1145 | cluster=tpu_cluster_resolver, 1146 | master=FLAGS.master, 1147 | model_dir=FLAGS.output_dir, 1148 | save_checkpoints_steps=FLAGS.save_checkpoints_steps, 1149 | tpu_config=tf.contrib.tpu.TPUConfig( 1150 | iterations_per_loop=FLAGS.iterations_per_loop, 1151 | num_shards=FLAGS.num_tpu_cores, 1152 | per_host_input_for_training=is_per_host)) 1153 | 1154 | train_examples = None 1155 | num_train_steps = None 1156 | num_warmup_steps = None 1157 | if FLAGS.do_train: 1158 | train_examples = read_squad_examples( 1159 | input_file=FLAGS.train_file, is_training=True) 1160 | num_train_steps = int( 1161 | len(train_examples) / FLAGS.train_batch_size * FLAGS.num_train_epochs) 1162 | num_warmup_steps = int(num_train_steps * FLAGS.warmup_proportion) 1163 | 1164 | # Pre-shuffle the input to avoid having to make a very large shuffle 1165 | # buffer in in the `input_fn`. 1166 | rng = random.Random(12345) 1167 | rng.shuffle(train_examples) 1168 | 1169 | model_fn = model_fn_builder( 1170 | bert_config=bert_config, 1171 | init_checkpoint=FLAGS.init_checkpoint, 1172 | learning_rate=FLAGS.learning_rate, 1173 | num_train_steps=num_train_steps, 1174 | num_warmup_steps=num_warmup_steps, 1175 | use_tpu=FLAGS.use_tpu, 1176 | use_one_hot_embeddings=FLAGS.use_tpu) 1177 | 1178 | # If TPU is not available, this will fall back to normal Estimator on CPU 1179 | # or GPU. 1180 | estimator = tf.contrib.tpu.TPUEstimator( 1181 | use_tpu=FLAGS.use_tpu, 1182 | model_fn=model_fn, 1183 | config=run_config, 1184 | train_batch_size=FLAGS.train_batch_size, 1185 | predict_batch_size=FLAGS.predict_batch_size) 1186 | 1187 | if FLAGS.do_train: 1188 | # We write to a temporary file to avoid storing very large constant tensors 1189 | # in memory. 1190 | train_writer = FeatureWriter( 1191 | filename=os.path.join(FLAGS.output_dir, "train.tf_record"), 1192 | is_training=True) 1193 | convert_examples_to_features( 1194 | examples=train_examples, 1195 | tokenizer=tokenizer, 1196 | max_seq_length=FLAGS.max_seq_length, 1197 | doc_stride=FLAGS.doc_stride, 1198 | max_query_length=FLAGS.max_query_length, 1199 | is_training=True, 1200 | output_fn=train_writer.process_feature) 1201 | train_writer.close() 1202 | 1203 | tf.logging.info("***** Running training *****") 1204 | tf.logging.info(" Num orig examples = %d", len(train_examples)) 1205 | tf.logging.info(" Num split examples = %d", train_writer.num_features) 1206 | tf.logging.info(" Batch size = %d", FLAGS.train_batch_size) 1207 | tf.logging.info(" Num steps = %d", num_train_steps) 1208 | del train_examples 1209 | 1210 | train_input_fn = input_fn_builder( 1211 | input_file=train_writer.filename, 1212 | seq_length=FLAGS.max_seq_length, 1213 | is_training=True, 1214 | drop_remainder=True) 1215 | estimator.train(input_fn=train_input_fn, max_steps=num_train_steps) 1216 | 1217 | if FLAGS.do_predict: 1218 | eval_examples = read_squad_examples( 1219 | input_file=FLAGS.predict_file, is_training=False) 1220 | 1221 | eval_writer = FeatureWriter( 1222 | filename=os.path.join(FLAGS.output_dir, "eval.tf_record"), 1223 | is_training=False) 1224 | eval_features = [] 1225 | 1226 | def append_feature(feature): 1227 | eval_features.append(feature) 1228 | eval_writer.process_feature(feature) 1229 | 1230 | convert_examples_to_features( 1231 | examples=eval_examples, 1232 | tokenizer=tokenizer, 1233 | max_seq_length=FLAGS.max_seq_length, 1234 | doc_stride=FLAGS.doc_stride, 1235 | max_query_length=FLAGS.max_query_length, 1236 | is_training=False, 1237 | output_fn=append_feature) 1238 | eval_writer.close() 1239 | 1240 | tf.logging.info("***** Running predictions *****") 1241 | tf.logging.info(" Num orig examples = %d", len(eval_examples)) 1242 | tf.logging.info(" Num split examples = %d", len(eval_features)) 1243 | tf.logging.info(" Batch size = %d", FLAGS.predict_batch_size) 1244 | 1245 | all_results = [] 1246 | 1247 | predict_input_fn = input_fn_builder( 1248 | input_file=eval_writer.filename, 1249 | seq_length=FLAGS.max_seq_length, 1250 | is_training=False, 1251 | drop_remainder=False) 1252 | 1253 | # If running eval on the TPU, you will need to specify the number of 1254 | # steps. 1255 | all_results = [] 1256 | for result in estimator.predict( 1257 | predict_input_fn, yield_single_examples=True): 1258 | if len(all_results) % 1000 == 0: 1259 | tf.logging.info("Processing example: %d" % (len(all_results))) 1260 | unique_id = int(result["unique_ids"]) 1261 | start_logits = [float(x) for x in result["start_logits"].flat] 1262 | end_logits = [float(x) for x in result["end_logits"].flat] 1263 | all_results.append( 1264 | RawResult( 1265 | unique_id=unique_id, 1266 | start_logits=start_logits, 1267 | end_logits=end_logits)) 1268 | 1269 | output_prediction_file = os.path.join(FLAGS.output_dir, "predictions.json") 1270 | output_nbest_file = os.path.join(FLAGS.output_dir, "nbest_predictions.json") 1271 | output_null_log_odds_file = os.path.join(FLAGS.output_dir, "null_odds.json") 1272 | 1273 | write_predictions(eval_examples, eval_features, all_results, 1274 | FLAGS.n_best_size, FLAGS.max_answer_length, 1275 | FLAGS.do_lower_case, output_prediction_file, 1276 | output_nbest_file, output_null_log_odds_file) 1277 | 1278 | 1279 | if __name__ == "__main__": 1280 | flags.mark_flag_as_required("vocab_file") 1281 | flags.mark_flag_as_required("bert_config_file") 1282 | flags.mark_flag_as_required("output_dir") 1283 | tf.compat.v1.app.run() 1284 | -------------------------------------------------------------------------------- /modeling.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """The main BERT model and related functions.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import collections 22 | import copy 23 | import json 24 | import math 25 | import re 26 | import numpy as np 27 | import six 28 | import tensorflow as tf 29 | import bert_utils 30 | 31 | class BertConfig(object): 32 | """Configuration for `BertModel`.""" 33 | 34 | def __init__(self, 35 | vocab_size, 36 | hidden_size=768, 37 | num_hidden_layers=12, 38 | num_attention_heads=12, 39 | intermediate_size=3072, 40 | hidden_act="gelu", 41 | hidden_dropout_prob=0.1, 42 | attention_probs_dropout_prob=0.1, 43 | max_position_embeddings=512, 44 | type_vocab_size=16, 45 | initializer_range=0.02): 46 | """Constructs BertConfig. 47 | 48 | Args: 49 | vocab_size: Vocabulary size of `inputs_ids` in `BertModel`. 50 | hidden_size: Size of the encoder layers and the pooler layer. 51 | num_hidden_layers: Number of hidden layers in the Transformer encoder. 52 | num_attention_heads: Number of attention heads for each attention layer in 53 | the Transformer encoder. 54 | intermediate_size: The size of the "intermediate" (i.e., feed-forward) 55 | layer in the Transformer encoder. 56 | hidden_act: The non-linear activation function (function or string) in the 57 | encoder and pooler. 58 | hidden_dropout_prob: The dropout probability for all fully connected 59 | layers in the embeddings, encoder, and pooler. 60 | attention_probs_dropout_prob: The dropout ratio for the attention 61 | probabilities. 62 | max_position_embeddings: The maximum sequence length that this model might 63 | ever be used with. Typically set this to something large just in case 64 | (e.g., 512 or 1024 or 2048). 65 | type_vocab_size: The vocabulary size of the `token_type_ids` passed into 66 | `BertModel`. 67 | initializer_range: The stdev of the truncated_normal_initializer for 68 | initializing all weight matrices. 69 | """ 70 | self.vocab_size = vocab_size 71 | self.hidden_size = hidden_size 72 | self.num_hidden_layers = num_hidden_layers 73 | self.num_attention_heads = num_attention_heads 74 | self.hidden_act = hidden_act 75 | self.intermediate_size = intermediate_size 76 | self.hidden_dropout_prob = hidden_dropout_prob 77 | self.attention_probs_dropout_prob = attention_probs_dropout_prob 78 | self.max_position_embeddings = max_position_embeddings 79 | self.type_vocab_size = type_vocab_size 80 | self.initializer_range = initializer_range 81 | 82 | @classmethod 83 | def from_dict(cls, json_object): 84 | """Constructs a `BertConfig` from a Python dictionary of parameters.""" 85 | config = BertConfig(vocab_size=None) 86 | for (key, value) in six.iteritems(json_object): 87 | config.__dict__[key] = value 88 | return config 89 | 90 | @classmethod 91 | def from_json_file(cls, json_file): 92 | """Constructs a `BertConfig` from a json file of parameters.""" 93 | with tf.gfile.GFile(json_file, "r") as reader: 94 | text = reader.read() 95 | return cls.from_dict(json.loads(text)) 96 | 97 | def to_dict(self): 98 | """Serializes this instance to a Python dictionary.""" 99 | output = copy.deepcopy(self.__dict__) 100 | return output 101 | 102 | def to_json_string(self): 103 | """Serializes this instance to a JSON string.""" 104 | return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n" 105 | 106 | 107 | class BertModel(object): 108 | """BERT model ("Bidirectional Encoder Representations from Transformers"). 109 | 110 | Example usage: 111 | 112 | ```python 113 | # Already been converted into WordPiece token ids 114 | input_ids = tf.constant([[31, 51, 99], [15, 5, 0]]) 115 | input_mask = tf.constant([[1, 1, 1], [1, 1, 0]]) 116 | token_type_ids = tf.constant([[0, 0, 1], [0, 2, 0]]) 117 | 118 | config = modeling.BertConfig(vocab_size=32000, hidden_size=512, 119 | num_hidden_layers=8, num_attention_heads=6, intermediate_size=1024) 120 | 121 | model = modeling.BertModel(config=config, is_training=True, 122 | input_ids=input_ids, input_mask=input_mask, token_type_ids=token_type_ids) 123 | 124 | label_embeddings = tf.get_variable(...) 125 | pooled_output = model.get_pooled_output() 126 | logits = tf.matmul(pooled_output, label_embeddings) 127 | ... 128 | ``` 129 | """ 130 | 131 | def __init__(self, 132 | config, 133 | is_training, 134 | input_ids, 135 | input_mask=None, 136 | token_type_ids=None, 137 | use_one_hot_embeddings=False, 138 | scope=None): 139 | """Constructor for BertModel. 140 | 141 | Args: 142 | config: `BertConfig` instance. 143 | is_training: bool. true for training model, false for eval model. Controls 144 | whether dropout will be applied. 145 | input_ids: int32 Tensor of shape [batch_size, seq_length]. 146 | input_mask: (optional) int32 Tensor of shape [batch_size, seq_length]. 147 | token_type_ids: (optional) int32 Tensor of shape [batch_size, seq_length]. 148 | use_one_hot_embeddings: (optional) bool. Whether to use one-hot word 149 | embeddings or tf.embedding_lookup() for the word embeddings. 150 | scope: (optional) variable scope. Defaults to "bert". 151 | 152 | Raises: 153 | ValueError: The config is invalid or one of the input tensor shapes 154 | is invalid. 155 | """ 156 | config = copy.deepcopy(config) 157 | if not is_training: 158 | config.hidden_dropout_prob = 0.0 159 | config.attention_probs_dropout_prob = 0.0 160 | 161 | input_shape = get_shape_list(input_ids, expected_rank=2) 162 | batch_size = input_shape[0] 163 | seq_length = input_shape[1] 164 | 165 | if input_mask is None: 166 | input_mask = tf.ones(shape=[batch_size, seq_length], dtype=tf.int32) 167 | 168 | if token_type_ids is None: 169 | token_type_ids = tf.zeros(shape=[batch_size, seq_length], dtype=tf.int32) 170 | 171 | with tf.variable_scope(scope, default_name="bert"): 172 | with tf.variable_scope("embeddings"): 173 | # Perform embedding lookup on the word ids, but use stype of factorized embedding parameterization from albert. add by brightmart, 2019-09-28 174 | (self.embedding_output, self.embedding_table,self.embedding_table_2) = embedding_lookup_factorized( 175 | input_ids=input_ids, 176 | vocab_size=config.vocab_size, 177 | hidden_size=config.hidden_size, 178 | embedding_size=config.embedding_size, 179 | initializer_range=config.initializer_range, 180 | word_embedding_name="word_embeddings", 181 | use_one_hot_embeddings=use_one_hot_embeddings) 182 | 183 | # Add positional embeddings and token type embeddings, then layer 184 | # normalize and perform dropout. 185 | self.embedding_output = embedding_postprocessor( 186 | input_tensor=self.embedding_output, 187 | use_token_type=True, 188 | token_type_ids=token_type_ids, 189 | token_type_vocab_size=config.type_vocab_size, 190 | token_type_embedding_name="token_type_embeddings", 191 | use_position_embeddings=True, 192 | position_embedding_name="position_embeddings", 193 | initializer_range=config.initializer_range, 194 | max_position_embeddings=config.max_position_embeddings, 195 | dropout_prob=config.hidden_dropout_prob) 196 | 197 | with tf.variable_scope("encoder"): 198 | # This converts a 2D mask of shape [batch_size, seq_length] to a 3D 199 | # mask of shape [batch_size, seq_length, seq_length] which is used 200 | # for the attention scores. 201 | attention_mask = create_attention_mask_from_input_mask( 202 | input_ids, input_mask) 203 | 204 | # Run the stacked transformer. 205 | # `sequence_output` shape = [batch_size, seq_length, hidden_size]. 206 | ln_type=config.ln_type 207 | print("ln_type:",ln_type) 208 | if ln_type=='postln' or ln_type is None: # currently, base or large of albert used post-LN structure 209 | print("old structure of transformer.use: transformer_model,which use post-LN") 210 | self.all_encoder_layers = transformer_model( 211 | input_tensor=self.embedding_output, 212 | attention_mask=attention_mask, 213 | hidden_size=config.hidden_size, 214 | num_hidden_layers=config.num_hidden_layers, 215 | num_attention_heads=config.num_attention_heads, 216 | intermediate_size=config.intermediate_size, 217 | intermediate_act_fn=get_activation(config.hidden_act), 218 | hidden_dropout_prob=config.hidden_dropout_prob, 219 | attention_probs_dropout_prob=config.attention_probs_dropout_prob, 220 | initializer_range=config.initializer_range, 221 | do_return_all_layers=True) 222 | else: # xlarge or xxlarge of albert, used pre-LN structure 223 | print("new structure of transformer.use: prelln_transformer_model,which use pre-LN") 224 | self.all_encoder_layers = prelln_transformer_model( # change by brightmart, 4th, oct, 2019. pre-Layer Normalization can converge fast and better. check paper: ON LAYER NORMALIZATION IN THE TRANSFORMER ARCHITECTURE 225 | input_tensor=self.embedding_output, 226 | attention_mask=attention_mask, 227 | hidden_size=config.hidden_size, 228 | num_hidden_layers=config.num_hidden_layers, 229 | num_attention_heads=config.num_attention_heads, 230 | intermediate_size=config.intermediate_size, 231 | intermediate_act_fn=get_activation(config.hidden_act), 232 | hidden_dropout_prob=config.hidden_dropout_prob, 233 | attention_probs_dropout_prob=config.attention_probs_dropout_prob, 234 | initializer_range=config.initializer_range, 235 | do_return_all_layers=True, 236 | shared_type='all') # do_return_all_layers=True 237 | 238 | self.sequence_output = self.all_encoder_layers[-1] # [batch_size, seq_length, hidden_size] 239 | # The "pooler" converts the encoded sequence tensor of shape 240 | # [batch_size, seq_length, hidden_size] to a tensor of shape 241 | # [batch_size, hidden_size]. This is necessary for segment-level 242 | # (or segment-pair-level) classification tasks where we need a fixed 243 | # dimensional representation of the segment. 244 | with tf.variable_scope("pooler"): 245 | # We "pool" the model by simply taking the hidden state corresponding 246 | # to the first token. We assume that this has been pre-trained 247 | first_token_tensor = tf.squeeze(self.sequence_output[:, 0:1, :], axis=1) 248 | self.pooled_output = tf.layers.dense( 249 | first_token_tensor, 250 | config.hidden_size, 251 | activation=tf.tanh, 252 | kernel_initializer=create_initializer(config.initializer_range)) 253 | 254 | def get_pooled_output(self): 255 | return self.pooled_output 256 | 257 | def get_sequence_output(self): 258 | """Gets final hidden layer of encoder. 259 | 260 | Returns: 261 | float Tensor of shape [batch_size, seq_length, hidden_size] corresponding 262 | to the final hidden of the transformer encoder. 263 | """ 264 | return self.sequence_output 265 | 266 | def get_all_encoder_layers(self): 267 | return self.all_encoder_layers 268 | 269 | def get_embedding_output(self): 270 | """Gets output of the embedding lookup (i.e., input to the transformer). 271 | 272 | Returns: 273 | float Tensor of shape [batch_size, seq_length, hidden_size] corresponding 274 | to the output of the embedding layer, after summing the word 275 | embeddings with the positional embeddings and the token type embeddings, 276 | then performing layer normalization. This is the input to the transformer. 277 | """ 278 | return self.embedding_output 279 | 280 | def get_embedding_table(self): 281 | return self.embedding_table 282 | 283 | def get_embedding_table_2(self): 284 | return self.embedding_table_2 285 | 286 | def gelu(x): 287 | """Gaussian Error Linear Unit. 288 | 289 | This is a smoother version of the RELU. 290 | Original paper: https://arxiv.org/abs/1606.08415 291 | Args: 292 | x: float Tensor to perform activation. 293 | 294 | Returns: 295 | `x` with the GELU activation applied. 296 | """ 297 | cdf = 0.5 * (1.0 + tf.tanh( 298 | (np.sqrt(2 / np.pi) * (x + 0.044715 * tf.pow(x, 3))))) 299 | return x * cdf 300 | 301 | 302 | def get_activation(activation_string): 303 | """Maps a string to a Python function, e.g., "relu" => `tf.nn.relu`. 304 | 305 | Args: 306 | activation_string: String name of the activation function. 307 | 308 | Returns: 309 | A Python function corresponding to the activation function. If 310 | `activation_string` is None, empty, or "linear", this will return None. 311 | If `activation_string` is not a string, it will return `activation_string`. 312 | 313 | Raises: 314 | ValueError: The `activation_string` does not correspond to a known 315 | activation. 316 | """ 317 | 318 | # We assume that anything that"s not a string is already an activation 319 | # function, so we just return it. 320 | if not isinstance(activation_string, six.string_types): 321 | return activation_string 322 | 323 | if not activation_string: 324 | return None 325 | 326 | act = activation_string.lower() 327 | if act == "linear": 328 | return None 329 | elif act == "relu": 330 | return tf.nn.relu 331 | elif act == "gelu": 332 | return gelu 333 | elif act == "tanh": 334 | return tf.tanh 335 | else: 336 | raise ValueError("Unsupported activation: %s" % act) 337 | 338 | 339 | def get_assignment_map_from_checkpoint(tvars, init_checkpoint): 340 | """Compute the union of the current variables and checkpoint variables.""" 341 | assignment_map = {} 342 | initialized_variable_names = {} 343 | 344 | name_to_variable = collections.OrderedDict() 345 | for var in tvars: 346 | name = var.name 347 | m = re.match("^(.*):\\d+$", name) 348 | if m is not None: 349 | name = m.group(1) 350 | name_to_variable[name] = var 351 | 352 | init_vars = tf.train.list_variables(init_checkpoint) 353 | 354 | assignment_map = collections.OrderedDict() 355 | for x in init_vars: 356 | (name, var) = (x[0], x[1]) 357 | if name not in name_to_variable: 358 | continue 359 | assignment_map[name] = name 360 | initialized_variable_names[name] = 1 361 | initialized_variable_names[name + ":0"] = 1 362 | 363 | return (assignment_map, initialized_variable_names) 364 | 365 | 366 | def dropout(input_tensor, dropout_prob): 367 | """Perform dropout. 368 | 369 | Args: 370 | input_tensor: float Tensor. 371 | dropout_prob: Python float. The probability of dropping out a value (NOT of 372 | *keeping* a dimension as in `tf.nn.dropout`). 373 | 374 | Returns: 375 | A version of `input_tensor` with dropout applied. 376 | """ 377 | if dropout_prob is None or dropout_prob == 0.0: 378 | return input_tensor 379 | 380 | output = tf.nn.dropout(input_tensor, 1.0 - dropout_prob) 381 | return output 382 | 383 | 384 | def layer_norm(input_tensor, name=None): 385 | """Run layer normalization on the last dimension of the tensor.""" 386 | return tf.contrib.layers.layer_norm( 387 | inputs=input_tensor, begin_norm_axis=-1, begin_params_axis=-1, scope=name) 388 | 389 | 390 | def layer_norm_and_dropout(input_tensor, dropout_prob, name=None): 391 | """Runs layer normalization followed by dropout.""" 392 | output_tensor = layer_norm(input_tensor, name) 393 | output_tensor = dropout(output_tensor, dropout_prob) 394 | return output_tensor 395 | 396 | 397 | def create_initializer(initializer_range=0.02): 398 | """Creates a `truncated_normal_initializer` with the given range.""" 399 | return tf.truncated_normal_initializer(stddev=initializer_range) 400 | 401 | 402 | def embedding_lookup(input_ids, 403 | vocab_size, 404 | embedding_size=128, 405 | initializer_range=0.02, 406 | word_embedding_name="word_embeddings", 407 | use_one_hot_embeddings=False): 408 | """Looks up words embeddings for id tensor. 409 | 410 | Args: 411 | input_ids: int32 Tensor of shape [batch_size, seq_length] containing word 412 | ids. 413 | vocab_size: int. Size of the embedding vocabulary. 414 | embedding_size: int. Width of the word embeddings. 415 | initializer_range: float. Embedding initialization range. 416 | word_embedding_name: string. Name of the embedding table. 417 | use_one_hot_embeddings: bool. If True, use one-hot method for word 418 | embeddings. If False, use `tf.gather()`. 419 | 420 | Returns: 421 | float Tensor of shape [batch_size, seq_length, embedding_size]. 422 | """ 423 | # This function assumes that the input is of shape [batch_size, seq_length, 424 | # num_inputs]. 425 | # 426 | # If the input is a 2D tensor of shape [batch_size, seq_length], we 427 | # reshape to [batch_size, seq_length, 1]. 428 | if input_ids.shape.ndims == 2: 429 | input_ids = tf.expand_dims(input_ids, axis=[-1]) # shape of input_ids is:[ batch_size, seq_length, 1] 430 | 431 | embedding_table = tf.get_variable( # [vocab_size, embedding_size] 432 | name=word_embedding_name, 433 | shape=[vocab_size, embedding_size], 434 | initializer=create_initializer(initializer_range)) 435 | 436 | flat_input_ids = tf.reshape(input_ids, [-1]) # one rank. shape as (batch_size * sequence_length,) 437 | if use_one_hot_embeddings: 438 | one_hot_input_ids = tf.one_hot(flat_input_ids, depth=vocab_size) # one_hot_input_ids=[batch_size * sequence_length,vocab_size] 439 | output = tf.matmul(one_hot_input_ids, embedding_table) # output=[batch_size * sequence_length,embedding_size] 440 | else: 441 | output = tf.gather(embedding_table, flat_input_ids) # [vocab_size, embedding_size]*[batch_size * sequence_length,]--->[batch_size * sequence_length,embedding_size] 442 | 443 | input_shape = get_shape_list(input_ids) # input_shape=[ batch_size, seq_length, 1] 444 | 445 | output = tf.reshape(output,input_shape[0:-1] + [input_shape[-1] * embedding_size]) # output=[batch_size,sequence_length,embedding_size] 446 | return (output, embedding_table) 447 | 448 | def embedding_lookup_factorized(input_ids, # Factorized embedding parameterization provide by albert 449 | vocab_size, 450 | hidden_size, 451 | embedding_size=128, 452 | initializer_range=0.02, 453 | word_embedding_name="word_embeddings", 454 | use_one_hot_embeddings=False): 455 | """Looks up words embeddings for id tensor, but in a factorized style followed by albert. it is used to reduce much percentage of parameters previous exists. 456 | Check "Factorized embedding parameterization" session in the paper. 457 | 458 | Args: 459 | input_ids: int32 Tensor of shape [batch_size, seq_length] containing word 460 | ids. 461 | vocab_size: int. Size of the embedding vocabulary. 462 | embedding_size: int. Width of the word embeddings. 463 | initializer_range: float. Embedding initialization range. 464 | word_embedding_name: string. Name of the embedding table. 465 | use_one_hot_embeddings: bool. If True, use one-hot method for word 466 | embeddings. If False, use `tf.gather()`. 467 | 468 | Returns: 469 | float Tensor of shape [batch_size, seq_length, embedding_size]. 470 | """ 471 | # This function assumes that the input is of shape [batch_size, seq_length, 472 | # num_inputs]. 473 | # 474 | # If the input is a 2D tensor of shape [batch_size, seq_length], we 475 | # reshape to [batch_size, seq_length, 1]. 476 | 477 | # 1.first project one-hot vectors into a lower dimensional embedding space of size E 478 | print("embedding_lookup_factorized. factorized embedding parameterization is used.") 479 | if input_ids.shape.ndims == 2: 480 | input_ids = tf.expand_dims(input_ids, axis=[-1]) # shape of input_ids is:[ batch_size, seq_length, 1] 481 | 482 | embedding_table = tf.get_variable( # [vocab_size, embedding_size] 483 | name=word_embedding_name, 484 | shape=[vocab_size, embedding_size], 485 | initializer=create_initializer(initializer_range)) 486 | 487 | flat_input_ids = tf.reshape(input_ids, [-1]) # one rank. shape as (batch_size * sequence_length,) 488 | if use_one_hot_embeddings: 489 | one_hot_input_ids = tf.one_hot(flat_input_ids,depth=vocab_size) # one_hot_input_ids=[batch_size * sequence_length,vocab_size] 490 | output_middle = tf.matmul(one_hot_input_ids, embedding_table) # output=[batch_size * sequence_length,embedding_size] 491 | else: 492 | output_middle = tf.gather(embedding_table,flat_input_ids) # [vocab_size, embedding_size]*[batch_size * sequence_length,]--->[batch_size * sequence_length,embedding_size] 493 | 494 | # 2. project vector(output_middle) to the hidden space 495 | project_variable = tf.get_variable( # [embedding_size, hidden_size] 496 | name=word_embedding_name+"_2", 497 | shape=[embedding_size, hidden_size], 498 | initializer=create_initializer(initializer_range)) 499 | output = tf.matmul(output_middle, project_variable) # ([batch_size * sequence_length, embedding_size] * [embedding_size, hidden_size])--->[batch_size * sequence_length, hidden_size] 500 | # reshape back to 3 rank 501 | input_shape = get_shape_list(input_ids) # input_shape=[ batch_size, seq_length, 1] 502 | batch_size, sequene_length, _=input_shape 503 | output = tf.reshape(output, (batch_size,sequene_length,hidden_size)) # output=[batch_size, sequence_length, hidden_size] 504 | return (output, embedding_table, project_variable) 505 | 506 | 507 | def embedding_postprocessor(input_tensor, 508 | use_token_type=False, 509 | token_type_ids=None, 510 | token_type_vocab_size=16, 511 | token_type_embedding_name="token_type_embeddings", 512 | use_position_embeddings=True, 513 | position_embedding_name="position_embeddings", 514 | initializer_range=0.02, 515 | max_position_embeddings=512, 516 | dropout_prob=0.1): 517 | """Performs various post-processing on a word embedding tensor. 518 | 519 | Args: 520 | input_tensor: float Tensor of shape [batch_size, seq_length, 521 | embedding_size]. 522 | use_token_type: bool. Whether to add embeddings for `token_type_ids`. 523 | token_type_ids: (optional) int32 Tensor of shape [batch_size, seq_length]. 524 | Must be specified if `use_token_type` is True. 525 | token_type_vocab_size: int. The vocabulary size of `token_type_ids`. 526 | token_type_embedding_name: string. The name of the embedding table variable 527 | for token type ids. 528 | use_position_embeddings: bool. Whether to add position embeddings for the 529 | position of each token in the sequence. 530 | position_embedding_name: string. The name of the embedding table variable 531 | for positional embeddings. 532 | initializer_range: float. Range of the weight initialization. 533 | max_position_embeddings: int. Maximum sequence length that might ever be 534 | used with this model. This can be longer than the sequence length of 535 | input_tensor, but cannot be shorter. 536 | dropout_prob: float. Dropout probability applied to the final output tensor. 537 | 538 | Returns: 539 | float tensor with same shape as `input_tensor`. 540 | 541 | Raises: 542 | ValueError: One of the tensor shapes or input values is invalid. 543 | """ 544 | input_shape = get_shape_list(input_tensor, expected_rank=3) 545 | batch_size = input_shape[0] 546 | seq_length = input_shape[1] 547 | width = input_shape[2] 548 | 549 | output = input_tensor 550 | 551 | if use_token_type: 552 | if token_type_ids is None: 553 | raise ValueError("`token_type_ids` must be specified if" 554 | "`use_token_type` is True.") 555 | token_type_table = tf.get_variable( 556 | name=token_type_embedding_name, 557 | shape=[token_type_vocab_size, width], 558 | initializer=create_initializer(initializer_range)) 559 | # This vocab will be small so we always do one-hot here, since it is always 560 | # faster for a small vocabulary. 561 | flat_token_type_ids = tf.reshape(token_type_ids, [-1]) 562 | one_hot_ids = tf.one_hot(flat_token_type_ids, depth=token_type_vocab_size) 563 | token_type_embeddings = tf.matmul(one_hot_ids, token_type_table) 564 | token_type_embeddings = tf.reshape(token_type_embeddings, 565 | [batch_size, seq_length, width]) 566 | output += token_type_embeddings 567 | 568 | if use_position_embeddings: 569 | assert_op = tf.assert_less_equal(seq_length, max_position_embeddings) 570 | with tf.control_dependencies([assert_op]): 571 | full_position_embeddings = tf.get_variable( 572 | name=position_embedding_name, 573 | shape=[max_position_embeddings, width], 574 | initializer=create_initializer(initializer_range)) 575 | # Since the position embedding table is a learned variable, we create it 576 | # using a (long) sequence length `max_position_embeddings`. The actual 577 | # sequence length might be shorter than this, for faster training of 578 | # tasks that do not have long sequences. 579 | # 580 | # So `full_position_embeddings` is effectively an embedding table 581 | # for position [0, 1, 2, ..., max_position_embeddings-1], and the current 582 | # sequence has positions [0, 1, 2, ... seq_length-1], so we can just 583 | # perform a slice. 584 | position_embeddings = tf.slice(full_position_embeddings, [0, 0], 585 | [seq_length, -1]) 586 | num_dims = len(output.shape.as_list()) 587 | 588 | # Only the last two dimensions are relevant (`seq_length` and `width`), so 589 | # we broadcast among the first dimensions, which is typically just 590 | # the batch size. 591 | position_broadcast_shape = [] 592 | for _ in range(num_dims - 2): 593 | position_broadcast_shape.append(1) 594 | position_broadcast_shape.extend([seq_length, width]) 595 | position_embeddings = tf.reshape(position_embeddings, 596 | position_broadcast_shape) 597 | output += position_embeddings 598 | 599 | output = layer_norm_and_dropout(output, dropout_prob) 600 | return output 601 | 602 | 603 | def create_attention_mask_from_input_mask(from_tensor, to_mask): 604 | """Create 3D attention mask from a 2D tensor mask. 605 | 606 | Args: 607 | from_tensor: 2D or 3D Tensor of shape [batch_size, from_seq_length, ...]. 608 | to_mask: int32 Tensor of shape [batch_size, to_seq_length]. 609 | 610 | Returns: 611 | float Tensor of shape [batch_size, from_seq_length, to_seq_length]. 612 | """ 613 | from_shape = get_shape_list(from_tensor, expected_rank=[2, 3]) 614 | batch_size = from_shape[0] 615 | from_seq_length = from_shape[1] 616 | 617 | to_shape = get_shape_list(to_mask, expected_rank=2) 618 | to_seq_length = to_shape[1] 619 | 620 | to_mask = tf.cast( 621 | tf.reshape(to_mask, [batch_size, 1, to_seq_length]), tf.float32) 622 | 623 | # We don't assume that `from_tensor` is a mask (although it could be). We 624 | # don't actually care if we attend *from* padding tokens (only *to* padding) 625 | # tokens so we create a tensor of all ones. 626 | # 627 | # `broadcast_ones` = [batch_size, from_seq_length, 1] 628 | broadcast_ones = tf.ones( 629 | shape=[batch_size, from_seq_length, 1], dtype=tf.float32) 630 | 631 | # Here we broadcast along two dimensions to create the mask. 632 | mask = broadcast_ones * to_mask 633 | 634 | return mask 635 | 636 | 637 | def attention_layer(from_tensor, 638 | to_tensor, 639 | attention_mask=None, 640 | num_attention_heads=1, 641 | size_per_head=512, 642 | query_act=None, 643 | key_act=None, 644 | value_act=None, 645 | attention_probs_dropout_prob=0.0, 646 | initializer_range=0.02, 647 | do_return_2d_tensor=False, 648 | batch_size=None, 649 | from_seq_length=None, 650 | to_seq_length=None): 651 | """Performs multi-headed attention from `from_tensor` to `to_tensor`. 652 | 653 | This is an implementation of multi-headed attention based on "Attention 654 | is all you Need". If `from_tensor` and `to_tensor` are the same, then 655 | this is self-attention. Each timestep in `from_tensor` attends to the 656 | corresponding sequence in `to_tensor`, and returns a fixed-with vector. 657 | 658 | This function first projects `from_tensor` into a "query" tensor and 659 | `to_tensor` into "key" and "value" tensors. These are (effectively) a list 660 | of tensors of length `num_attention_heads`, where each tensor is of shape 661 | [batch_size, seq_length, size_per_head]. 662 | 663 | Then, the query and key tensors are dot-producted and scaled. These are 664 | softmaxed to obtain attention probabilities. The value tensors are then 665 | interpolated by these probabilities, then concatenated back to a single 666 | tensor and returned. 667 | 668 | In practice, the multi-headed attention are done with transposes and 669 | reshapes rather than actual separate tensors. 670 | 671 | Args: 672 | from_tensor: float Tensor of shape [batch_size, from_seq_length, 673 | from_width]. 674 | to_tensor: float Tensor of shape [batch_size, to_seq_length, to_width]. 675 | attention_mask: (optional) int32 Tensor of shape [batch_size, 676 | from_seq_length, to_seq_length]. The values should be 1 or 0. The 677 | attention scores will effectively be set to -infinity for any positions in 678 | the mask that are 0, and will be unchanged for positions that are 1. 679 | num_attention_heads: int. Number of attention heads. 680 | size_per_head: int. Size of each attention head. 681 | query_act: (optional) Activation function for the query transform. 682 | key_act: (optional) Activation function for the key transform. 683 | value_act: (optional) Activation function for the value transform. 684 | attention_probs_dropout_prob: (optional) float. Dropout probability of the 685 | attention probabilities. 686 | initializer_range: float. Range of the weight initializer. 687 | do_return_2d_tensor: bool. If True, the output will be of shape [batch_size 688 | * from_seq_length, num_attention_heads * size_per_head]. If False, the 689 | output will be of shape [batch_size, from_seq_length, num_attention_heads 690 | * size_per_head]. 691 | batch_size: (Optional) int. If the input is 2D, this might be the batch size 692 | of the 3D version of the `from_tensor` and `to_tensor`. 693 | from_seq_length: (Optional) If the input is 2D, this might be the seq length 694 | of the 3D version of the `from_tensor`. 695 | to_seq_length: (Optional) If the input is 2D, this might be the seq length 696 | of the 3D version of the `to_tensor`. 697 | 698 | Returns: 699 | float Tensor of shape [batch_size, from_seq_length, 700 | num_attention_heads * size_per_head]. (If `do_return_2d_tensor` is 701 | true, this will be of shape [batch_size * from_seq_length, 702 | num_attention_heads * size_per_head]). 703 | 704 | Raises: 705 | ValueError: Any of the arguments or tensor shapes are invalid. 706 | """ 707 | 708 | def transpose_for_scores(input_tensor, batch_size, num_attention_heads, 709 | seq_length, width): 710 | output_tensor = tf.reshape( 711 | input_tensor, [batch_size, seq_length, num_attention_heads, width]) 712 | 713 | output_tensor = tf.transpose(output_tensor, [0, 2, 1, 3]) 714 | return output_tensor 715 | 716 | from_shape = get_shape_list(from_tensor, expected_rank=[2, 3]) 717 | to_shape = get_shape_list(to_tensor, expected_rank=[2, 3]) 718 | 719 | if len(from_shape) != len(to_shape): 720 | raise ValueError( 721 | "The rank of `from_tensor` must match the rank of `to_tensor`.") 722 | 723 | if len(from_shape) == 3: 724 | batch_size = from_shape[0] 725 | from_seq_length = from_shape[1] 726 | to_seq_length = to_shape[1] 727 | elif len(from_shape) == 2: 728 | if (batch_size is None or from_seq_length is None or to_seq_length is None): 729 | raise ValueError( 730 | "When passing in rank 2 tensors to attention_layer, the values " 731 | "for `batch_size`, `from_seq_length`, and `to_seq_length` " 732 | "must all be specified.") 733 | 734 | # Scalar dimensions referenced here: 735 | # B = batch size (number of sequences) 736 | # F = `from_tensor` sequence length 737 | # T = `to_tensor` sequence length 738 | # N = `num_attention_heads` 739 | # H = `size_per_head` 740 | 741 | from_tensor_2d = reshape_to_matrix(from_tensor) 742 | to_tensor_2d = reshape_to_matrix(to_tensor) 743 | 744 | # `query_layer` = [B*F, N*H] 745 | query_layer = tf.layers.dense( 746 | from_tensor_2d, 747 | num_attention_heads * size_per_head, 748 | activation=query_act, 749 | name="query", 750 | kernel_initializer=create_initializer(initializer_range)) 751 | 752 | # `key_layer` = [B*T, N*H] 753 | key_layer = tf.layers.dense( 754 | to_tensor_2d, 755 | num_attention_heads * size_per_head, 756 | activation=key_act, 757 | name="key", 758 | kernel_initializer=create_initializer(initializer_range)) 759 | 760 | # `value_layer` = [B*T, N*H] 761 | value_layer = tf.layers.dense( 762 | to_tensor_2d, 763 | num_attention_heads * size_per_head, 764 | activation=value_act, 765 | name="value", 766 | kernel_initializer=create_initializer(initializer_range)) 767 | 768 | # `query_layer` = [B, N, F, H] 769 | query_layer = transpose_for_scores(query_layer, batch_size, 770 | num_attention_heads, from_seq_length, 771 | size_per_head) 772 | 773 | # `key_layer` = [B, N, T, H] 774 | key_layer = transpose_for_scores(key_layer, batch_size, num_attention_heads, 775 | to_seq_length, size_per_head) 776 | 777 | # Take the dot product between "query" and "key" to get the raw 778 | # attention scores. 779 | # `attention_scores` = [B, N, F, T] 780 | attention_scores = tf.matmul(query_layer, key_layer, transpose_b=True) 781 | attention_scores = tf.multiply(attention_scores, 782 | 1.0 / math.sqrt(float(size_per_head))) 783 | 784 | if attention_mask is not None: 785 | # `attention_mask` = [B, 1, F, T] 786 | attention_mask = tf.expand_dims(attention_mask, axis=[1]) 787 | 788 | # Since attention_mask is 1.0 for positions we want to attend and 0.0 for 789 | # masked positions, this operation will create a tensor which is 0.0 for 790 | # positions we want to attend and -10000.0 for masked positions. 791 | adder = (1.0 - tf.cast(attention_mask, tf.float32)) * -10000.0 792 | 793 | # Since we are adding it to the raw scores before the softmax, this is 794 | # effectively the same as removing these entirely. 795 | attention_scores += adder 796 | 797 | # Normalize the attention scores to probabilities. 798 | # `attention_probs` = [B, N, F, T] 799 | attention_probs = tf.nn.softmax(attention_scores) 800 | 801 | # This is actually dropping out entire tokens to attend to, which might 802 | # seem a bit unusual, but is taken from the original Transformer paper. 803 | attention_probs = dropout(attention_probs, attention_probs_dropout_prob) 804 | 805 | # `value_layer` = [B, T, N, H] 806 | value_layer = tf.reshape( 807 | value_layer, 808 | [batch_size, to_seq_length, num_attention_heads, size_per_head]) 809 | 810 | # `value_layer` = [B, N, T, H] 811 | value_layer = tf.transpose(value_layer, [0, 2, 1, 3]) 812 | 813 | # `context_layer` = [B, N, F, H] 814 | context_layer = tf.matmul(attention_probs, value_layer) 815 | 816 | # `context_layer` = [B, F, N, H] 817 | context_layer = tf.transpose(context_layer, [0, 2, 1, 3]) 818 | 819 | if do_return_2d_tensor: 820 | # `context_layer` = [B*F, N*H] 821 | context_layer = tf.reshape( 822 | context_layer, 823 | [batch_size * from_seq_length, num_attention_heads * size_per_head]) 824 | else: 825 | # `context_layer` = [B, F, N*H] 826 | context_layer = tf.reshape( 827 | context_layer, 828 | [batch_size, from_seq_length, num_attention_heads * size_per_head]) 829 | 830 | return context_layer 831 | 832 | 833 | def transformer_model(input_tensor, 834 | attention_mask=None, 835 | hidden_size=768, 836 | num_hidden_layers=12, 837 | num_attention_heads=12, 838 | intermediate_size=3072, 839 | intermediate_act_fn=gelu, 840 | hidden_dropout_prob=0.1, 841 | attention_probs_dropout_prob=0.1, 842 | initializer_range=0.02, 843 | do_return_all_layers=False, 844 | share_parameter_across_layers=True): 845 | """Multi-headed, multi-layer Transformer from "Attention is All You Need". 846 | 847 | This is almost an exact implementation of the original Transformer encoder. 848 | 849 | See the original paper: 850 | https://arxiv.org/abs/1706.03762 851 | 852 | Also see: 853 | https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/models/transformer.py 854 | 855 | Args: 856 | input_tensor: float Tensor of shape [batch_size, seq_length, hidden_size]. 857 | attention_mask: (optional) int32 Tensor of shape [batch_size, seq_length, 858 | seq_length], with 1 for positions that can be attended to and 0 in 859 | positions that should not be. 860 | hidden_size: int. Hidden size of the Transformer. 861 | num_hidden_layers: int. Number of layers (blocks) in the Transformer. 862 | num_attention_heads: int. Number of attention heads in the Transformer. 863 | intermediate_size: int. The size of the "intermediate" (a.k.a., feed 864 | forward) layer. 865 | intermediate_act_fn: function. The non-linear activation function to apply 866 | to the output of the intermediate/feed-forward layer. 867 | hidden_dropout_prob: float. Dropout probability for the hidden layers. 868 | attention_probs_dropout_prob: float. Dropout probability of the attention 869 | probabilities. 870 | initializer_range: float. Range of the initializer (stddev of truncated 871 | normal). 872 | do_return_all_layers: Whether to also return all layers or just the final 873 | layer. 874 | 875 | Returns: 876 | float Tensor of shape [batch_size, seq_length, hidden_size], the final 877 | hidden layer of the Transformer. 878 | 879 | Raises: 880 | ValueError: A Tensor shape or parameter is invalid. 881 | """ 882 | if hidden_size % num_attention_heads != 0: 883 | raise ValueError( 884 | "The hidden size (%d) is not a multiple of the number of attention " 885 | "heads (%d)" % (hidden_size, num_attention_heads)) 886 | 887 | attention_head_size = int(hidden_size / num_attention_heads) 888 | input_shape = get_shape_list(input_tensor, expected_rank=3) 889 | batch_size = input_shape[0] 890 | seq_length = input_shape[1] 891 | input_width = input_shape[2] 892 | 893 | # The Transformer performs sum residuals on all layers so the input needs 894 | # to be the same as the hidden size. 895 | if input_width != hidden_size: 896 | raise ValueError("The width of the input tensor (%d) != hidden size (%d)" % 897 | (input_width, hidden_size)) 898 | 899 | # We keep the representation as a 2D tensor to avoid re-shaping it back and 900 | # forth from a 3D tensor to a 2D tensor. Re-shapes are normally free on 901 | # the GPU/CPU but may not be free on the TPU, so we want to minimize them to 902 | # help the optimizer. 903 | prev_output = reshape_to_matrix(input_tensor) 904 | 905 | all_layer_outputs = [] 906 | for layer_idx in range(num_hidden_layers): 907 | if share_parameter_across_layers: 908 | name_variable_scope="layer_shared" 909 | else: 910 | name_variable_scope="layer_%d" % layer_idx 911 | # share all parameters across layers. add by brightmart, 2019-09-28. previous it is like this: "layer_%d" % layer_idx 912 | with tf.variable_scope(name_variable_scope, reuse=True if (share_parameter_across_layers and layer_idx>0) else False): 913 | 914 | layer_input = prev_output 915 | 916 | with tf.variable_scope("attention"): 917 | attention_heads = [] 918 | with tf.variable_scope("self"): 919 | attention_head = attention_layer( 920 | from_tensor=layer_input, 921 | to_tensor=layer_input, 922 | attention_mask=attention_mask, 923 | num_attention_heads=num_attention_heads, 924 | size_per_head=attention_head_size, 925 | attention_probs_dropout_prob=attention_probs_dropout_prob, 926 | initializer_range=initializer_range, 927 | do_return_2d_tensor=True, 928 | batch_size=batch_size, 929 | from_seq_length=seq_length, 930 | to_seq_length=seq_length) 931 | attention_heads.append(attention_head) 932 | 933 | attention_output = None 934 | if len(attention_heads) == 1: 935 | attention_output = attention_heads[0] 936 | else: 937 | # In the case where we have other sequences, we just concatenate 938 | # them to the self-attention head before the projection. 939 | attention_output = tf.concat(attention_heads, axis=-1) 940 | 941 | # Run a linear projection of `hidden_size` then add a residual 942 | # with `layer_input`. 943 | with tf.variable_scope("output"): 944 | attention_output = tf.layers.dense( 945 | attention_output, 946 | hidden_size, 947 | kernel_initializer=create_initializer(initializer_range)) 948 | attention_output = dropout(attention_output, hidden_dropout_prob) 949 | attention_output = layer_norm(attention_output + layer_input) 950 | 951 | # The activation is only applied to the "intermediate" hidden layer. 952 | with tf.variable_scope("intermediate"): 953 | intermediate_output = tf.layers.dense( 954 | attention_output, 955 | intermediate_size, 956 | activation=intermediate_act_fn, 957 | kernel_initializer=create_initializer(initializer_range)) 958 | 959 | # Down-project back to `hidden_size` then add the residual. 960 | with tf.variable_scope("output"): 961 | layer_output = tf.layers.dense( 962 | intermediate_output, 963 | hidden_size, 964 | kernel_initializer=create_initializer(initializer_range)) 965 | layer_output = dropout(layer_output, hidden_dropout_prob) 966 | layer_output = layer_norm(layer_output + attention_output) 967 | prev_output = layer_output 968 | all_layer_outputs.append(layer_output) 969 | 970 | if do_return_all_layers: 971 | final_outputs = [] 972 | for layer_output in all_layer_outputs: 973 | final_output = reshape_from_matrix(layer_output, input_shape) 974 | final_outputs.append(final_output) 975 | return final_outputs 976 | else: 977 | final_output = reshape_from_matrix(prev_output, input_shape) 978 | return final_output 979 | 980 | 981 | def get_shape_list(tensor, expected_rank=None, name=None): 982 | """Returns a list of the shape of tensor, preferring static dimensions. 983 | 984 | Args: 985 | tensor: A tf.Tensor object to find the shape of. 986 | expected_rank: (optional) int. The expected rank of `tensor`. If this is 987 | specified and the `tensor` has a different rank, and exception will be 988 | thrown. 989 | name: Optional name of the tensor for the error message. 990 | 991 | Returns: 992 | A list of dimensions of the shape of tensor. All static dimensions will 993 | be returned as python integers, and dynamic dimensions will be returned 994 | as tf.Tensor scalars. 995 | """ 996 | if name is None: 997 | name = tensor.name 998 | 999 | if expected_rank is not None: 1000 | assert_rank(tensor, expected_rank, name) 1001 | 1002 | shape = tensor.shape.as_list() 1003 | 1004 | non_static_indexes = [] 1005 | for (index, dim) in enumerate(shape): 1006 | if dim is None: 1007 | non_static_indexes.append(index) 1008 | 1009 | if not non_static_indexes: 1010 | return shape 1011 | 1012 | dyn_shape = tf.shape(tensor) 1013 | for index in non_static_indexes: 1014 | shape[index] = dyn_shape[index] 1015 | return shape 1016 | 1017 | 1018 | def reshape_to_matrix(input_tensor): 1019 | """Reshapes a >= rank 2 tensor to a rank 2 tensor (i.e., a matrix).""" 1020 | ndims = input_tensor.shape.ndims 1021 | if ndims < 2: 1022 | raise ValueError("Input tensor must have at least rank 2. Shape = %s" % 1023 | (input_tensor.shape)) 1024 | if ndims == 2: 1025 | return input_tensor 1026 | 1027 | width = input_tensor.shape[-1] 1028 | output_tensor = tf.reshape(input_tensor, [-1, width]) 1029 | return output_tensor 1030 | 1031 | 1032 | def reshape_from_matrix(output_tensor, orig_shape_list): 1033 | """Reshapes a rank 2 tensor back to its original rank >= 2 tensor.""" 1034 | if len(orig_shape_list) == 2: 1035 | return output_tensor 1036 | 1037 | output_shape = get_shape_list(output_tensor) 1038 | 1039 | orig_dims = orig_shape_list[0:-1] 1040 | width = output_shape[-1] 1041 | 1042 | return tf.reshape(output_tensor, orig_dims + [width]) 1043 | 1044 | 1045 | def assert_rank(tensor, expected_rank, name=None): 1046 | """Raises an exception if the tensor rank is not of the expected rank. 1047 | 1048 | Args: 1049 | tensor: A tf.Tensor to check the rank of. 1050 | expected_rank: Python integer or list of integers, expected rank. 1051 | name: Optional name of the tensor for the error message. 1052 | 1053 | Raises: 1054 | ValueError: If the expected shape doesn't match the actual shape. 1055 | """ 1056 | if name is None: 1057 | name = tensor.name 1058 | 1059 | expected_rank_dict = {} 1060 | if isinstance(expected_rank, six.integer_types): 1061 | expected_rank_dict[expected_rank] = True 1062 | else: 1063 | for x in expected_rank: 1064 | expected_rank_dict[x] = True 1065 | 1066 | actual_rank = tensor.shape.ndims 1067 | if actual_rank not in expected_rank_dict: 1068 | scope_name = tf.get_variable_scope().name 1069 | raise ValueError( 1070 | "For the tensor `%s` in scope `%s`, the actual rank " 1071 | "`%d` (shape = %s) is not equal to the expected rank `%s`" % 1072 | (name, scope_name, actual_rank, str(tensor.shape), str(expected_rank))) 1073 | 1074 | def prelln_transformer_model(input_tensor, 1075 | attention_mask=None, 1076 | hidden_size=768, 1077 | num_hidden_layers=12, 1078 | num_attention_heads=12, 1079 | intermediate_size=3072, 1080 | intermediate_act_fn=gelu, 1081 | hidden_dropout_prob=0.1, 1082 | attention_probs_dropout_prob=0.1, 1083 | initializer_range=0.02, 1084 | do_return_all_layers=False, 1085 | shared_type='all', # None, 1086 | adapter_fn=None): 1087 | """Multi-headed, multi-layer Transformer from "Attention is All You Need". 1088 | 1089 | This is almost an exact implementation of the original Transformer encoder. 1090 | 1091 | See the original paper: 1092 | https://arxiv.org/abs/1706.03762 1093 | 1094 | Also see: 1095 | https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/models/transformer.py 1096 | 1097 | Args: 1098 | input_tensor: float Tensor of shape [batch_size, seq_length, hidden_size]. 1099 | attention_mask: (optional) int32 Tensor of shape [batch_size, seq_length, 1100 | seq_length], with 1 for positions that can be attended to and 0 in 1101 | positions that should not be. 1102 | hidden_size: int. Hidden size of the Transformer. 1103 | num_hidden_layers: int. Number of layers (blocks) in the Transformer. 1104 | num_attention_heads: int. Number of attention heads in the Transformer. 1105 | intermediate_size: int. The size of the "intermediate" (a.k.a., feed 1106 | forward) layer. 1107 | intermediate_act_fn: function. The non-linear activation function to apply 1108 | to the output of the intermediate/feed-forward layer. 1109 | hidden_dropout_prob: float. Dropout probability for the hidden layers. 1110 | attention_probs_dropout_prob: float. Dropout probability of the attention 1111 | probabilities. 1112 | initializer_range: float. Range of the initializer (stddev of truncated 1113 | normal). 1114 | do_return_all_layers: Whether to also return all layers or just the final 1115 | layer. 1116 | 1117 | Returns: 1118 | float Tensor of shape [batch_size, seq_length, hidden_size], the final 1119 | hidden layer of the Transformer. 1120 | 1121 | Raises: 1122 | ValueError: A Tensor shape or parameter is invalid. 1123 | """ 1124 | if hidden_size % num_attention_heads != 0: 1125 | raise ValueError( 1126 | "The hidden size (%d) is not a multiple of the number of attention " 1127 | "heads (%d)" % (hidden_size, num_attention_heads)) 1128 | 1129 | attention_head_size = int(hidden_size / num_attention_heads) 1130 | 1131 | input_shape = bert_utils.get_shape_list(input_tensor, expected_rank=3) 1132 | batch_size = input_shape[0] 1133 | seq_length = input_shape[1] 1134 | input_width = input_shape[2] 1135 | 1136 | # The Transformer performs sum residuals on all layers so the input needs 1137 | # to be the same as the hidden size. 1138 | if input_width != hidden_size: 1139 | raise ValueError("The width of the input tensor (%d) != hidden size (%d)" % 1140 | (input_width, hidden_size)) 1141 | 1142 | # We keep the representation as a 2D tensor to avoid re-shaping it back and 1143 | # forth from a 3D tensor to a 2D tensor. Re-shapes are normally free on 1144 | # the GPU/CPU but may not be free on the TPU, so we want to minimize them to 1145 | # help the optimizer. 1146 | prev_output = bert_utils.reshape_to_matrix(input_tensor) 1147 | 1148 | all_layer_outputs = [] 1149 | 1150 | def layer_scope(idx, shared_type): 1151 | if shared_type == 'all': 1152 | tmp = { 1153 | "layer":"layer_shared", 1154 | 'attention':'attention', 1155 | 'intermediate':'intermediate', 1156 | 'output':'output' 1157 | } 1158 | elif shared_type == 'attention': 1159 | tmp = { 1160 | "layer":"layer_shared", 1161 | 'attention':'attention', 1162 | 'intermediate':'intermediate_{}'.format(idx), 1163 | 'output':'output_{}'.format(idx) 1164 | } 1165 | elif shared_type == 'ffn': 1166 | tmp = { 1167 | "layer":"layer_shared", 1168 | 'attention':'attention_{}'.format(idx), 1169 | 'intermediate':'intermediate', 1170 | 'output':'output' 1171 | } 1172 | else: 1173 | tmp = { 1174 | "layer":"layer_{}".format(idx), 1175 | 'attention':'attention', 1176 | 'intermediate':'intermediate', 1177 | 'output':'output' 1178 | } 1179 | 1180 | return tmp 1181 | 1182 | all_layer_outputs = [] 1183 | 1184 | for layer_idx in range(num_hidden_layers): 1185 | 1186 | idx_scope = layer_scope(layer_idx, shared_type) 1187 | 1188 | with tf.variable_scope(idx_scope['layer'], reuse=tf.AUTO_REUSE): 1189 | layer_input = prev_output 1190 | 1191 | with tf.variable_scope(idx_scope['attention'], reuse=tf.AUTO_REUSE): 1192 | attention_heads = [] 1193 | 1194 | with tf.variable_scope("output", reuse=tf.AUTO_REUSE): 1195 | layer_input_pre = layer_norm(layer_input) 1196 | 1197 | with tf.variable_scope("self"): 1198 | attention_head = attention_layer( 1199 | from_tensor=layer_input_pre, 1200 | to_tensor=layer_input_pre, 1201 | attention_mask=attention_mask, 1202 | num_attention_heads=num_attention_heads, 1203 | size_per_head=attention_head_size, 1204 | attention_probs_dropout_prob=attention_probs_dropout_prob, 1205 | initializer_range=initializer_range, 1206 | do_return_2d_tensor=True, 1207 | batch_size=batch_size, 1208 | from_seq_length=seq_length, 1209 | to_seq_length=seq_length) 1210 | attention_heads.append(attention_head) 1211 | 1212 | attention_output = None 1213 | if len(attention_heads) == 1: 1214 | attention_output = attention_heads[0] 1215 | else: 1216 | # In the case where we have other sequences, we just concatenate 1217 | # them to the self-attention head before the projection. 1218 | attention_output = tf.concat(attention_heads, axis=-1) 1219 | 1220 | # Run a linear projection of `hidden_size` then add a residual 1221 | # with `layer_input`. 1222 | with tf.variable_scope("output", reuse=tf.AUTO_REUSE): 1223 | attention_output = tf.layers.dense( 1224 | attention_output, 1225 | hidden_size, 1226 | kernel_initializer=create_initializer(initializer_range)) 1227 | attention_output = dropout(attention_output, hidden_dropout_prob) 1228 | 1229 | # attention_output = layer_norm(attention_output + layer_input) 1230 | attention_output = attention_output + layer_input 1231 | 1232 | with tf.variable_scope(idx_scope['output'], reuse=tf.AUTO_REUSE): 1233 | attention_output_pre = layer_norm(attention_output) 1234 | 1235 | # The activation is only applied to the "intermediate" hidden layer. 1236 | with tf.variable_scope(idx_scope['intermediate'], reuse=tf.AUTO_REUSE): 1237 | intermediate_output = tf.layers.dense( 1238 | attention_output_pre, 1239 | intermediate_size, 1240 | activation=intermediate_act_fn, 1241 | kernel_initializer=create_initializer(initializer_range)) 1242 | 1243 | # Down-project back to `hidden_size` then add the residual. 1244 | with tf.variable_scope(idx_scope['output'], reuse=tf.AUTO_REUSE): 1245 | layer_output = tf.layers.dense( 1246 | intermediate_output, 1247 | hidden_size, 1248 | kernel_initializer=create_initializer(initializer_range)) 1249 | layer_output = dropout(layer_output, hidden_dropout_prob) 1250 | 1251 | # layer_output = layer_norm(layer_output + attention_output) 1252 | layer_output = layer_output + attention_output 1253 | prev_output = layer_output 1254 | all_layer_outputs.append(layer_output) 1255 | 1256 | if do_return_all_layers: 1257 | final_outputs = [] 1258 | for layer_output in all_layer_outputs: 1259 | final_output = bert_utils.reshape_from_matrix(layer_output, input_shape) 1260 | final_outputs.append(final_output) 1261 | return final_outputs 1262 | else: 1263 | final_output = bert_utils.reshape_from_matrix(prev_output, input_shape) 1264 | return final_output 1265 | --------------------------------------------------------------------------------