├── LICENSE ├── README.md ├── bert ├── README.md ├── __init__.py ├── create_pretraining_data.py ├── extract_features.py ├── general_utils.py ├── modeling.py ├── modeling_test.py ├── optimization.py ├── optimization_test.py ├── run_classifier.py ├── run_classifier_predict_online.py ├── run_pretraining.py ├── theseus_replacement_scheduler.py ├── tokenization.py └── tokenization_test.py ├── common_utils.py ├── configs ├── __init__.py ├── base_config.py ├── bert_config.py ├── bert_mrc_config.py └── event_config.py ├── data └── slot_pattern │ ├── slot_descrip │ ├── slot_descrip_old │ ├── vocab_all_event_type_label_map.txt │ └── vocab_all_slot_label_noBI_map.txt ├── data_processing ├── __init__.py ├── basic_prepare_data.py ├── bert_mrc_prepare_data.py ├── bert_prepare_data.py ├── data_utils.py ├── event_prepare_data.py ├── mrc_query_map.py └── tokenize.py ├── event_predict.py ├── gen_kfold_data.py ├── models ├── __init__.py ├── bert_event_type_classification.py ├── bert_mrc.py ├── event_verify_av.py ├── layers │ └── __init__.py ├── tf_metrics.py └── utils.py ├── optimization.py ├── requirements.txt ├── run_event.py ├── run_event_classification.sh ├── run_event_role.sh ├── run_retro_eav.sh ├── run_retro_rolemrc.sh └── train_helper.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 qiufengyuyi 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # 百度aistudio 2020 事件抽取赛道 2 | 3 | ------ 4 | 5 | ## update on 2020.07.09 优化了gen_kfold_data.py,重构了代码,能够生成event type分类任务的数据index_type_fold_data_{},以及role extraction阶段任务的数据verfify_neg_fold_data_{},注意所有文件的对应路径要与自己在config中配置的相同。另外,neg_fold_data_{}为之前老版本的数据存储格式,与verify_neg_fold_data_{}是一样的,可以一样使用 6 | 依赖包:主要是tensorflow 1.12.0,另外使用了bojone的bert4keras,详见https://github.com/bojone/bert4keras 7 | 8 | ,其余见requirements.txt 9 | 10 | 目前主要集中使用机器阅读理解的方式来尝试解决事件抽取任务。主要分为两个阶段: 11 | 12 | 1、事件类型抽取 13 | 14 | 2、事件论元抽取,使用MRC的方式来做。 15 | 16 | 具体内容见知乎文章。 17 | 18 | 项目主体来自于本人另一个repo,使用MRC做实体识别,具体可参考https://github.com/qiufengyuyi/sequence_tagging 19 | 20 | 最终使用Retro-reader方法,在test1.json上的分数为0.856,使用了roberta-large-wwm. 21 | 22 | 对于roberta-wwm-base,分数为0.851. 23 | 24 | ## 生成k-fold训练数据: 25 | 26 | 根据不同阶段,生成两个阶段的k-fold训练数据,具体可参考gen_kfold_data.py 27 | 28 | 29 | ## 事件类型抽取: 30 | 31 | ```shell 32 | bash run_event_classification.sh 33 | ``` 34 | 35 | ## baseline事件论元抽取: 36 | 37 | ```shell 38 | bash run_event_role.sh 39 | ``` 40 | 41 | ## RetroReader-EAV问题是否可回答模块: 42 | 43 | ```shell 44 | bash run_retro_eav.sh 45 | ``` 46 | 47 | ## RetroReader-精读模块: 48 | 49 | ```shell 50 | bash run_retro_rolemrc.sh 51 | ``` 52 | -------------------------------------------------------------------------------- /bert/README.md: -------------------------------------------------------------------------------- 1 | [TOC] 2 | 3 | # Use BERT as feature 4 | 1. 如何调用bert,将输入的语句输出为向量? 5 | 2. 如果在自己的代码中添加bert作为底层特征,需要官方例子run_classifier.py的那么多代码吗? 6 | # 环境 7 | 8 | ```python 9 | mac: 10 | tf==1.4.0 11 | python=2.7 12 | 13 | windows: 14 | tf==1.12 15 | python=3.5 16 | ``` 17 | 18 | # 入口 19 | 20 | 调用预训练的模型,来做句子的预测。 21 | bert_as_feature.py 22 | 配置data_root为模型的地址 23 | 调用预训练模型:chinese_L-12_H-768_A-12 24 | 调用核心代码: 25 | ```python 26 | # graph 27 | input_ids = tf.placeholder(tf.int32, shape=[None, None], name='input_ids') 28 | input_mask = tf.placeholder(tf.int32, shape=[None, None], name='input_masks') 29 | segment_ids = tf.placeholder(tf.int32, shape=[None, None], name='segment_ids') 30 | 31 | # 初始化BERT 32 | model = modeling.BertModel( 33 | config=bert_config, 34 | is_training=False, 35 | input_ids=input_ids, 36 | input_mask=input_mask, 37 | token_type_ids=segment_ids, 38 | use_one_hot_embeddings=False) 39 | 40 | # 加载bert模型 41 | tvars = tf.trainable_variables() 42 | (assignment, initialized_variable_names) = modeling.get_assignment_map_from_checkpoint(tvars, init_check_point) 43 | 44 | # 获取最后一层和倒数第二层。 45 | encoder_last_layer = model.get_sequence_output() 46 | encoder_last2_layer = model.all_encoder_layers[-2] 47 | 48 | with tf.Session() as sess: 49 | sess.run(tf.global_variables_initializer()) 50 | 51 | token = tokenization.CharTokenizer(vocab_file=bert_vocab_file) 52 | query = u'Jack,请回答1988, UNwant\u00E9d,running' 53 | split_tokens = token.tokenize(query) 54 | word_ids = token.convert_tokens_to_ids(split_tokens) 55 | word_mask = [1] * len(word_ids) 56 | word_segment_ids = [0] * len(word_ids) 57 | fd = {input_ids: [word_ids], input_mask: [word_mask], segment_ids: [word_segment_ids]} 58 | last, last2 = sess.run([encoder_last_layer, encoder_last_layer], feed_dict=fd) 59 | print('last shape:{}, last2 shape: {}'.format(last.shape, last2.shape)) 60 | ``` 61 | 62 | 完整代码见: [bert_as_feature.py](https://github.com/InsaneLife/bert/blob/master/bert_as_feature.py) 63 | 64 | 代码库:https://github.com/InsaneLife/bert 65 | 66 | 中文模型下载:**[`BERT-Base, Chinese`](https://storage.googleapis.com/bert_models/2018_11_03/chinese_L-12_H-768_A-12.zip)**: Chinese Simplified and Traditional, 12-layer, 768-hidden, 12-heads, 110M parameters 67 | 68 | # 最终结果 69 | 70 | 最后一层和倒数第二层: 71 | last shape:(1, 14, 768), last2 shape: (1, 14, 768) 72 | 73 | ``` 74 | # last value 75 | [[ 0.8200665 1.7532703 -0.3771637 ... -0.63692784 -0.17133102 76 | 0.01075665] 77 | [ 0.79148203 -0.08384223 -0.51832616 ... 0.8080162 1.9931345 78 | 1.072408 ] 79 | [-0.02546642 2.2759912 -0.6004753 ... -0.88577884 3.1459959 80 | -0.03815675] 81 | ... 82 | [-0.15581022 1.154014 -0.96733016 ... -0.47922543 0.51068854 83 | 0.29749477] 84 | [ 0.38253042 0.09779643 -0.39919692 ... 0.98277044 0.6780443 85 | -0.52883977] 86 | [ 0.20359193 -0.42314947 0.51891303 ... -0.23625426 0.666618 87 | 0.30184716]] 88 | ``` 89 | 90 | 91 | 92 | # 预处理 93 | 94 | `tokenization.py`是对输入的句子处理,包含两个主要类:`BasickTokenizer`, `FullTokenizer` 95 | 96 | `BasickTokenizer`会对每个字做分割,会识别英文单词,对于数字会合并,例如: 97 | 98 | ``` 99 | query: 'Jack,请回答1988, UNwant\u00E9d,running' 100 | token: ['jack', ',', '请', '回', '答', '1988', ',', 'unwanted', ',', 'running'] 101 | ``` 102 | 103 | `FullTokenizer`会对英文字符做n-gram匹配,会将英文单词拆分,例如running会拆分为run、##ing,主要是针对英文。 104 | 105 | ``` 106 | query: 'UNwant\u00E9d,running' 107 | token: ["un", "##want", "##ed", ",", "runn", "##ing"] 108 | ``` 109 | 110 | 对于中文数据,特别是NER,如果数字和英文单词是整体的话,会出现大量UNK,所以要将其拆开,想要的结果: 111 | 112 | ``` 113 | query: 'Jack,请回答1988' 114 | token: ['j', 'a', 'c', 'k', ',', '请', '回', '答', '1', '9', '8', '8'] 115 | ``` 116 | 117 | 具体变动如下: 118 | 119 | ```python 120 | class CharTokenizer(object): 121 | """Runs end-to-end tokenziation.""" 122 | def __init__(self, vocab_file, do_lower_case=True): 123 | self.vocab = load_vocab(vocab_file) 124 | self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case) 125 | self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab) 126 | 127 | def tokenize(self, text): 128 | split_tokens = [] 129 | for token in self.basic_tokenizer.tokenize(text): 130 | for sub_token in token: 131 | split_tokens.append(sub_token) 132 | return split_tokens 133 | 134 | def convert_tokens_to_ids(self, tokens): 135 | return convert_tokens_to_ids(self.vocab, tokens) 136 | ``` 137 | 138 | 139 | 140 | 141 | 142 | 143 | -------------------------------------------------------------------------------- /bert/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | -------------------------------------------------------------------------------- /bert/create_pretraining_data.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Create masked LM/next sentence masked_lm TF examples for BERT.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import collections 22 | import random 23 | 24 | from bert import tokenization 25 | import tensorflow as tf 26 | 27 | flags = tf.flags 28 | 29 | FLAGS = flags.FLAGS 30 | 31 | flags.DEFINE_string("input_file", None, 32 | "Input raw text file (or comma-separated list of files).") 33 | 34 | flags.DEFINE_string( 35 | "output_file", None, 36 | "Output TF example file (or comma-separated list of files).") 37 | 38 | flags.DEFINE_string("vocab_file", None, 39 | "The vocabulary file that the BERT model was trained on.") 40 | 41 | flags.DEFINE_bool( 42 | "do_lower_case", True, 43 | "Whether to lower case the input text. Should be True for uncased " 44 | "models and False for cased models.") 45 | 46 | flags.DEFINE_integer("max_seq_length", 128, "Maximum sequence length.") 47 | 48 | flags.DEFINE_integer("max_predictions_per_seq", 20, 49 | "Maximum number of masked LM predictions per sequence.") 50 | 51 | flags.DEFINE_integer("random_seed", 12345, "Random seed for data generation.") 52 | 53 | flags.DEFINE_integer( 54 | "dupe_factor", 10, 55 | "Number of times to duplicate the input data (with different masks).") 56 | 57 | flags.DEFINE_float("masked_lm_prob", 0.15, "Masked LM probability.") 58 | 59 | flags.DEFINE_float( 60 | "short_seq_prob", 0.1, 61 | "Probability of creating sequences which are shorter than the " 62 | "maximum length.") 63 | 64 | 65 | class TrainingInstance(object): 66 | """A single training instance (sentence pair).""" 67 | 68 | def __init__(self, tokens, segment_ids, masked_lm_positions, masked_lm_labels, 69 | is_random_next): 70 | self.tokens = tokens 71 | self.segment_ids = segment_ids 72 | self.is_random_next = is_random_next 73 | self.masked_lm_positions = masked_lm_positions 74 | self.masked_lm_labels = masked_lm_labels 75 | 76 | def __str__(self): 77 | s = "" 78 | s += "tokens: %s\n" % (" ".join( 79 | [tokenization.printable_text(x) for x in self.tokens])) 80 | s += "segment_ids: %s\n" % (" ".join([str(x) for x in self.segment_ids])) 81 | s += "is_random_next: %s\n" % self.is_random_next 82 | s += "masked_lm_positions: %s\n" % (" ".join( 83 | [str(x) for x in self.masked_lm_positions])) 84 | s += "masked_lm_labels: %s\n" % (" ".join( 85 | [tokenization.printable_text(x) for x in self.masked_lm_labels])) 86 | s += "\n" 87 | return s 88 | 89 | def __repr__(self): 90 | return self.__str__() 91 | 92 | 93 | def write_instance_to_example_files(instances, tokenizer, max_seq_length, 94 | max_predictions_per_seq, output_files): 95 | """Create TF example files from `TrainingInstance`s.""" 96 | writers = [] 97 | for output_file in output_files: 98 | writers.append(tf.python_io.TFRecordWriter(output_file)) 99 | 100 | writer_index = 0 101 | 102 | total_written = 0 103 | for (inst_index, instance) in enumerate(instances): 104 | input_ids = tokenizer.convert_tokens_to_ids(instance.tokens) 105 | input_mask = [1] * len(input_ids) 106 | segment_ids = list(instance.segment_ids) 107 | assert len(input_ids) <= max_seq_length 108 | 109 | while len(input_ids) < max_seq_length: 110 | input_ids.append(0) 111 | input_mask.append(0) 112 | segment_ids.append(0) 113 | 114 | assert len(input_ids) == max_seq_length 115 | assert len(input_mask) == max_seq_length 116 | assert len(segment_ids) == max_seq_length 117 | 118 | masked_lm_positions = list(instance.masked_lm_positions) 119 | masked_lm_ids = tokenizer.convert_tokens_to_ids(instance.masked_lm_labels) 120 | masked_lm_weights = [1.0] * len(masked_lm_ids) 121 | 122 | while len(masked_lm_positions) < max_predictions_per_seq: 123 | masked_lm_positions.append(0) 124 | masked_lm_ids.append(0) 125 | masked_lm_weights.append(0.0) 126 | 127 | next_sentence_label = 1 if instance.is_random_next else 0 128 | 129 | features = collections.OrderedDict() 130 | features["input_ids"] = create_int_feature(input_ids) 131 | features["input_mask"] = create_int_feature(input_mask) 132 | features["segment_ids"] = create_int_feature(segment_ids) 133 | features["masked_lm_positions"] = create_int_feature(masked_lm_positions) 134 | features["masked_lm_ids"] = create_int_feature(masked_lm_ids) 135 | features["masked_lm_weights"] = create_float_feature(masked_lm_weights) 136 | features["next_sentence_labels"] = create_int_feature([next_sentence_label]) 137 | 138 | tf_example = tf.train.Example(features=tf.train.Features(feature=features)) 139 | 140 | writers[writer_index].write(tf_example.SerializeToString()) 141 | writer_index = (writer_index + 1) % len(writers) 142 | 143 | total_written += 1 144 | 145 | if inst_index < 20: 146 | tf.logging.info("*** Example ***") 147 | tf.logging.info("tokens: %s" % " ".join( 148 | [tokenization.printable_text(x) for x in instance.tokens])) 149 | 150 | for feature_name in features.keys(): 151 | feature = features[feature_name] 152 | values = [] 153 | if feature.int64_list.value: 154 | values = feature.int64_list.value 155 | elif feature.float_list.value: 156 | values = feature.float_list.value 157 | tf.logging.info( 158 | "%s: %s" % (feature_name, " ".join([str(x) for x in values]))) 159 | 160 | for writer in writers: 161 | writer.close() 162 | 163 | tf.logging.info("Wrote %d total instances", total_written) 164 | 165 | 166 | def create_int_feature(values): 167 | feature = tf.train.Feature(int64_list=tf.train.Int64List(value=list(values))) 168 | return feature 169 | 170 | 171 | def create_float_feature(values): 172 | feature = tf.train.Feature(float_list=tf.train.FloatList(value=list(values))) 173 | return feature 174 | 175 | 176 | def create_training_instances(input_files, tokenizer, max_seq_length, 177 | dupe_factor, short_seq_prob, masked_lm_prob, 178 | max_predictions_per_seq, rng): 179 | """Create `TrainingInstance`s from raw text.""" 180 | all_documents = [[]] 181 | 182 | # Input file format: 183 | # (1) One sentence per line. These should ideally be actual sentences, not 184 | # entire paragraphs or arbitrary spans of text. (Because we use the 185 | # sentence boundaries for the "next sentence prediction" task). 186 | # (2) Blank lines between documents. Document boundaries are needed so 187 | # that the "next sentence prediction" task doesn't span between documents. 188 | for input_file in input_files: 189 | with tf.gfile.GFile(input_file, "r") as reader: 190 | while True: 191 | line = tokenization.convert_to_unicode(reader.readline()) 192 | if not line: 193 | break 194 | line = line.strip() 195 | 196 | # Empty lines are used as document delimiters 197 | if not line: 198 | all_documents.append([]) 199 | tokens = tokenizer.tokenize(line) 200 | if tokens: 201 | all_documents[-1].append(tokens) 202 | 203 | # Remove empty documents 204 | all_documents = [x for x in all_documents if x] 205 | rng.shuffle(all_documents) 206 | 207 | vocab_words = list(tokenizer.vocab.keys()) 208 | instances = [] 209 | for _ in range(dupe_factor): 210 | for document_index in range(len(all_documents)): 211 | instances.extend( 212 | create_instances_from_document( 213 | all_documents, document_index, max_seq_length, short_seq_prob, 214 | masked_lm_prob, max_predictions_per_seq, vocab_words, rng)) 215 | 216 | rng.shuffle(instances) 217 | return instances 218 | 219 | 220 | def create_instances_from_document( 221 | all_documents, document_index, max_seq_length, short_seq_prob, 222 | masked_lm_prob, max_predictions_per_seq, vocab_words, rng): 223 | """Creates `TrainingInstance`s for a single document.""" 224 | document = all_documents[document_index] 225 | 226 | # Account for [CLS], [SEP], [SEP] 227 | max_num_tokens = max_seq_length - 3 228 | 229 | # We *usually* want to fill up the entire sequence since we are padding 230 | # to `max_seq_length` anyways, so short sequences are generally wasted 231 | # computation. However, we *sometimes* 232 | # (i.e., short_seq_prob == 0.1 == 10% of the time) want to use shorter 233 | # sequences to minimize the mismatch between pre-training and fine-tuning. 234 | # The `target_seq_length` is just a rough target however, whereas 235 | # `max_seq_length` is a hard limit. 236 | target_seq_length = max_num_tokens 237 | if rng.random() < short_seq_prob: 238 | target_seq_length = rng.randint(2, max_num_tokens) 239 | 240 | # We DON'T just concatenate all of the tokens from a document into a long 241 | # sequence and choose an arbitrary split point because this would make the 242 | # next sentence prediction task too easy. Instead, we split the input into 243 | # segments "A" and "B" based on the actual "sentences" provided by the user 244 | # input. 245 | instances = [] 246 | current_chunk = [] 247 | current_length = 0 248 | i = 0 249 | while i < len(document): 250 | segment = document[i] 251 | current_chunk.append(segment) 252 | current_length += len(segment) 253 | if i == len(document) - 1 or current_length >= target_seq_length: 254 | if current_chunk: 255 | # `a_end` is how many segments from `current_chunk` go into the `A` 256 | # (first) sentence. 257 | a_end = 1 258 | if len(current_chunk) >= 2: 259 | a_end = rng.randint(1, len(current_chunk) - 1) 260 | 261 | tokens_a = [] 262 | for j in range(a_end): 263 | tokens_a.extend(current_chunk[j]) 264 | 265 | tokens_b = [] 266 | # Random next 267 | is_random_next = False 268 | if len(current_chunk) == 1 or rng.random() < 0.5: 269 | is_random_next = True 270 | target_b_length = target_seq_length - len(tokens_a) 271 | 272 | # This should rarely go for more than one iteration for large 273 | # corpora. However, just to be careful, we try to make sure that 274 | # the random document is not the same as the document 275 | # we're processing. 276 | for _ in range(10): 277 | random_document_index = rng.randint(0, len(all_documents) - 1) 278 | if random_document_index != document_index: 279 | break 280 | 281 | random_document = all_documents[random_document_index] 282 | random_start = rng.randint(0, len(random_document) - 1) 283 | for j in range(random_start, len(random_document)): 284 | tokens_b.extend(random_document[j]) 285 | if len(tokens_b) >= target_b_length: 286 | break 287 | # We didn't actually use these segments so we "put them back" so 288 | # they don't go to waste. 289 | num_unused_segments = len(current_chunk) - a_end 290 | i -= num_unused_segments 291 | # Actual next 292 | else: 293 | is_random_next = False 294 | for j in range(a_end, len(current_chunk)): 295 | tokens_b.extend(current_chunk[j]) 296 | truncate_seq_pair(tokens_a, tokens_b, max_num_tokens, rng) 297 | 298 | assert len(tokens_a) >= 1 299 | assert len(tokens_b) >= 1 300 | 301 | tokens = [] 302 | segment_ids = [] 303 | tokens.append("[CLS]") 304 | segment_ids.append(0) 305 | for token in tokens_a: 306 | tokens.append(token) 307 | segment_ids.append(0) 308 | 309 | tokens.append("[SEP]") 310 | segment_ids.append(0) 311 | 312 | for token in tokens_b: 313 | tokens.append(token) 314 | segment_ids.append(1) 315 | tokens.append("[SEP]") 316 | segment_ids.append(1) 317 | 318 | (tokens, masked_lm_positions, 319 | masked_lm_labels) = create_masked_lm_predictions( 320 | tokens, masked_lm_prob, max_predictions_per_seq, vocab_words, rng) 321 | instance = TrainingInstance( 322 | tokens=tokens, 323 | segment_ids=segment_ids, 324 | is_random_next=is_random_next, 325 | masked_lm_positions=masked_lm_positions, 326 | masked_lm_labels=masked_lm_labels) 327 | instances.append(instance) 328 | current_chunk = [] 329 | current_length = 0 330 | i += 1 331 | 332 | return instances 333 | 334 | 335 | def create_masked_lm_predictions(tokens, masked_lm_prob, 336 | max_predictions_per_seq, vocab_words, rng): 337 | """Creates the predictions for the masked LM objective.""" 338 | 339 | cand_indexes = [] 340 | for (i, token) in enumerate(tokens): 341 | if token == "[CLS]" or token == "[SEP]": 342 | continue 343 | cand_indexes.append(i) 344 | 345 | rng.shuffle(cand_indexes) 346 | 347 | output_tokens = list(tokens) 348 | 349 | masked_lm = collections.namedtuple("masked_lm", ["index", "label"]) # pylint: disable=invalid-name 350 | 351 | num_to_predict = min(max_predictions_per_seq, 352 | max(1, int(round(len(tokens) * masked_lm_prob)))) 353 | 354 | masked_lms = [] 355 | covered_indexes = set() 356 | for index in cand_indexes: 357 | if len(masked_lms) >= num_to_predict: 358 | break 359 | if index in covered_indexes: 360 | continue 361 | covered_indexes.add(index) 362 | 363 | masked_token = None 364 | # 80% of the time, replace with [MASK] 365 | if rng.random() < 0.8: 366 | masked_token = "[MASK]" 367 | else: 368 | # 10% of the time, keep original 369 | if rng.random() < 0.5: 370 | masked_token = tokens[index] 371 | # 10% of the time, replace with random word 372 | else: 373 | masked_token = vocab_words[rng.randint(0, len(vocab_words) - 1)] 374 | 375 | output_tokens[index] = masked_token 376 | 377 | masked_lms.append(masked_lm(index=index, label=tokens[index])) 378 | 379 | masked_lms = sorted(masked_lms, key=lambda x: x.index) 380 | 381 | masked_lm_positions = [] 382 | masked_lm_labels = [] 383 | for p in masked_lms: 384 | masked_lm_positions.append(p.index) 385 | masked_lm_labels.append(p.label) 386 | 387 | return (output_tokens, masked_lm_positions, masked_lm_labels) 388 | 389 | 390 | def truncate_seq_pair(tokens_a, tokens_b, max_num_tokens, rng): 391 | """Truncates a pair of sequences to a maximum sequence length.""" 392 | while True: 393 | total_length = len(tokens_a) + len(tokens_b) 394 | if total_length <= max_num_tokens: 395 | break 396 | 397 | trunc_tokens = tokens_a if len(tokens_a) > len(tokens_b) else tokens_b 398 | assert len(trunc_tokens) >= 1 399 | 400 | # We want to sometimes truncate from the front and sometimes from the 401 | # back to add more randomness and avoid biases. 402 | if rng.random() < 0.5: 403 | del trunc_tokens[0] 404 | else: 405 | trunc_tokens.pop() 406 | 407 | 408 | def main(_): 409 | tf.logging.set_verbosity(tf.logging.INFO) 410 | 411 | tokenizer = tokenization.FullTokenizer( 412 | vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case) 413 | 414 | input_files = [] 415 | for input_pattern in FLAGS.input_file.split(","): 416 | input_files.extend(tf.gfile.Glob(input_pattern)) 417 | 418 | tf.logging.info("*** Reading from input files ***") 419 | for input_file in input_files: 420 | tf.logging.info(" %s", input_file) 421 | 422 | rng = random.Random(FLAGS.random_seed) 423 | instances = create_training_instances( 424 | input_files, tokenizer, FLAGS.max_seq_length, FLAGS.dupe_factor, 425 | FLAGS.short_seq_prob, FLAGS.masked_lm_prob, FLAGS.max_predictions_per_seq, 426 | rng) 427 | 428 | output_files = FLAGS.output_file.split(",") 429 | tf.logging.info("*** Writing to output files ***") 430 | for output_file in output_files: 431 | tf.logging.info(" %s", output_file) 432 | 433 | write_instance_to_example_files(instances, tokenizer, FLAGS.max_seq_length, 434 | FLAGS.max_predictions_per_seq, output_files) 435 | 436 | 437 | if __name__ == "__main__": 438 | flags.mark_flag_as_required("input_file") 439 | flags.mark_flag_as_required("output_file") 440 | flags.mark_flag_as_required("vocab_file") 441 | tf.app.run() 442 | -------------------------------------------------------------------------------- /bert/extract_features.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Extract pre-computed feature vectors from BERT.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import codecs 22 | import collections 23 | import json 24 | import re 25 | 26 | from bert import modeling 27 | from bert import tokenization 28 | import tensorflow as tf 29 | 30 | flags = tf.flags 31 | 32 | FLAGS = flags.FLAGS 33 | 34 | flags.DEFINE_string("input_file", None, "") 35 | 36 | flags.DEFINE_string("output_file", None, "") 37 | 38 | flags.DEFINE_string("layers", "-1,-2,-3,-4", "") 39 | 40 | flags.DEFINE_string( 41 | "bert_config_file", None, 42 | "The config json file corresponding to the pre-trained BERT model. " 43 | "This specifies the model architecture.") 44 | 45 | flags.DEFINE_integer( 46 | "max_seq_length", 128, 47 | "The maximum total input sequence length after WordPiece tokenization. " 48 | "Sequences longer than this will be truncated, and sequences shorter " 49 | "than this will be padded.") 50 | 51 | flags.DEFINE_string( 52 | "init_checkpoint", None, 53 | "Initial checkpoint (usually from a pre-trained BERT model).") 54 | 55 | flags.DEFINE_string("vocab_file", None, 56 | "The vocabulary file that the BERT model was trained on.") 57 | 58 | flags.DEFINE_bool( 59 | "do_lower_case", True, 60 | "Whether to lower case the input text. Should be True for uncased " 61 | "models and False for cased models.") 62 | 63 | flags.DEFINE_integer("batch_size", 32, "Batch size for predictions.") 64 | 65 | flags.DEFINE_bool("use_tpu", False, "Whether to use TPU or GPU/CPU.") 66 | 67 | flags.DEFINE_string("master", None, 68 | "If using a TPU, the address of the master.") 69 | 70 | flags.DEFINE_integer( 71 | "num_tpu_cores", 8, 72 | "Only used if `use_tpu` is True. Total number of TPU cores to use.") 73 | 74 | flags.DEFINE_bool( 75 | "use_one_hot_embeddings", False, 76 | "If True, tf.one_hot will be used for embedding lookups, otherwise " 77 | "tf.nn.embedding_lookup will be used. On TPUs, this should be True " 78 | "since it is much faster.") 79 | 80 | 81 | class InputExample(object): 82 | 83 | def __init__(self, unique_id, text_a, text_b): 84 | self.unique_id = unique_id 85 | self.text_a = text_a 86 | self.text_b = text_b 87 | 88 | 89 | class InputFeatures(object): 90 | """A single set of features of data.""" 91 | 92 | def __init__(self, unique_id, tokens, input_ids, input_mask, input_type_ids): 93 | self.unique_id = unique_id 94 | self.tokens = tokens 95 | self.input_ids = input_ids 96 | self.input_mask = input_mask 97 | self.input_type_ids = input_type_ids 98 | 99 | 100 | def input_fn_builder(features, seq_length): 101 | """Creates an `input_fn` closure to be passed to TPUEstimator.""" 102 | 103 | all_unique_ids = [] 104 | all_input_ids = [] 105 | all_input_mask = [] 106 | all_input_type_ids = [] 107 | 108 | for feature in features: 109 | all_unique_ids.append(feature.unique_id) 110 | all_input_ids.append(feature.input_ids) 111 | all_input_mask.append(feature.input_mask) 112 | all_input_type_ids.append(feature.input_type_ids) 113 | 114 | def input_fn(params): 115 | """The actual input function.""" 116 | batch_size = params["batch_size"] 117 | 118 | num_examples = len(features) 119 | 120 | # This is for demo purposes and does NOT scale to large data sets. We do 121 | # not use Dataset.from_generator() because that uses tf.py_func which is 122 | # not TPU compatible. The right way to load data is with TFRecordReader. 123 | d = tf.data.Dataset.from_tensor_slices({ 124 | "unique_ids": 125 | tf.constant(all_unique_ids, shape=[num_examples], dtype=tf.int32), 126 | "input_ids": 127 | tf.constant( 128 | all_input_ids, shape=[num_examples, seq_length], 129 | dtype=tf.int32), 130 | "input_mask": 131 | tf.constant( 132 | all_input_mask, 133 | shape=[num_examples, seq_length], 134 | dtype=tf.int32), 135 | "input_type_ids": 136 | tf.constant( 137 | all_input_type_ids, 138 | shape=[num_examples, seq_length], 139 | dtype=tf.int32), 140 | }) 141 | 142 | d = d.batch(batch_size=batch_size, drop_remainder=False) 143 | return d 144 | 145 | return input_fn 146 | 147 | 148 | def model_fn_builder(bert_config, init_checkpoint, layer_indexes, use_tpu, 149 | use_one_hot_embeddings): 150 | """Returns `model_fn` closure for TPUEstimator.""" 151 | 152 | def model_fn(features, labels, mode, params): # pylint: disable=unused-argument 153 | """The `model_fn` for TPUEstimator.""" 154 | 155 | unique_ids = features["unique_ids"] 156 | input_ids = features["input_ids"] 157 | input_mask = features["input_mask"] 158 | input_type_ids = features["input_type_ids"] 159 | 160 | model = modeling.BertModel( 161 | config=bert_config, 162 | is_training=False, 163 | input_ids=input_ids, 164 | input_mask=input_mask, 165 | token_type_ids=input_type_ids, 166 | use_one_hot_embeddings=use_one_hot_embeddings) 167 | 168 | if mode != tf.estimator.ModeKeys.PREDICT: 169 | raise ValueError("Only PREDICT modes are supported: %s" % (mode)) 170 | 171 | tvars = tf.trainable_variables() 172 | scaffold_fn = None 173 | (assignment_map, _) = modeling.get_assignment_map_from_checkpoint( 174 | tvars, init_checkpoint) 175 | if use_tpu: 176 | 177 | def tpu_scaffold(): 178 | tf.train.init_from_checkpoint(init_checkpoint, assignment_map) 179 | return tf.train.Scaffold() 180 | 181 | scaffold_fn = tpu_scaffold 182 | else: 183 | tf.train.init_from_checkpoint(init_checkpoint, assignment_map) 184 | 185 | all_layers = model.get_all_encoder_layers() 186 | 187 | predictions = { 188 | "unique_id": unique_ids, 189 | } 190 | 191 | for (i, layer_index) in enumerate(layer_indexes): 192 | predictions["layer_output_%d" % i] = all_layers[layer_index] 193 | 194 | output_spec = tf.contrib.tpu.TPUEstimatorSpec( 195 | mode=mode, predictions=predictions, scaffold_fn=scaffold_fn) 196 | return output_spec 197 | 198 | return model_fn 199 | 200 | 201 | def convert_examples_to_features(examples, seq_length, tokenizer): 202 | """Loads a data file into a list of `InputBatch`s.""" 203 | 204 | features = [] 205 | for (ex_index, example) in enumerate(examples): 206 | tokens_a = tokenizer.tokenize(example.text_a) 207 | 208 | tokens_b = None 209 | if example.text_b: 210 | tokens_b = tokenizer.tokenize(example.text_b) 211 | 212 | if tokens_b: 213 | # Modifies `tokens_a` and `tokens_b` in place so that the total 214 | # length is less than the specified length. 215 | # Account for [CLS], [SEP], [SEP] with "- 3" 216 | _truncate_seq_pair(tokens_a, tokens_b, seq_length - 3) 217 | else: 218 | # Account for [CLS] and [SEP] with "- 2" 219 | if len(tokens_a) > seq_length - 2: 220 | tokens_a = tokens_a[0:(seq_length - 2)] 221 | 222 | # The convention in BERT is: 223 | # (a) For sequence pairs: 224 | # tokens: [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP] 225 | # type_ids: 0 0 0 0 0 0 0 0 1 1 1 1 1 1 226 | # (b) For single sequences: 227 | # tokens: [CLS] the dog is hairy . [SEP] 228 | # type_ids: 0 0 0 0 0 0 0 229 | # 230 | # Where "type_ids" are used to indicate whether this is the first 231 | # sequence or the second sequence. The embedding vectors for `type=0` and 232 | # `type=1` were learned during pre-training and are added to the wordpiece 233 | # embedding vector (and position vector). This is not *strictly* necessary 234 | # since the [SEP] token unambiguously separates the sequences, but it makes 235 | # it easier for the model to learn the concept of sequences. 236 | # 237 | # For classification tasks, the first vector (corresponding to [CLS]) is 238 | # used as as the "sentence vector". Note that this only makes sense because 239 | # the entire model is fine-tuned. 240 | tokens = [] 241 | input_type_ids = [] 242 | tokens.append("[CLS]") 243 | input_type_ids.append(0) 244 | for token in tokens_a: 245 | tokens.append(token) 246 | input_type_ids.append(0) 247 | tokens.append("[SEP]") 248 | input_type_ids.append(0) 249 | 250 | if tokens_b: 251 | for token in tokens_b: 252 | tokens.append(token) 253 | input_type_ids.append(1) 254 | tokens.append("[SEP]") 255 | input_type_ids.append(1) 256 | 257 | input_ids = tokenizer.convert_tokens_to_ids(tokens) 258 | 259 | # The mask has 1 for real tokens and 0 for padding tokens. Only real 260 | # tokens are attended to. 261 | input_mask = [1] * len(input_ids) 262 | 263 | # Zero-pad up to the sequence length. 264 | while len(input_ids) < seq_length: 265 | input_ids.append(0) 266 | input_mask.append(0) 267 | input_type_ids.append(0) 268 | 269 | assert len(input_ids) == seq_length 270 | assert len(input_mask) == seq_length 271 | assert len(input_type_ids) == seq_length 272 | 273 | if ex_index < 5: 274 | tf.logging.info("*** Example ***") 275 | tf.logging.info("unique_id: %s" % (example.unique_id)) 276 | tf.logging.info("tokens: %s" % " ".join([str(x) for x in tokens])) 277 | tf.logging.info("input_ids: %s" % " ".join([str(x) for x in input_ids])) 278 | tf.logging.info("input_mask: %s" % " ".join([str(x) for x in input_mask])) 279 | tf.logging.info( 280 | "input_type_ids: %s" % " ".join([str(x) for x in input_type_ids])) 281 | 282 | features.append( 283 | InputFeatures( 284 | unique_id=example.unique_id, 285 | tokens=tokens, 286 | input_ids=input_ids, 287 | input_mask=input_mask, 288 | input_type_ids=input_type_ids)) 289 | return features 290 | 291 | 292 | def _truncate_seq_pair(tokens_a, tokens_b, max_length): 293 | """Truncates a sequence pair in place to the maximum length.""" 294 | 295 | # This is a simple heuristic which will always truncate the longer sequence 296 | # one token at a time. This makes more sense than truncating an equal percent 297 | # of tokens from each, since if one sequence is very short then each token 298 | # that's truncated likely contains more information than a longer sequence. 299 | while True: 300 | total_length = len(tokens_a) + len(tokens_b) 301 | if total_length <= max_length: 302 | break 303 | if len(tokens_a) > len(tokens_b): 304 | tokens_a.pop() 305 | else: 306 | tokens_b.pop() 307 | 308 | 309 | def read_examples(input_file): 310 | """Read a list of `InputExample`s from an input file.""" 311 | examples = [] 312 | unique_id = 0 313 | with tf.gfile.GFile(input_file, "r") as reader: 314 | while True: 315 | line = tokenization.convert_to_unicode(reader.readline()) 316 | if not line: 317 | break 318 | line = line.strip() 319 | text_a = None 320 | text_b = None 321 | m = re.match(r"^(.*) \|\|\| (.*)$", line) 322 | if m is None: 323 | text_a = line 324 | else: 325 | text_a = m.group(1) 326 | text_b = m.group(2) 327 | examples.append( 328 | InputExample(unique_id=unique_id, text_a=text_a, text_b=text_b)) 329 | unique_id += 1 330 | return examples 331 | 332 | 333 | def main(_): 334 | tf.logging.set_verbosity(tf.logging.INFO) 335 | 336 | layer_indexes = [int(x) for x in FLAGS.layers.split(",")] 337 | 338 | bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file) 339 | 340 | tokenizer = tokenization.FullTokenizer( 341 | vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case) 342 | 343 | is_per_host = tf.contrib.tpu.InputPipelineConfig.PER_HOST_V2 344 | run_config = tf.contrib.tpu.RunConfig( 345 | master=FLAGS.master, 346 | tpu_config=tf.contrib.tpu.TPUConfig( 347 | num_shards=FLAGS.num_tpu_cores, 348 | per_host_input_for_training=is_per_host)) 349 | 350 | examples = read_examples(FLAGS.input_file) 351 | 352 | features = convert_examples_to_features( 353 | examples=examples, seq_length=FLAGS.max_seq_length, tokenizer=tokenizer) 354 | 355 | unique_id_to_feature = {} 356 | for feature in features: 357 | unique_id_to_feature[feature.unique_id] = feature 358 | 359 | model_fn = model_fn_builder( 360 | bert_config=bert_config, 361 | init_checkpoint=FLAGS.init_checkpoint, 362 | layer_indexes=layer_indexes, 363 | use_tpu=FLAGS.use_tpu, 364 | use_one_hot_embeddings=FLAGS.use_one_hot_embeddings) 365 | 366 | # If TPU is not available, this will fall back to normal Estimator on CPU 367 | # or GPU. 368 | estimator = tf.contrib.tpu.TPUEstimator( 369 | use_tpu=FLAGS.use_tpu, 370 | model_fn=model_fn, 371 | config=run_config, 372 | predict_batch_size=FLAGS.batch_size) 373 | 374 | input_fn = input_fn_builder( 375 | features=features, seq_length=FLAGS.max_seq_length) 376 | 377 | with codecs.getwriter("utf-8")(tf.gfile.Open(FLAGS.output_file, 378 | "w")) as writer: 379 | for result in estimator.predict(input_fn, yield_single_examples=True): 380 | unique_id = int(result["unique_id"]) 381 | feature = unique_id_to_feature[unique_id] 382 | output_json = collections.OrderedDict() 383 | output_json["linex_index"] = unique_id 384 | all_features = [] 385 | for (i, token) in enumerate(feature.tokens): 386 | all_layers = [] 387 | for (j, layer_index) in enumerate(layer_indexes): 388 | layer_output = result["layer_output_%d" % j] 389 | layers = collections.OrderedDict() 390 | layers["index"] = layer_index 391 | layers["values"] = [ 392 | round(float(x), 6) for x in layer_output[i:(i + 1)].flat 393 | ] 394 | all_layers.append(layers) 395 | features = collections.OrderedDict() 396 | features["token"] = token 397 | features["layers"] = all_layers 398 | all_features.append(features) 399 | output_json["features"] = all_features 400 | writer.write(json.dumps(output_json) + "\n") 401 | 402 | 403 | if __name__ == "__main__": 404 | flags.mark_flag_as_required("input_file") 405 | flags.mark_flag_as_required("vocab_file") 406 | flags.mark_flag_as_required("bert_config_file") 407 | flags.mark_flag_as_required("init_checkpoint") 408 | flags.mark_flag_as_required("output_file") 409 | tf.app.run() 410 | -------------------------------------------------------------------------------- /bert/general_utils.py: -------------------------------------------------------------------------------- 1 | #coding=utf-8 2 | 3 | 4 | import logging 5 | import sys 6 | import time 7 | import numpy as np 8 | 9 | 10 | def get_logger(filename): 11 | logger = logging.getLogger('logger') 12 | logger.setLevel(logging.DEBUG) 13 | logging.basicConfig(format='%(message)s', level=logging.DEBUG) 14 | handler = logging.FileHandler(filename) 15 | handler.setLevel(logging.DEBUG) 16 | handler.setFormatter(logging.Formatter('%(asctime)s:%(levelname)s: %(message)s')) 17 | logging.getLogger().addHandler(handler) 18 | 19 | return logger 20 | 21 | 22 | class Progbar(object): 23 | """Progbar class copied from keras (https://github.com/fchollet/keras/) 24 | 25 | Displays a progress bar. 26 | Small edit : added strict arg to update 27 | # Arguments 28 | target: Total number of steps expected. 29 | interval: Minimum visual progress update interval (in seconds). 30 | """ 31 | 32 | def __init__(self, target, width=30, verbose=1): 33 | self.width = width 34 | self.target = target 35 | self.sum_values = {} 36 | self.unique_values = [] 37 | self.start = time.time() 38 | self.total_width = 0 39 | self.seen_so_far = 0 40 | self.verbose = verbose 41 | 42 | def update(self, current, values=[], exact=[], strict=[]): 43 | """ 44 | Updates the progress bar. 45 | # Arguments 46 | current: Index of current step. 47 | values: List of tuples (name, value_for_last_step). 48 | The progress bar will display averages for these values. 49 | exact: List of tuples (name, value_for_last_step). 50 | The progress bar will display these values directly. 51 | """ 52 | 53 | for k, v in values: 54 | if k not in self.sum_values: 55 | self.sum_values[k] = [v * (current - self.seen_so_far), 56 | current - self.seen_so_far] 57 | self.unique_values.append(k) 58 | else: 59 | self.sum_values[k][0] += v * (current - self.seen_so_far) 60 | self.sum_values[k][1] += (current - self.seen_so_far) 61 | for k, v in exact: 62 | if k not in self.sum_values: 63 | self.unique_values.append(k) 64 | self.sum_values[k] = [v, 1] 65 | 66 | for k, v in strict: 67 | if k not in self.sum_values: 68 | self.unique_values.append(k) 69 | self.sum_values[k] = v 70 | 71 | self.seen_so_far = current 72 | 73 | now = time.time() 74 | if self.verbose == 1: 75 | prev_total_width = self.total_width 76 | sys.stdout.write("\b" * prev_total_width) 77 | sys.stdout.write("\r") 78 | 79 | numdigits = int(np.floor(np.log10(self.target))) + 1 80 | barstr = '%%%dd/%%%dd [' % (numdigits, numdigits) 81 | bar = barstr % (current, self.target) 82 | prog = float(current)/self.target 83 | prog_width = int(self.width*prog) 84 | if prog_width > 0: 85 | bar += ('='*(prog_width-1)) 86 | if current < self.target: 87 | bar += '>' 88 | else: 89 | bar += '=' 90 | bar += ('.'*(self.width-prog_width)) 91 | bar += ']' 92 | sys.stdout.write(bar) 93 | self.total_width = len(bar) 94 | 95 | if current: 96 | time_per_unit = (now - self.start) / current 97 | else: 98 | time_per_unit = 0 99 | eta = time_per_unit*(self.target - current) 100 | info = '' 101 | if current < self.target: 102 | info += ' - ETA: %ds' % eta 103 | else: 104 | info += ' - %ds' % (now - self.start) 105 | for k in self.unique_values: 106 | if type(self.sum_values[k]) is list: 107 | info += ' - %s: %.4f' % (k, 108 | self.sum_values[k][0] / max(1, self.sum_values[k][1])) 109 | else: 110 | info += ' - %s: %s' % (k, self.sum_values[k]) 111 | 112 | self.total_width += len(info) 113 | if prev_total_width > self.total_width: 114 | info += ((prev_total_width-self.total_width) * " ") 115 | 116 | sys.stdout.write(info) 117 | sys.stdout.flush() 118 | 119 | if current >= self.target: 120 | sys.stdout.write("\n") 121 | 122 | if self.verbose == 2: 123 | if current >= self.target: 124 | info = '%ds' % (now - self.start) 125 | for k in self.unique_values: 126 | info += ' - %s: %.4f' % (k, 127 | self.sum_values[k][0] / max(1, self.sum_values[k][1])) 128 | sys.stdout.write(info + "\n") 129 | 130 | def add(self, n, values=[]): 131 | self.update(self.seen_so_far+n, values) -------------------------------------------------------------------------------- /bert/modeling_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | import collections 20 | import json 21 | import random 22 | import re 23 | 24 | from bert import modeling 25 | import six 26 | import tensorflow as tf 27 | 28 | 29 | class BertModelTest(tf.test.TestCase): 30 | 31 | class BertModelTester(object): 32 | 33 | def __init__(self, 34 | parent, 35 | batch_size=13, 36 | seq_length=7, 37 | is_training=True, 38 | use_input_mask=True, 39 | use_token_type_ids=True, 40 | vocab_size=99, 41 | hidden_size=32, 42 | num_hidden_layers=5, 43 | num_attention_heads=4, 44 | intermediate_size=37, 45 | hidden_act="gelu", 46 | hidden_dropout_prob=0.1, 47 | attention_probs_dropout_prob=0.1, 48 | max_position_embeddings=512, 49 | type_vocab_size=16, 50 | initializer_range=0.02, 51 | scope=None): 52 | self.parent = parent 53 | self.batch_size = batch_size 54 | self.seq_length = seq_length 55 | self.is_training = is_training 56 | self.use_input_mask = use_input_mask 57 | self.use_token_type_ids = use_token_type_ids 58 | self.vocab_size = vocab_size 59 | self.hidden_size = hidden_size 60 | self.num_hidden_layers = num_hidden_layers 61 | self.num_attention_heads = num_attention_heads 62 | self.intermediate_size = intermediate_size 63 | self.hidden_act = hidden_act 64 | self.hidden_dropout_prob = hidden_dropout_prob 65 | self.attention_probs_dropout_prob = attention_probs_dropout_prob 66 | self.max_position_embeddings = max_position_embeddings 67 | self.type_vocab_size = type_vocab_size 68 | self.initializer_range = initializer_range 69 | self.scope = scope 70 | 71 | def create_model(self): 72 | input_ids = BertModelTest.ids_tensor([self.batch_size, self.seq_length], 73 | self.vocab_size) 74 | 75 | input_mask = None 76 | if self.use_input_mask: 77 | input_mask = BertModelTest.ids_tensor( 78 | [self.batch_size, self.seq_length], vocab_size=2) 79 | 80 | token_type_ids = None 81 | if self.use_token_type_ids: 82 | token_type_ids = BertModelTest.ids_tensor( 83 | [self.batch_size, self.seq_length], self.type_vocab_size) 84 | 85 | config = modeling.BertConfig( 86 | vocab_size=self.vocab_size, 87 | hidden_size=self.hidden_size, 88 | num_hidden_layers=self.num_hidden_layers, 89 | num_attention_heads=self.num_attention_heads, 90 | intermediate_size=self.intermediate_size, 91 | hidden_act=self.hidden_act, 92 | hidden_dropout_prob=self.hidden_dropout_prob, 93 | attention_probs_dropout_prob=self.attention_probs_dropout_prob, 94 | max_position_embeddings=self.max_position_embeddings, 95 | type_vocab_size=self.type_vocab_size, 96 | initializer_range=self.initializer_range) 97 | 98 | model = modeling.BertModel( 99 | config=config, 100 | is_training=self.is_training, 101 | input_ids=input_ids, 102 | input_mask=input_mask, 103 | token_type_ids=token_type_ids, 104 | scope=self.scope) 105 | 106 | outputs = { 107 | "embedding_output": model.get_embedding_output(), 108 | "sequence_output": model.get_sequence_output(), 109 | "pooled_output": model.get_pooled_output(), 110 | "all_encoder_layers": model.get_all_encoder_layers(), 111 | } 112 | return outputs 113 | 114 | def check_output(self, result): 115 | self.parent.assertAllEqual( 116 | result["embedding_output"].shape, 117 | [self.batch_size, self.seq_length, self.hidden_size]) 118 | 119 | self.parent.assertAllEqual( 120 | result["sequence_output"].shape, 121 | [self.batch_size, self.seq_length, self.hidden_size]) 122 | 123 | self.parent.assertAllEqual(result["pooled_output"].shape, 124 | [self.batch_size, self.hidden_size]) 125 | 126 | def test_default(self): 127 | self.run_tester(BertModelTest.BertModelTester(self)) 128 | 129 | def test_config_to_json_string(self): 130 | config = modeling.BertConfig(vocab_size=99, hidden_size=37) 131 | obj = json.loads(config.to_json_string()) 132 | self.assertEqual(obj["vocab_size"], 99) 133 | self.assertEqual(obj["hidden_size"], 37) 134 | 135 | def run_tester(self, tester): 136 | with self.test_session() as sess: 137 | ops = tester.create_model() 138 | init_op = tf.group(tf.global_variables_initializer(), 139 | tf.local_variables_initializer()) 140 | sess.run(init_op) 141 | output_result = sess.run(ops) 142 | tester.check_output(output_result) 143 | 144 | self.assert_all_tensors_reachable(sess, [init_op, ops]) 145 | 146 | @classmethod 147 | def ids_tensor(cls, shape, vocab_size, rng=None, name=None): 148 | """Creates a random int32 tensor of the shape within the vocab size.""" 149 | if rng is None: 150 | rng = random.Random() 151 | 152 | total_dims = 1 153 | for dim in shape: 154 | total_dims *= dim 155 | 156 | values = [] 157 | for _ in range(total_dims): 158 | values.append(rng.randint(0, vocab_size - 1)) 159 | 160 | return tf.constant(value=values, dtype=tf.int32, shape=shape, name=name) 161 | 162 | def assert_all_tensors_reachable(self, sess, outputs): 163 | """Checks that all the tensors in the graph are reachable from outputs.""" 164 | graph = sess.graph 165 | 166 | ignore_strings = [ 167 | "^.*/dilation_rate$", 168 | "^.*/Tensordot/concat$", 169 | "^.*/Tensordot/concat/axis$", 170 | "^testing/.*$", 171 | ] 172 | 173 | ignore_regexes = [re.compile(x) for x in ignore_strings] 174 | 175 | unreachable = self.get_unreachable_ops(graph, outputs) 176 | filtered_unreachable = [] 177 | for x in unreachable: 178 | do_ignore = False 179 | for r in ignore_regexes: 180 | m = r.match(x.name) 181 | if m is not None: 182 | do_ignore = True 183 | if do_ignore: 184 | continue 185 | filtered_unreachable.append(x) 186 | unreachable = filtered_unreachable 187 | 188 | self.assertEqual( 189 | len(unreachable), 0, "The following ops are unreachable: %s" % 190 | (" ".join([x.name for x in unreachable]))) 191 | 192 | @classmethod 193 | def get_unreachable_ops(cls, graph, outputs): 194 | """Finds all of the tensors in graph that are unreachable from outputs.""" 195 | outputs = cls.flatten_recursive(outputs) 196 | output_to_op = collections.defaultdict(list) 197 | op_to_all = collections.defaultdict(list) 198 | assign_out_to_in = collections.defaultdict(list) 199 | 200 | for op in graph.get_operations(): 201 | for x in op.inputs: 202 | op_to_all[op.name].append(x.name) 203 | for y in op.outputs: 204 | output_to_op[y.name].append(op.name) 205 | op_to_all[op.name].append(y.name) 206 | if str(op.type) == "Assign": 207 | for y in op.outputs: 208 | for x in op.inputs: 209 | assign_out_to_in[y.name].append(x.name) 210 | 211 | assign_groups = collections.defaultdict(list) 212 | for out_name in assign_out_to_in.keys(): 213 | name_group = assign_out_to_in[out_name] 214 | for n1 in name_group: 215 | assign_groups[n1].append(out_name) 216 | for n2 in name_group: 217 | if n1 != n2: 218 | assign_groups[n1].append(n2) 219 | 220 | seen_tensors = {} 221 | stack = [x.name for x in outputs] 222 | while stack: 223 | name = stack.pop() 224 | if name in seen_tensors: 225 | continue 226 | seen_tensors[name] = True 227 | 228 | if name in output_to_op: 229 | for op_name in output_to_op[name]: 230 | if op_name in op_to_all: 231 | for input_name in op_to_all[op_name]: 232 | if input_name not in stack: 233 | stack.append(input_name) 234 | 235 | expanded_names = [] 236 | if name in assign_groups: 237 | for assign_name in assign_groups[name]: 238 | expanded_names.append(assign_name) 239 | 240 | for expanded_name in expanded_names: 241 | if expanded_name not in stack: 242 | stack.append(expanded_name) 243 | 244 | unreachable_ops = [] 245 | for op in graph.get_operations(): 246 | is_unreachable = False 247 | all_names = [x.name for x in op.inputs] + [x.name for x in op.outputs] 248 | for name in all_names: 249 | if name not in seen_tensors: 250 | is_unreachable = True 251 | if is_unreachable: 252 | unreachable_ops.append(op) 253 | return unreachable_ops 254 | 255 | @classmethod 256 | def flatten_recursive(cls, item): 257 | """Flattens (potentially nested) a tuple/dictionary/list to a list.""" 258 | output = [] 259 | if isinstance(item, list): 260 | output.extend(item) 261 | elif isinstance(item, tuple): 262 | output.extend(list(item)) 263 | elif isinstance(item, dict): 264 | for (_, v) in six.iteritems(item): 265 | output.append(v) 266 | else: 267 | return [item] 268 | 269 | flat_output = [] 270 | for x in output: 271 | flat_output.extend(cls.flatten_recursive(x)) 272 | return flat_output 273 | 274 | 275 | if __name__ == "__main__": 276 | tf.test.main() 277 | -------------------------------------------------------------------------------- /bert/optimization.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Functions and classes related to optimization (weight updates).""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import re 22 | import tensorflow as tf 23 | 24 | 25 | def create_optimizer(loss, init_lr, num_train_steps, num_warmup_steps, use_tpu): 26 | """Creates an optimizer training op.""" 27 | global_step = tf.train.get_or_create_global_step() 28 | 29 | learning_rate = tf.constant(value=init_lr, shape=[], dtype=tf.float32) 30 | 31 | # Implements linear decay of the learning rate. 32 | learning_rate = tf.train.polynomial_decay( 33 | learning_rate, 34 | global_step, 35 | num_train_steps, 36 | end_learning_rate=0.0, 37 | power=1.0, 38 | cycle=False) 39 | 40 | # Implements linear warmup. I.e., if global_step < num_warmup_steps, the 41 | # learning rate will be `global_step/num_warmup_steps * init_lr`. 42 | if num_warmup_steps: 43 | global_steps_int = tf.cast(global_step, tf.int32) 44 | warmup_steps_int = tf.constant(num_warmup_steps, dtype=tf.int32) 45 | 46 | global_steps_float = tf.cast(global_steps_int, tf.float32) 47 | warmup_steps_float = tf.cast(warmup_steps_int, tf.float32) 48 | 49 | warmup_percent_done = global_steps_float / warmup_steps_float 50 | warmup_learning_rate = init_lr * warmup_percent_done 51 | 52 | is_warmup = tf.cast(global_steps_int < warmup_steps_int, tf.float32) 53 | learning_rate = ( 54 | (1.0 - is_warmup) * learning_rate + is_warmup * warmup_learning_rate) 55 | 56 | # It is recommended that you use this optimizer for fine tuning, since this 57 | # is how the model was trained (note that the Adam m/v variables are NOT 58 | # loaded from init_checkpoint.) 59 | optimizer = AdamWeightDecayOptimizer( 60 | learning_rate=learning_rate, 61 | weight_decay_rate=0.01, 62 | beta_1=0.9, 63 | beta_2=0.999, 64 | epsilon=1e-6, 65 | exclude_from_weight_decay=["LayerNorm", "layer_norm", "bias"]) 66 | 67 | if use_tpu: 68 | optimizer = tf.contrib.tpu.CrossShardOptimizer(optimizer) 69 | 70 | tvars = tf.trainable_variables() 71 | grads = tf.gradients(loss, tvars) 72 | 73 | # This is how the model was pre-trained. 74 | (grads, _) = tf.clip_by_global_norm(grads, clip_norm=1.0) 75 | 76 | train_op = optimizer.apply_gradients( 77 | zip(grads, tvars), global_step=global_step) 78 | 79 | new_global_step = global_step + 1 80 | train_op = tf.group(train_op, global_step.assign(new_global_step)) 81 | return train_op 82 | 83 | 84 | class AdamWeightDecayOptimizer(tf.train.Optimizer): 85 | """A basic Adam optimizer that includes "correct" L2 weight decay.""" 86 | 87 | def __init__(self, 88 | learning_rate, 89 | weight_decay_rate=0.0, 90 | beta_1=0.9, 91 | beta_2=0.999, 92 | epsilon=1e-6, 93 | exclude_from_weight_decay=None, 94 | name="AdamWeightDecayOptimizer"): 95 | """Constructs a AdamWeightDecayOptimizer.""" 96 | super(AdamWeightDecayOptimizer, self).__init__(False, name) 97 | 98 | self.learning_rate = learning_rate 99 | self.weight_decay_rate = weight_decay_rate 100 | self.beta_1 = beta_1 101 | self.beta_2 = beta_2 102 | self.epsilon = epsilon 103 | self.exclude_from_weight_decay = exclude_from_weight_decay 104 | 105 | def apply_gradients(self, grads_and_vars, global_step=None, name=None): 106 | """See base class.""" 107 | assignments = [] 108 | for (grad, param) in grads_and_vars: 109 | if grad is None or param is None: 110 | continue 111 | 112 | param_name = self._get_variable_name(param.name) 113 | 114 | m = tf.get_variable( 115 | name=param_name + "/adam_m", 116 | shape=param.shape.as_list(), 117 | dtype=tf.float32, 118 | trainable=False, 119 | initializer=tf.zeros_initializer()) 120 | v = tf.get_variable( 121 | name=param_name + "/adam_v", 122 | shape=param.shape.as_list(), 123 | dtype=tf.float32, 124 | trainable=False, 125 | initializer=tf.zeros_initializer()) 126 | 127 | # Standard Adam update. 128 | next_m = ( 129 | tf.multiply(self.beta_1, m) + tf.multiply(1.0 - self.beta_1, grad)) 130 | next_v = ( 131 | tf.multiply(self.beta_2, v) + tf.multiply(1.0 - self.beta_2, 132 | tf.square(grad))) 133 | 134 | update = next_m / (tf.sqrt(next_v) + self.epsilon) 135 | 136 | # Just adding the square of the weights to the loss function is *not* 137 | # the correct way of using L2 regularization/weight decay with Adam, 138 | # since that will interact with the m and v parameters in strange ways. 139 | # 140 | # Instead we want ot decay the weights in a manner that doesn't interact 141 | # with the m/v parameters. This is equivalent to adding the square 142 | # of the weights to the loss with plain (non-momentum) SGD. 143 | if self._do_use_weight_decay(param_name): 144 | update += self.weight_decay_rate * param 145 | 146 | update_with_lr = self.learning_rate * update 147 | 148 | next_param = param - update_with_lr 149 | 150 | assignments.extend( 151 | [param.assign(next_param), 152 | m.assign(next_m), 153 | v.assign(next_v)]) 154 | return tf.group(*assignments, name=name) 155 | 156 | def _do_use_weight_decay(self, param_name): 157 | """Whether to use L2 weight decay for `param_name`.""" 158 | if not self.weight_decay_rate: 159 | return False 160 | if self.exclude_from_weight_decay: 161 | for r in self.exclude_from_weight_decay: 162 | if re.search(r, param_name) is not None: 163 | return False 164 | return True 165 | 166 | def _get_variable_name(self, param_name): 167 | """Get the variable name from the tensor name.""" 168 | m = re.match("^(.*):\\d+$", param_name) 169 | if m is not None: 170 | param_name = m.group(1) 171 | return param_name 172 | -------------------------------------------------------------------------------- /bert/optimization_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | from bert import optimization 20 | import tensorflow as tf 21 | 22 | 23 | class OptimizationTest(tf.test.TestCase): 24 | 25 | def test_adam(self): 26 | with self.test_session() as sess: 27 | w = tf.get_variable( 28 | "w", 29 | shape=[3], 30 | initializer=tf.constant_initializer([0.1, -0.2, -0.1])) 31 | x = tf.constant([0.4, 0.2, -0.5]) 32 | loss = tf.reduce_mean(tf.square(x - w)) 33 | tvars = tf.trainable_variables() 34 | grads = tf.gradients(loss, tvars) 35 | global_step = tf.train.get_or_create_global_step() 36 | optimizer = optimization.AdamWeightDecayOptimizer(learning_rate=0.2) 37 | train_op = optimizer.apply_gradients(zip(grads, tvars), global_step) 38 | init_op = tf.group(tf.global_variables_initializer(), 39 | tf.local_variables_initializer()) 40 | sess.run(init_op) 41 | for _ in range(100): 42 | sess.run(train_op) 43 | w_np = sess.run(w) 44 | self.assertAllClose(w_np.flat, [0.4, 0.2, -0.5], rtol=1e-2, atol=1e-2) 45 | 46 | 47 | if __name__ == "__main__": 48 | tf.test.main() 49 | -------------------------------------------------------------------------------- /bert/theseus_replacement_scheduler.py: -------------------------------------------------------------------------------- 1 | from bert_of_theseus import BertEncoder 2 | 3 | 4 | class ConstantReplacementScheduler: 5 | def __init__(self, bert_encoder: BertEncoder, replacing_rate, replacing_steps=None): 6 | self.bert_encoder = bert_encoder 7 | self.replacing_rate = replacing_rate 8 | self.replacing_steps = replacing_steps 9 | self.step_counter = 0 10 | self.bert_encoder.set_replacing_rate(replacing_rate) 11 | 12 | def step(self): 13 | self.step_counter += 1 14 | if self.replacing_steps is None or self.replacing_rate == 1.0: 15 | return self.replacing_rate 16 | else: 17 | if self.step_counter >= self.replacing_steps: 18 | self.bert_encoder.set_replacing_rate(1.0) 19 | self.replacing_rate = 1.0 20 | return self.replacing_rate 21 | 22 | 23 | class LinearReplacementScheduler: 24 | def __init__(self, bert_encoder: BertEncoder, base_replacing_rate, k): 25 | self.bert_encoder = bert_encoder 26 | self.base_replacing_rate = base_replacing_rate 27 | self.step_counter = 0 28 | self.k = k 29 | self.bert_encoder.set_replacing_rate(base_replacing_rate) 30 | 31 | def step(self): 32 | self.step_counter += 1 33 | current_replacing_rate = min(self.k * self.step_counter + self.base_replacing_rate, 1.0) 34 | self.bert_encoder.set_replacing_rate(current_replacing_rate) 35 | return current_replacing_rate 36 | -------------------------------------------------------------------------------- /bert/tokenization.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Tokenization classes.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import collections 22 | import unicodedata 23 | import six 24 | import tensorflow as tf 25 | 26 | 27 | def convert_to_unicode(text): 28 | """Converts `text` to Unicode (if it's not already), assuming utf-8 input.""" 29 | if six.PY3: 30 | if isinstance(text, str): 31 | return text 32 | elif isinstance(text, bytes): 33 | return text.decode("utf-8", "ignore") 34 | else: 35 | raise ValueError("Unsupported string type: %s" % (type(text))) 36 | elif six.PY2: 37 | if isinstance(text, str): 38 | return text.decode("utf-8", "ignore") 39 | elif isinstance(text, unicode): 40 | return text 41 | else: 42 | raise ValueError("Unsupported string type: %s" % (type(text))) 43 | else: 44 | raise ValueError("Not running on Python2 or Python 3?") 45 | 46 | 47 | def printable_text(text): 48 | """Returns text encoded in a way suitable for print or `tf.logging`.""" 49 | 50 | # These functions want `str` for both Python2 and Python3, but in one case 51 | # it's a Unicode string and in the other it's a byte string. 52 | if six.PY3: 53 | if isinstance(text, str): 54 | return text 55 | elif isinstance(text, bytes): 56 | return text.decode("utf-8", "ignore") 57 | else: 58 | raise ValueError("Unsupported string type: %s" % (type(text))) 59 | elif six.PY2: 60 | if isinstance(text, str): 61 | return text 62 | elif isinstance(text, unicode): 63 | return text.encode("utf-8") 64 | else: 65 | raise ValueError("Unsupported string type: %s" % (type(text))) 66 | else: 67 | raise ValueError("Not running on Python2 or Python 3?") 68 | 69 | 70 | def load_vocab(vocab_file): 71 | """Loads a vocabulary file into a dictionary.""" 72 | vocab = collections.OrderedDict() 73 | index = 0 74 | with tf.gfile.GFile(vocab_file, "r") as reader: 75 | while True: 76 | token = convert_to_unicode(reader.readline()) 77 | if not token: 78 | break 79 | token = token.strip() 80 | vocab[token] = index 81 | index += 1 82 | return vocab 83 | 84 | 85 | def convert_tokens_to_ids(vocab, tokens, unk_token="[UNK]"): 86 | """Converts a sequence of tokens into ids using the vocab.""" 87 | ids = [] 88 | for token in tokens: 89 | if token in vocab: 90 | ids.append(vocab[token]) 91 | else: 92 | ids.append(vocab[unk_token]) 93 | return ids 94 | 95 | 96 | def whitespace_tokenize(text): 97 | """Runs basic whitespace cleaning and splitting on a peice of text.""" 98 | text = text.strip() 99 | if not text: 100 | return [] 101 | tokens = text.split() 102 | return tokens 103 | 104 | 105 | class FullTokenizer(object): 106 | """Runs end-to-end tokenziation.""" 107 | 108 | def __init__(self, vocab_file, do_lower_case=True): 109 | self.vocab = load_vocab(vocab_file) 110 | self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case) 111 | self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab) 112 | 113 | def tokenize(self, text): 114 | split_tokens = [] 115 | for token in self.basic_tokenizer.tokenize(text): 116 | for sub_token in self.wordpiece_tokenizer.tokenize(token): 117 | split_tokens.append(sub_token) 118 | 119 | return split_tokens 120 | 121 | def convert_tokens_to_ids(self, tokens): 122 | return convert_tokens_to_ids(self.vocab, tokens) 123 | 124 | 125 | class CharTokenizer(object): 126 | """Runs end-to-end tokenziation.""" 127 | 128 | def __init__(self, vocab_file, do_lower_case=True): 129 | self.vocab = load_vocab(vocab_file) 130 | self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case) 131 | self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab) 132 | 133 | def tokenize(self, text): 134 | split_tokens = [] 135 | for token in self.basic_tokenizer.tokenize(text): 136 | for sub_token in token: 137 | split_tokens.append(sub_token) 138 | 139 | return split_tokens 140 | 141 | def convert_tokens_to_ids(self, tokens): 142 | return convert_tokens_to_ids(self.vocab, tokens) 143 | 144 | 145 | class BasicTokenizer(object): 146 | """Runs basic tokenization (punctuation splitting, lower casing, etc.).""" 147 | 148 | def __init__(self, do_lower_case=True): 149 | """Constructs a BasicTokenizer. 150 | 151 | Args: 152 | do_lower_case: Whether to lower case the input. 153 | """ 154 | self.do_lower_case = do_lower_case 155 | 156 | def tokenize(self, text): 157 | """Tokenizes a piece of text.""" 158 | text = convert_to_unicode(text) 159 | text = self._clean_text(text) 160 | 161 | # This was added on November 1st, 2018 for the multilingual and Chinese 162 | # models. This is also applied to the English models now, but it doesn't 163 | # matter since the English models were not trained on any Chinese data 164 | # and generally don't have any Chinese data in them (there are Chinese 165 | # characters in the vocabulary because Wikipedia does have some Chinese 166 | # words in the English Wikipedia.). 167 | text = self._tokenize_chinese_chars(text) 168 | 169 | orig_tokens = whitespace_tokenize(text) 170 | split_tokens = [] 171 | for token in orig_tokens: 172 | if self.do_lower_case: 173 | token = token.lower() 174 | token = self._run_strip_accents(token) 175 | split_tokens.extend(self._run_split_on_punc(token)) 176 | 177 | output_tokens = whitespace_tokenize(" ".join(split_tokens)) 178 | return output_tokens 179 | 180 | def _run_strip_accents(self, text): 181 | """Strips accents from a piece of text.""" 182 | text = unicodedata.normalize("NFD", text) 183 | output = [] 184 | for char in text: 185 | cat = unicodedata.category(char) 186 | if cat == "Mn": 187 | continue 188 | output.append(char) 189 | return "".join(output) 190 | 191 | def _run_split_on_punc(self, text): 192 | """Splits punctuation on a piece of text.""" 193 | chars = list(text) 194 | i = 0 195 | start_new_word = True 196 | output = [] 197 | while i < len(chars): 198 | char = chars[i] 199 | if _is_punctuation(char): 200 | output.append([char]) 201 | start_new_word = True 202 | else: 203 | if start_new_word: 204 | output.append([]) 205 | start_new_word = False 206 | output[-1].append(char) 207 | i += 1 208 | 209 | return ["".join(x) for x in output] 210 | 211 | def _tokenize_chinese_chars(self, text): 212 | """Adds whitespace around any CJK character.""" 213 | output = [] 214 | for char in text: 215 | cp = ord(char) 216 | if self._is_chinese_char(cp): 217 | output.append(" ") 218 | output.append(char) 219 | output.append(" ") 220 | else: 221 | output.append(char) 222 | return "".join(output) 223 | 224 | def _is_chinese_char(self, cp): 225 | """Checks whether CP is the codepoint of a CJK character.""" 226 | # This defines a "chinese character" as anything in the CJK Unicode block: 227 | # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block) 228 | # 229 | # Note that the CJK Unicode block is NOT all Japanese and Korean characters, 230 | # despite its name. The modern Korean Hangul alphabet is a different block, 231 | # as is Japanese Hiragana and Katakana. Those alphabets are used to write 232 | # space-separated words, so they are not treated specially and handled 233 | # like the all of the other languages. 234 | if ((cp >= 0x4E00 and cp <= 0x9FFF) or # 235 | (cp >= 0x3400 and cp <= 0x4DBF) or # 236 | (cp >= 0x20000 and cp <= 0x2A6DF) or # 237 | (cp >= 0x2A700 and cp <= 0x2B73F) or # 238 | (cp >= 0x2B740 and cp <= 0x2B81F) or # 239 | (cp >= 0x2B820 and cp <= 0x2CEAF) or 240 | (cp >= 0xF900 and cp <= 0xFAFF) or # 241 | (cp >= 0x2F800 and cp <= 0x2FA1F)): # 242 | return True 243 | 244 | return False 245 | 246 | def _clean_text(self, text): 247 | """Performs invalid character removal and whitespace cleanup on text.""" 248 | output = [] 249 | for char in text: 250 | cp = ord(char) 251 | if cp == 0 or cp == 0xfffd or _is_control(char): 252 | continue 253 | if _is_whitespace(char): 254 | output.append(" ") 255 | else: 256 | output.append(char) 257 | return "".join(output) 258 | 259 | 260 | class WordpieceTokenizer(object): 261 | """Runs WordPiece tokenziation.""" 262 | 263 | def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=100): 264 | self.vocab = vocab 265 | self.unk_token = unk_token 266 | self.max_input_chars_per_word = max_input_chars_per_word 267 | 268 | def tokenize(self, text): 269 | """Tokenizes a piece of text into its word pieces. 270 | 271 | This uses a greedy longest-match-first algorithm to perform tokenization 272 | using the given vocabulary. 273 | 274 | For example: 275 | input = "unaffable" 276 | output = ["un", "##aff", "##able"] 277 | 278 | Args: 279 | text: A single token or whitespace separated tokens. This should have 280 | already been passed through `BasicTokenizer. 281 | 282 | Returns: 283 | A list of wordpiece tokens. 284 | """ 285 | 286 | text = convert_to_unicode(text) 287 | 288 | output_tokens = [] 289 | for token in whitespace_tokenize(text): 290 | chars = list(token) 291 | if len(chars) > self.max_input_chars_per_word: 292 | output_tokens.append(self.unk_token) 293 | continue 294 | 295 | is_bad = False 296 | start = 0 297 | sub_tokens = [] 298 | while start < len(chars): 299 | end = len(chars) 300 | cur_substr = None 301 | while start < end: 302 | substr = "".join(chars[start:end]) 303 | if start > 0: 304 | substr = "##" + substr 305 | if substr in self.vocab: 306 | cur_substr = substr 307 | break 308 | end -= 1 309 | if cur_substr is None: 310 | is_bad = True 311 | break 312 | sub_tokens.append(cur_substr) 313 | start = end 314 | 315 | if is_bad: 316 | output_tokens.append(self.unk_token) 317 | else: 318 | output_tokens.extend(sub_tokens) 319 | return output_tokens 320 | 321 | 322 | def _is_whitespace(char): 323 | """Checks whether `chars` is a whitespace character.""" 324 | # \t, \n, and \r are technically contorl characters but we treat them 325 | # as whitespace since they are generally considered as such. 326 | if char == " " or char == "\t" or char == "\n" or char == "\r": 327 | return True 328 | cat = unicodedata.category(char) 329 | if cat == "Zs": 330 | return True 331 | return False 332 | 333 | 334 | def _is_control(char): 335 | """Checks whether `chars` is a control character.""" 336 | # These are technically control characters but we count them as whitespace 337 | # characters. 338 | if char == "\t" or char == "\n" or char == "\r": 339 | return False 340 | cat = unicodedata.category(char) 341 | if cat.startswith("C"): 342 | return True 343 | return False 344 | 345 | 346 | def _is_punctuation(char): 347 | """Checks whether `chars` is a punctuation character.""" 348 | cp = ord(char) 349 | # We treat all non-letter/number ASCII as punctuation. 350 | # Characters such as "^", "$", and "`" are not in the Unicode 351 | # Punctuation class but we treat them as punctuation anyways, for 352 | # consistency. 353 | if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or 354 | (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)): 355 | return True 356 | cat = unicodedata.category(char) 357 | if cat.startswith("P"): 358 | return True 359 | return False 360 | -------------------------------------------------------------------------------- /bert/tokenization_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | import os 20 | import tempfile 21 | 22 | from bert import tokenization 23 | import tensorflow as tf 24 | 25 | 26 | class TokenizationTest(tf.test.TestCase): 27 | 28 | def test_full_tokenizer(self): 29 | vocab_tokens = [ 30 | "[UNK]", "[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un", "runn", 31 | "##ing", "," 32 | ] 33 | with tempfile.NamedTemporaryFile(delete=False) as vocab_writer: 34 | vocab_writer.write("".join([x + "\n" for x in vocab_tokens])) 35 | 36 | vocab_file = vocab_writer.name 37 | 38 | tokenizer = tokenization.FullTokenizer(vocab_file) 39 | os.unlink(vocab_file) 40 | 41 | tokens = tokenizer.tokenize(u"UNwant\u00E9d,running") 42 | self.assertAllEqual(tokens, ["un", "##want", "##ed", ",", "runn", "##ing"]) 43 | 44 | self.assertAllEqual( 45 | tokenizer.convert_tokens_to_ids(tokens), [7, 4, 5, 10, 8, 9]) 46 | 47 | def test_chinese(self): 48 | tokenizer = tokenization.BasicTokenizer() 49 | 50 | self.assertAllEqual( 51 | tokenizer.tokenize(u"ah\u535A\u63A8zz"), 52 | [u"ah", u"\u535A", u"\u63A8", u"zz"]) 53 | 54 | def test_basic_tokenizer_lower(self): 55 | tokenizer = tokenization.BasicTokenizer(do_lower_case=True) 56 | 57 | self.assertAllEqual( 58 | tokenizer.tokenize(u" \tHeLLo!how \n Are yoU? "), 59 | ["hello", "!", "how", "are", "you", "?"]) 60 | self.assertAllEqual(tokenizer.tokenize(u"H\u00E9llo"), ["hello"]) 61 | 62 | def test_basic_tokenizer_no_lower(self): 63 | tokenizer = tokenization.BasicTokenizer(do_lower_case=False) 64 | 65 | self.assertAllEqual( 66 | tokenizer.tokenize(u" \tHeLLo!how \n Are yoU? "), 67 | ["HeLLo", "!", "how", "Are", "yoU", "?"]) 68 | 69 | def test_wordpiece_tokenizer(self): 70 | vocab_tokens = [ 71 | "[UNK]", "[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un", "runn", 72 | "##ing" 73 | ] 74 | 75 | vocab = {} 76 | for (i, token) in enumerate(vocab_tokens): 77 | vocab[token] = i 78 | tokenizer = tokenization.WordpieceTokenizer(vocab=vocab) 79 | 80 | self.assertAllEqual(tokenizer.tokenize(""), []) 81 | 82 | self.assertAllEqual( 83 | tokenizer.tokenize("unwanted running"), 84 | ["un", "##want", "##ed", "runn", "##ing"]) 85 | 86 | self.assertAllEqual( 87 | tokenizer.tokenize("unwantedX running"), ["[UNK]", "runn", "##ing"]) 88 | 89 | def test_convert_tokens_to_ids(self): 90 | vocab_tokens = [ 91 | "[UNK]", "[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un", "runn", 92 | "##ing" 93 | ] 94 | 95 | vocab = {} 96 | for (i, token) in enumerate(vocab_tokens): 97 | vocab[token] = i 98 | 99 | self.assertAllEqual( 100 | tokenization.convert_tokens_to_ids( 101 | vocab, ["un", "##want", "##ed", "runn", "##ing"]), [7, 4, 5, 8, 9]) 102 | 103 | def test_is_whitespace(self): 104 | self.assertTrue(tokenization._is_whitespace(u" ")) 105 | self.assertTrue(tokenization._is_whitespace(u"\t")) 106 | self.assertTrue(tokenization._is_whitespace(u"\r")) 107 | self.assertTrue(tokenization._is_whitespace(u"\n")) 108 | self.assertTrue(tokenization._is_whitespace(u"\u00A0")) 109 | 110 | self.assertFalse(tokenization._is_whitespace(u"A")) 111 | self.assertFalse(tokenization._is_whitespace(u"-")) 112 | 113 | def test_is_control(self): 114 | self.assertTrue(tokenization._is_control(u"\u0005")) 115 | 116 | self.assertFalse(tokenization._is_control(u"A")) 117 | self.assertFalse(tokenization._is_control(u" ")) 118 | self.assertFalse(tokenization._is_control(u"\t")) 119 | self.assertFalse(tokenization._is_control(u"\r")) 120 | 121 | def test_is_punctuation(self): 122 | self.assertTrue(tokenization._is_punctuation(u"-")) 123 | self.assertTrue(tokenization._is_punctuation(u"$")) 124 | self.assertTrue(tokenization._is_punctuation(u"`")) 125 | self.assertTrue(tokenization._is_punctuation(u".")) 126 | 127 | self.assertFalse(tokenization._is_punctuation(u"A")) 128 | self.assertFalse(tokenization._is_punctuation(u" ")) 129 | 130 | 131 | if __name__ == "__main__": 132 | tf.test.main() 133 | -------------------------------------------------------------------------------- /common_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | 4 | def set_logger(context, verbose=False): 5 | if os.name == 'nt': # for Windows 6 | return NTLogger(context, verbose) 7 | 8 | logger = logging.getLogger(context) 9 | logger.setLevel(logging.DEBUG if verbose else logging.INFO) 10 | formatter = logging.Formatter( 11 | '%(levelname)-.1s:' + context + ':[%(filename).3s:%(funcName).3s:%(lineno)3d]:%(message)s', datefmt= 12 | '%m-%d %H:%M:%S') 13 | console_handler = logging.StreamHandler() 14 | console_handler.setLevel(logging.DEBUG if verbose else logging.INFO) 15 | console_handler.setFormatter(formatter) 16 | logger.handlers = [] 17 | logger.addHandler(console_handler) 18 | return logger 19 | 20 | 21 | class NTLogger: 22 | def __init__(self, context, verbose): 23 | self.context = context 24 | self.verbose = verbose 25 | 26 | def info(self, msg, **kwargs): 27 | print('I:%s:%s' % (self.context, msg), flush=True) 28 | 29 | def debug(self, msg, **kwargs): 30 | if self.verbose: 31 | print('D:%s:%s' % (self.context, msg), flush=True) 32 | 33 | def error(self, msg, **kwargs): 34 | print('E:%s:%s' % (self.context, msg), flush=True) 35 | 36 | def warning(self, msg, **kwargs): 37 | print('W:%s:%s' % (self.context, msg), flush=True) -------------------------------------------------------------------------------- /configs/__init__.py: -------------------------------------------------------------------------------- 1 | __author__ = 'lenovo' 2 | -------------------------------------------------------------------------------- /configs/base_config.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | BASE_DIR = Path('slot_extraction') 4 | config = { 5 | 'data_dir':'data', 6 | 'embedding_dir':os.path.join('data','embedding_data'), 7 | 'embedding_file_name':"sgns.financial.bigram-char", 8 | 'input_embedding_file':"input_char_embedding.npy", 9 | 'input_word_embedding_file':"input_word_embedding.npy", 10 | 'vocab_file':"vocab.txt", 11 | 'slot_list_root_path':os.path.join('data','slot_pattern'), 12 | 'slot_file_name':"base_slot_list", 13 | 'log_dir': os.path.join('output','log'), 14 | 'data_file_name':'orig_data_train.txt', 15 | 'train_valid_data_dir':'train_valid_data', 16 | 'train_data_text_name':'train_split_data_text.npy', 17 | 'valid_data_text_name':'valid_split_data_text.npy', 18 | 'train_data_tag_name':'train_split_data_tag.npy', 19 | 'valid_data_tag_name':'valid_split_data_tag.npy', 20 | 'test_data_text_name':'test_data_text.npy', 21 | 'test_data_tag_name':'test_data_tag.npy', 22 | 'train_data_text_word_name':'train_word_split_data.npy', 23 | 'valid_data_text_word_name':'valid_word_split_data.npy', 24 | 'test_data_text_word_name':'test_word_split_data.npy', 25 | 'model_dir':os.path.join('output','model','checkpoint'), 26 | 'base_model_dir':os.path.join('output','model','base_model','checkpoint'), 27 | 'orig_dev':'orig_data_dev.txt', 28 | 'orig_test':'orig_data_test.txt', 29 | "pb_model_dir":os.path.join('output','model','saved_model'), 30 | "base_pb_model_dir":os.path.join('output','model','base_model','saved_model'), 31 | "standard_slot_description":os.path.join('data','slot_pattern','slot_description.csv'), 32 | "cnn_model_dir":os.path.join('output','model','cnn_model','checkpoint'), 33 | "cnn_pb_model_dir":os.path.join('output','model','cnn_model','saved_model'), 34 | "lstm_only_dlloss_model_dir":os.path.join('output','model','lstm_only_dlloss_model','checkpoint'), 35 | "lstm_only_dlloss_pb_model_dir":os.path.join('output','model','lstm_only_dlloss_model','saved_model'), 36 | "lstmcrf_lesslabel_model_dir":os.path.join('output','model','lstm_only_dlloss_model','checkpoint'), 37 | "lstmcrf_lesslabel_pb_model_dir":os.path.join('output','model','lstm_only_dlloss_model','saved_model'), 38 | "lstmcrf_completelabel_model_dir":os.path.join('output','model','lstmcrf_completelabel_model','checkpoint'), 39 | "lstmcrf_completelabel_pb_model_dir":os.path.join('output','model','lstmcrf_completelabel_model','saved_model'), 40 | "lstmcrf_completelabel_wordemb_model_dir":os.path.join('output','model','lstmcrf_completelabel_wordemb_model','checkpoint'), 41 | "lstmcrf_completelabel_wordemb_pb_model_dir":os.path.join('output','model','lstmcrf_completelabel_wordemb_model','saved_model'), 42 | "lstmcrf_cnn_completelabel_wordemb_model_dir":os.path.join('output','model','lstmcrf_cnn_completelabel_wordemb_model','checkpoint'), 43 | "lstmcrf_cnn_completelabel_wordemb_pb_model_dir":os.path.join('output','model','lstmcrf_cnn_completelabel_wordemb_model','saved_model'), 44 | 45 | } 46 | # print(os.path.join(config.get("train_valid_data_dir"),config.get("train_data_text_name"))) 47 | # print(os.path.join(config.get("train_valid_data_dir"),config.get("train_data_text_name"))) -------------------------------------------------------------------------------- /configs/bert_config.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | BASE_DIR = Path('slot_extraction') 4 | bert_config = { 5 | 'data_dir':'data', 6 | 'vocab_file':"vocab.txt", 7 | 'slot_list_root_path':os.path.join('data','slot_pattern'), 8 | 'slot_file_name':"base_slot_list", 9 | 'bert_slot_file_name':"bert_basic_slot_list", 10 | 'bert_slot_complete_file_name':"bert_complete_slot_pattern", 11 | 'log_dir': os.path.join('output','log'), 12 | 'data_file_name':'orig_data_train.txt', 13 | 'train_valid_data_dir':'train_valid_data_bert_lesslabel_hascls', 14 | 'train_data_text_name':'train_split_data_text.npy', 15 | 'valid_data_text_name':'valid_split_data_text.npy', 16 | 'train_data_tag_name':'train_split_data_tag.npy', 17 | 'valid_data_tag_name':'valid_split_data_tag.npy', 18 | 'test_data_text_name':'test_data_text.npy', 19 | 'test_data_tag_name':'test_data_tag.npy', 20 | 'orig_dev':'orig_data_dev.txt', 21 | 'orig_test':'orig_data_test.txt', 22 | "standard_slot_description":os.path.join('data','slot_pattern','slot_description.csv'), 23 | "bert_pretrained_model_path":os.path.join('data','chinese_roberta_wwm_ext_L-12_H-768_A-12'), 24 | "bert_config_path":"bert_config.json", 25 | 'bert_init_checkpoints':'bert_model.ckpt', 26 | "bert_model_dir":os.path.join('output','model','bert_model','checkpoint'), 27 | "bert_model_pb":os.path.join('output','model','bert_model','saved_model'), 28 | "bert_ce_model_dir":os.path.join('output','model','bert_ce_model','checkpoint'), 29 | "bert_ce_model_pb":os.path.join('output','model','bert_ce_model','saved_model'), 30 | "bert_crf_model_dir":os.path.join('output','model','bert_crf_model','checkpoint'), 31 | "bert_crf_model_pb":os.path.join('output','model','bert_crf_model','saved_model'), 32 | "bert_crflstm_model_dir":os.path.join('output','model','bert_crflstm_model','checkpoint'), 33 | "bert_crflstm_model_pb":os.path.join('output','model','bert_crflstm_model','saved_model'), 34 | "bert_focaldsc_model_dir":os.path.join('output','model','bert_focaldsc_model','checkpoint'), 35 | "bert_focaldsc_model_pb":os.path.join('output','model','bert_focaldsc_model','saved_model'), 36 | "bert_lstmcrf_lesslabel_model_dir":os.path.join('output','model','bert_lstmcrf_lesslabel_model','checkpoint'), 37 | "bert_lstmcrf_lesslabel_model_pb":os.path.join('output','model','bert_lstmcrf_lesslabel_model','saved_model'), 38 | "bert_dlloss_lesslabel_model_dir":os.path.join('output','model','bert_dlloss_lesslabel_model','checkpoint'), 39 | "bert_dlloss_lesslabel_model_pb":os.path.join('output','model','bert_dlloss_lesslabel_model','saved_model'), 40 | "bert_dlloss_cslesslabel_model_dir":os.path.join('output','model','bert_dlloss_cslesslabel_model','checkpoint'), 41 | "bert_dlloss_cslesslabel_model_pb":os.path.join('output','model','bert_dlloss_cslesslabel_model','saved_model'), 42 | "bert_lstmcrf_cslesslabel_model_dir":os.path.join('output','model','bert_lstmcrf_cslesslabel_model','checkpoint'), 43 | "bert_lstmcrf_cslesslabel_model_pb":os.path.join('output','model','bert_lstmcrf_cslesslabel_model','saved_model'), 44 | "bert_celoss_cslesslabel_model_dir":os.path.join('output','model','bert_celoss_cslesslabel_model','checkpoint'), 45 | "bert_celoss_cslesslabel_model_pb":os.path.join('output','model','bert_celoss_cslesslabel_model','saved_model'), 46 | "bert_dscloss_cslesslabel_model_dir":os.path.join('output','model','bert_dscloss_cslesslabel_model','checkpoint'), 47 | "bert_dscloss_cslesslabel_model_pb":os.path.join('output','model','bert_dscloss_cslesslabel_model','saved_model'), 48 | 49 | } 50 | # print(os.path.join(config.get("train_valid_data_dir"),config.get("train_data_text_name"))) 51 | -------------------------------------------------------------------------------- /configs/bert_mrc_config.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | BASE_DIR = Path('slot_extraction') 4 | bert_mrc_config = { 5 | 'data_dir':'data', 6 | 'vocab_file':"vocab.txt", 7 | 'slot_list_root_path':os.path.join('data','slot_pattern'), 8 | 'bert_slot_file_name':"bert_basic_slot_list", 9 | 'bert_slot_complete_file_name':"bert_complete_slot_pattern", 10 | 'log_dir': os.path.join('output','log'), 11 | 'data_file_name':'orig_data_train.txt', 12 | 'train_valid_data_dir':'train_valid_data_bert_mrc', 13 | 'train_data_text_name':'train_split_data_text.npy', 14 | 'valid_data_text_name':'valid_split_data_text.npy', 15 | 'train_data_start_tag_name':'train_split_data_start_tag.npy', 16 | 'train_data_end_tag_name':'train_split_data_end_tag.npy', 17 | 'valid_data_start_tag_name':'valid_split_data_start_tag.npy', 18 | 'valid_data_end_tag_name':'valid_split_data_end_tag.npy', 19 | 'train_data_token_type_ids_name':'train_split_data_token_type_ids.npy', 20 | 'valid_data_token_type_ids_name':'valid_split_data_token_type_ids.npy', 21 | 'train_data_query_len_name':'train_split_data_query_len.npy', 22 | 'valid_data_query_len_name':'valid_split_data_query_len.npy', 23 | 'test_data_token_type_ids_name': 'test_split_data_token_type_ids.npy', 24 | 'test_data_text_name':'test_data_text.npy', 25 | 'test_data_tag_name':'test_data_tag.npy', 26 | 'test_data_query_len_name':'test_data_query_len.npy', 27 | 'test_data_query_class':'test_data_query_class.npy', 28 | 'test_data_src_sample_id':'test_data_src_sample_id.npy', 29 | 'orig_dev':'orig_data_dev.txt', 30 | 'orig_test':'orig_data_test.txt', 31 | "standard_slot_description":os.path.join('data','slot_pattern','slot_description.csv'), 32 | "bert_pretrained_model_path":os.path.join('data','chinese_roberta_wwm_ext_L-12_H-768_A-12'), 33 | "bert_config_path":"bert_config.json", 34 | 'bert_init_checkpoints':'bert_model.ckpt', 35 | "bert_mrc_model_dir":os.path.join('output','model','bert_mrc_model','checkpoint'), 36 | "bert_mrc_model_pb":os.path.join('output','model','bert_mrc_model','saved_model'), 37 | 'direct_train_data_text_name':'direct_train_split_data_text.npy', 38 | 'direct_valid_data_text_name':'direct_valid_split_data_text.npy', 39 | 'direct_train_data_start_tag_name':'direct_train_split_data_start_tag.npy', 40 | 'direct_train_data_end_tag_name':'direct_train_split_data_end_tag.npy', 41 | 'direct_valid_data_start_tag_name':'direct_valid_split_data_start_tag.npy', 42 | 'direct_valid_data_end_tag_name':'direct_valid_split_data_end_tag.npy', 43 | 'direct_train_data_token_type_ids_name':'direct_train_split_data_token_type_ids.npy', 44 | 'direct_valid_data_token_type_ids_name':'direct_valid_split_data_token_type_ids.npy', 45 | 'direct_train_data_query_len_name':'direct_train_split_data_query_len.npy', 46 | 'direct_valid_data_query_len_name':'direct_valid_split_data_query_len.npy', 47 | "direct_bert_mrc_model_dir":os.path.join('output','model','direct_bert_mrc_model','checkpoint'), 48 | "direct_bert_mrc_model_pb":os.path.join('output','model','direct_bert_mrc_model','saved_model'), 49 | 'direct_test_data_token_type_ids_name': 'direct_test_split_data_token_type_ids.npy', 50 | 'direct_test_data_text_name':'direct_test_data_text.npy', 51 | 'direct_test_data_tag_name':'direct_test_data_tag.npy', 52 | 'direct_test_data_query_len_name':'direct_test_data_query_len.npy', 53 | 'direct_test_data_query_class':'direct_test_data_query_class.npy', 54 | 'direct_test_data_src_sample_id':'direct_test_data_src_sample_id.npy', 55 | "bert_mrc_dice_model_dir":os.path.join('output','model','bert_mrc_dice_model','checkpoint'), 56 | "bert_mrc_dice_model_pb":os.path.join('output','model','bert_mrc_dice_model','saved_model'), 57 | "bert_mrc_ratio_imb_model_dir":os.path.join('output','model','bert_mrc_ratio_imb_model','checkpoint'), 58 | "bert_mrc_ratio_imb_model_pb":os.path.join('output','model','bert_mrc_ratio_imb_model','saved_model'), 59 | "bert_mrc_ratio_imb_use_start_label_model_dir":os.path.join('output','model','bert_mrc_ratio_imb_use_start_label_model','checkpoint'), 60 | "bert_mrc_ratio_imb_use_start_label_model_pb":os.path.join('output','model','bert_mrc_ratio_imb_use_start_label_model','saved_model'), 61 | "bert_mrc_focal_loss_model_dir":os.path.join('output','model','bert_mrc_focal_loss_model','checkpoint'), 62 | "bert_mrc_focal_loss_model_pb":os.path.join('output','model','bert_mrc_focal_loss_model','saved_model'), 63 | } 64 | # print(os.path.join(config.get("train_valid_data_dir"),config.get("train_data_text_name"))) 65 | -------------------------------------------------------------------------------- /configs/event_config.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | 4 | BASE_DIR = Path('slot_extraction') 5 | event_config = { 6 | 'data_dir': 'data', 7 | 'vocab_file': "vocab.txt", 8 | 'slot_list_root_path': os.path.join('data', 'slot_pattern'), 9 | 'slot_file_name': "vocab_trigger_label_map.txt", 10 | 'bert_slot_file_name': "vocab_trigger_label_map.txt", 11 | 'bert_slot_complete_file_name': "vocab_trigger_label_map.txt", 12 | 'bert_slot_complete_file_name_role': "vocab_all_slot_label_noBI_map.txt", 13 | 'query_map_file': "slot_descrip", 14 | 'event_type_file': "vocab_all_event_type_label_map.txt", 15 | 'all_slot_file': "vocab_all_slot_label_map.txt", 16 | 'log_dir': os.path.join('output', 'log'), 17 | 'data_file_name': 'orig_data_train.txt', 18 | 'event_data_file_train': "train.json", 19 | 'event_data_file_eval': "dev.json", 20 | 'event_data_file_test': "test.json", 21 | 'train_valid_data_dir': 'train_valid_data_bert_event', 22 | 'train_data_text_name': 'train_split_data_text.npy', 23 | 'valid_data_text_name': 'valid_split_data_text.npy', 24 | 'train_data_tag_name': 'train_split_data_tag.npy', 25 | 'valid_data_tag_name': 'valid_split_data_tag.npy', 26 | 'test_data_text_name': 'test_data_text.npy', 27 | 'test_data_tag_name': 'test_data_tag.npy', 28 | # "bert_pretrained_model_path":os.path.join('data','chinese_roberta_wwm_ext_L-12_H-768_A-12'), 29 | "bert_pretrained_model_path": os.path.join('data', 'chinese_roberta_wwm_ext_L-12_H-1024_A-12_large'), 30 | # "bert_pretrained_model_path":os.path.join('data','albert_large_zh'), 31 | 32 | # "bert_pretrained_model_path":os.path.join('roberta_zh-master','finetune_roberta_large_wwm'), 33 | "bert_config_path": "bert_config.json", 34 | # 'bert_init_checkpoints':'model.ckpt-51000', 35 | 36 | 'bert_init_checkpoints': 'bert_model.ckpt', 37 | "bert_model_dir": os.path.join('output', 'model', 'event_trigger_bert_model', 'checkpoint'), 38 | "bert_model_pb": os.path.join('output', 'model', 'event_trigger_bert_model', 'saved_model'), 39 | "role_bert_model_dir": os.path.join('output', 'model', 40 | 'wwm_lr_fold_{}_usingtype_roberta_large_traindev_event_role_bert_mrc_model_desmodified_lowercase', 41 | 'checkpoint'), 42 | "role_bert_model_pb": os.path.join('output', 'model', 43 | 'wwm_lr_fold_{}_usingtype_roberta_large_traindev_event_role_bert_mrc_model_desmodified_lowercase', 44 | 'saved_model'), 45 | "student_role_bert_model_dir": os.path.join('output', 'model', 46 | 'student_usingtype_roberta_large_traindev_event_role_bert_mrc_model_desmodified_lowercase', 47 | 'checkpoint'), 48 | "student_role_bert_model_pb": os.path.join('output', 'model', 49 | 'student_usingtype_roberta_large_traindev_event_role_bert_mrc_model_desmodified_lowercase', 50 | 'saved_model'), 51 | "merge_role_bert_model_dir": os.path.join('output', 'model', 52 | 'merge_usingtype_roberta_traindev_event_role_bert_mrc_model_desmodified_lowercase', 53 | 'checkpoint'), 54 | "merge_role_bert_model_pb": os.path.join('output', 'model', 55 | 'merge_usingtype_roberta_traindev_event_role_bert_mrc_model_desmodified_lowercase', 56 | 'saved_model'), 57 | "merge_continue_role_bert_model_dir": os.path.join('output', 'model', 58 | 'merge_continue_usingtype_roberta_large_traindev_event_role_bert_mrc_model_desmodified_lowercase', 59 | 'checkpoint'), 60 | "merge_continue_role_bert_model_pb": os.path.join('output', 'model', 61 | 'merge_continue_usingtype_roberta_large_traindev_event_role_bert_mrc_model_desmodified_lowercase', 62 | 'saved_model'), 63 | "role_verify_cls_bert_model_dir": os.path.join('output', 'model', 64 | 'final_verify_cls_fold_{}_usingtype_roberta_large_traindev_event_role_bert_mrc_model_desmodified_lowercase', 65 | 'checkpoint'), 66 | "role_verify_cls_bert_model_pb": os.path.join('output', 'model', 67 | 'final_verify_cls_fold_{}_usingtype_roberta_large_traindev_event_role_bert_mrc_model_desmodified_lowercase', 68 | 'saved_model'), 69 | "role_verify_avmrc_bert_model_dir": os.path.join('output', 'model', 70 | 'final_verify_avmrc_fold_{}_usingtype_roberta_large_traindev_event_role_bert_mrc_model_desmodified_lowercase', 71 | 'checkpoint'), 72 | "role_verify_avmrc_bert_model_pb": os.path.join('output', 'model', 73 | 'final_verify_avmrc_fold_{}_usingtype_roberta_large_traindev_event_role_bert_mrc_model_desmodified_lowercase', 74 | 'saved_model'), 75 | 76 | "datamodified_role_bert_model_dir": os.path.join('output', 'model', 77 | 'datamodified_usingtype_roberta_large_traindev_event_role_bert_mrc_model_desmodified_lowercase', 78 | 'checkpoint'), 79 | "datamodified_role_bert_model_pb": os.path.join('output', 'model', 80 | 'datamodified_usingtype_roberta_large_traindev_event_role_bert_mrc_model_desmodified_lowercase', 81 | 'saved_model'), 82 | "datamodified_small_role_bert_model_dir": os.path.join('output', 'model', 83 | 'datamodified_usingtype_roberta_traindev_event_role_bert_mrc_model_desmodified_lowercase', 84 | 'checkpoint'), 85 | "datamodified_small_role_bert_model_pb": os.path.join('output', 'model', 86 | 'datamodified_usingtype_roberta_traindev_event_role_bert_mrc_model_desmodified_lowercase', 87 | 'saved_model'), 88 | 89 | "event_schema": "event_schema.json", 90 | "multi_task_bert_model_dir": os.path.join('output', 'model', 'event_multask_bert_model', 'checkpoint'), 91 | "multi_task_bert_model_pb": os.path.join('output', 'model', 'event_multask_bert_model', 'saved_model'), 92 | "type_class_bert_model_dir": os.path.join('output', 'model', 93 | 'index_fold_{}_roberta_large_traindev_desmodified_lowercase_event_type_class_bert_model', 94 | 'checkpoint'), 95 | "type_class_bert_model_pb": os.path.join('output', 'model', 96 | 'index_fold_{}_roberta_large_traindev_desmodified_lowercase_event_type_class_bert_model', 97 | 'saved_model'), 98 | "type_role_class_bert_model_dir": os.path.join('output', 'model', 'event_type_role_class_bert_model', 'checkpoint'), 99 | "type_role_class_bert_model_pb": os.path.join('output', 'model', 'event_type_role_class_bert_model', 'saved_model'), 100 | 101 | } 102 | # print(os.path.join(config.get("train_valid_data_dir"),config.get("train_data_text_name"))) 103 | -------------------------------------------------------------------------------- /data/slot_pattern/slot_descrip: -------------------------------------------------------------------------------- 1 | 附有价格且用于买卖的商品,其价格下降 2 | 事件发生的时间,包含年、月、日、天、周、时、分、秒等 3 | 事件发生的时间,包含年、月、日、天、周、时、分、秒等 4 | 真实和虚构的人名或者代号 5 | 事件发生的时间,包含年、月、日、天、周、时、分、秒等 6 | 某一类团体中的人员 7 | 坍塌主体 8 | 降息幅度是多少? 9 | 开庭法院名称 10 | 袭击导致了多少人死亡? 11 | 跌停股票名称 12 | 事件发生的时间,包含年、月、日、天、周、时、分、秒等 13 | 真实和虚构的人名或者代号 14 | 事件发生的时间,包含年、月、日、天、周、时、分、秒等 15 | 真实和虚构的人名或者代号 16 | 真实和虚构的人名或者代号 17 | 个人或者团体 18 | 失联者 19 | 事件发生的时间,包含年、月、日、天、周、时、分、秒等 20 | 事件发生的时间,包含年、月、日、天、周、时、分、秒等 21 | 事件发生的时间,包含年、月、日、天、周、时、分、秒等 22 | 事件发生的时间,包含年、月、日、天、周、时、分、秒等 23 | 货币金额大小 24 | 公司,商业机构,社会组织等组织机构 25 | 事件发生的时间,包含年、月、日、天、周、时、分、秒等 26 | 国家,城市,山川等抽象或具体的地点 27 | 事件发生的时间,包含年、月、日、天、周、时、分、秒等 28 | 个人或者团体 29 | 公司,商业机构,社会组织等组织机构 30 | 真实和虚构的人名或者代号 31 | 国家,城市,山川等抽象或具体的地点 32 | 爆炸导致了多少人死亡? 33 | 事件发生的时间,包含年、月、日、天、周、时、分、秒等 34 | 真实和虚构的人名或者代号 35 | 公司,商业机构,社会组织等组织机构 36 | 真实和虚构的人名或者代号 37 | 降价幅度 38 | 公司,商业机构,社会组织等组织机构 39 | 公司,商业机构,社会组织等组织机构 40 | 罢工人数 41 | 货币金额大小 42 | 事件发生的时间,包含年、月、日、天、周、时、分、秒等 43 | 真实和虚构的人名或者代号 44 | 拥有相对独立的法律地位和组织结构的行政机构 45 | 个人或者团体 46 | 个人或者公司,商业机构,社会组织等组织机构 47 | 个人或者公司,商业机构,社会组织等组织机构 48 | 作为子女后代的个人人名或者代号 49 | 涨停股票 50 | 年龄计量大小 51 | 国家,城市,山川等抽象或具体的地点 52 | 事件发生的时间,包含年、月、日、天、周、时、分、秒等 53 | 涨价幅度是多少? 54 | 国家,城市,山川等抽象或具体的地点 55 | 地震导致了多少人死亡? 56 | 国家,城市,山川等抽象或具体的地点 57 | 真实和虚构的人名或者代号 58 | 袭击导致了多少人受伤? 59 | 事件发生的时间,包含年、月、日、天、周、时、分、秒等 60 | 真实和虚构的人名或者代号 61 | 真实和虚构的人名或者代号 62 | 事件发生的时间,包含年、月、日、天、周、时、分、秒等 63 | 地震导致了多少人受伤? 64 | 事件发生的时间,包含年、月、日、天、周、时、分、秒等 65 | 车祸导致了多少人受伤? 66 | 事件发生的时间,包含年、月、日、天、周、时、分、秒等 67 | 个人或者团体 68 | 人民法院、人民检察院和公安机关等行政机构 69 | 比赛名称 70 | 国家,城市,山川等抽象或具体的地点 71 | 生日方年龄 72 | 什么人或者什么组织发起了袭击? 73 | 真实和虚构的人名或者代号 74 | 自组织的团体或群体 75 | 事件发生的时间,包含年、月、日、天、周、时、分、秒等 76 | 公司,商业机构等组织机构 77 | 事件发生的时间,包含年、月、日、天、周、时、分、秒等 78 | 公司,商业机构等组织机构 79 | 国家,城市,山川等抽象或具体的地点 80 | 洪灾导致了多少人受伤? 81 | 事件发生的时间,包含年、月、日、天、周、时、分、秒等 82 | 活动名称 83 | 公司,商业机构等组织机构 84 | 事件发生的时间,包含年、月、日、天、周、时、分、秒等 85 | 真实和虚构的人名或者代号 86 | 个人或者公司,商业机构,社会组织等组织机构 87 | 交易物 88 | 事件发生的时间,包含年、月、日、天、周、时、分、秒等 89 | 事件发生的时间,包含年、月、日、天、周、时、分、秒等 90 | 事件发生的时间,包含年、月、日、天、周、时、分、秒等 91 | 事件发生的时间,包含年、月、日、天、周、时、分、秒等 92 | 公司,商业机构,社会组织等组织机构 93 | 影视剧名称 94 | 国家或地区的中央银行名称或代号 95 | 个人或者团体 96 | 震源深度,包含深度距离如公里数或者千米数 97 | 事件发生的时间,包含年、月、日、天、周、时、分、秒等 98 | 事件发生的时间,包含年、月、日、天、周、时、分、秒等 99 | 夺得冠军的个人或者团体 100 | 公司,商业机构等组织机构 101 | 个人或者公司,商业机构,社会组织等组织机构 102 | 事件发生的时间,包含年、月、日、天、周、时、分、秒等 103 | 真实和虚构的人名或者代号或者职称 104 | 车祸导致了多少人死亡? 105 | 事件发生的时间,包含年、月、日、天、周、时、分、秒等 106 | 公司,商业机构等组织机构 107 | 公司,商业机构等组织机构 108 | 真实和虚构的人名或者团体代号 109 | 事件发生的时间,包含年、月、日、天、周、时、分、秒等 110 | 事件发生的时间,包含年、月、日、天、周、时、分、秒等 111 | 事件发生的时间,包含年、月、日、天、周、时、分、秒等 112 | 事件发生的时间,包含年、月、日、天、周、时、分、秒等 113 | 真实和虚构的人名或者代号 114 | 事件发生的时间,包含年、月、日、天、周、时、分、秒等 115 | 事件发生的时间,包含年、月、日、天、周、时、分、秒等 116 | 真实和虚构的人名或者代号 117 | 真实和虚构的人名或者代号 118 | 坍塌事故导致了多少人死亡? 119 | 个人或者公司,商业机构,社会组织等组织机构 120 | 国家,城市,山川等抽象或具体的地点 121 | 事件发生的时间,包含年、月、日、天、周、时、分、秒等 122 | 公司,商业机构等组织机构 123 | 真实和虚构的人名或者团体代号 124 | 洪灾导致了多少人死亡? 125 | 真实和虚构的人名或者代号 126 | 事件发生的时间,包含年、月、日、天、周、时、分、秒等 127 | 公司,商业机构,社会组织等组织机构 128 | 下架产品 129 | 事件发生的时间,包含年、月、日、天、周、时、分、秒等 130 | 事件发生的时间,包含年、月、日、天、周、时、分、秒等 131 | 失败的个人或者团队 132 | 个人或者公司,商业机构,社会组织等组织机构 133 | 游行人数 134 | 公司,商业机构等组织机构 135 | 火灾事故导致了多少人受伤? 136 | 事件发生的时间,包含年、月、日、天、周、时、分、秒等 137 | 爆炸事故导致了多少人受伤? 138 | 事件发生的时间,包含年、月、日、天、周、时、分、秒等 139 | 什么组织机构发起了约谈? 140 | 个人或者团体 141 | 事件发生的时间,包含年、月、日、天、周、时、分、秒等 142 | 胜利的个人或者团队 143 | 事件发生的时间,包含年、月、日、天、周、时、分、秒等 144 | 事件发生的时间,包含年、月、日、天、周、时、分、秒等 145 | 国家,城市,山川等抽象或具体的地点 146 | 融资轮次 147 | 法院判决宣告的剥夺自由或剥夺政治权利的刑罚应予执行的时长 148 | 事件发生的时间,包含年、月、日、天、周、时、分、秒等 149 | 事件发生的时间,包含年、月、日、天、周、时、分、秒等 150 | 国家,城市,山川等抽象或具体的地点 151 | 事件发生的时间,包含年、月、日、天、周、时、分、秒等 152 | 比赛名称 153 | 涨价物 154 | 公司,商业机构,社会组织等组织机构 155 | 个人或者团体名称 156 | 拥有相对独立的法律地位和组织结构的行政机构 157 | 国家,城市,山川等抽象或具体的地点 158 | 真实和虚构的人名或者代号 159 | 事件发生的时间,包含年、月、日、天、周、时、分、秒等 160 | 针对在某个领域中有特殊表现的人或事或物进行表彰而设立的项目 161 | 事件发生的时间,包含年、月、日、天、周、时、分、秒等 162 | 坠机事故导致了多少人死亡? 163 | 事件发生的时间,包含年、月、日、天、周、时、分、秒等 164 | 召回内容 165 | 震源在地表的投影点,震中所在地 166 | 个人或者公司,商业机构,社会组织等组织机构 167 | 比赛名称 168 | 事件发生的时间,包含年、月、日、天、周、时、分、秒等 169 | 事件发生的时间,包含年、月、日、天、周、时、分、秒等 170 | 真实和虚构的人名或者代号 171 | 事件发生的时间,包含年、月、日、天、周、时、分、秒等 172 | 事件发生的时间,包含年、月、日、天、周、时、分、秒等 173 | 个人或者团体 174 | 加息幅度 175 | 禁赛时长 176 | 公司,商业机构,社会组织等组织机构 177 | 活动名称 178 | 真实和虚构的人名或者代号 179 | 事件发生的时间,包含年、月、日、天、周、时、分、秒等 180 | 真实和虚构的人名或者代号 181 | 真实和虚构的人名或者代号,或者团体 182 | 火灾事故导致了多少人死亡? 183 | 事件发生的时间,包含年、月、日、天、周、时、分、秒等 184 | 坍塌事故导致了多少人受伤 185 | 真实和虚构的人名或者代号 186 | 事件发生的时间,包含年、月、日、天、周、时、分、秒等 187 | 公司,商业机构,社会组织等组织机构 188 | 公安、司法机关及其它行政执法机关 189 | 裁员裁了多少人? 190 | 真实和虚构的人名或者代号 191 | 真实和虚构的人名或者代号 192 | 真实和虚构的人名或者代号 193 | 真实和虚构的人名或者代号 194 | 货币金额大小 195 | 真实和虚构的人名或者代号,或者团体机构 196 | 事件发生的时间,包含年、月、日、天、周、时、分、秒等 197 | 坠机事故导致了多少人受伤? 198 | 公司,商业机构,社会组织等组织机构 199 | 事件发生的时间,包含年、月、日、天、周、时、分、秒等 200 | 事件发生的时间,包含年、月、日、天、周、时、分、秒等 201 | 事件发生的时间,包含年、月、日、天、周、时、分、秒等 202 | 真实和虚构的人名或者代号 203 | 点赞对象 204 | 公司,商业机构,社会组织等组织机构 205 | 公司,商业机构等组织机构 206 | 国家,城市,山川等抽象或具体的地点 207 | 国家,城市,山川等抽象或具体的地点 208 | 货币金额大小 209 | 开庭案件 210 | 事件发生的时间,包含年、月、日、天、周、时、分、秒等 211 | 发布产品 212 | 事件发生的时间,包含年、月、日、天、周、时、分、秒等 213 | 国家,城市,山川等抽象或具体的地点 214 | 事件发生的时间,包含年、月、日、天、周、时、分、秒等 215 | 赛事名称 216 | 地震大小级别 217 | 个人或者团体 218 | -------------------------------------------------------------------------------- /data/slot_pattern/slot_descrip_old: -------------------------------------------------------------------------------- 1 | 附有价格且用于买卖的商品,其价格下降 2 | 事件发生的时间,包含年、月、日、天、周、时、分、秒等 3 | 事件发生的时间,包含年、月、日、天、周、时、分、秒等 4 | 真实和虚构的人名或者代号 5 | 事件发生的时间,包含年、月、日、天、周、时、分、秒等 6 | 某一类团体中的人员 7 | 坍塌主体 8 | 降息幅度 9 | 开庭法院名称 10 | 死亡人数 11 | 跌停股票名称 12 | 事件发生的时间,包含年、月、日、天、周、时、分、秒等 13 | 真实和虚构的人名或者代号 14 | 事件发生的时间,包含年、月、日、天、周、时、分、秒等 15 | 真实和虚构的人名或者代号 16 | 真实和虚构的人名或者代号 17 | 个人或者团体 18 | 失联者 19 | 事件发生的时间,包含年、月、日、天、周、时、分、秒等 20 | 事件发生的时间,包含年、月、日、天、周、时、分、秒等 21 | 事件发生的时间,包含年、月、日、天、周、时、分、秒等 22 | 事件发生的时间,包含年、月、日、天、周、时、分、秒等 23 | 货币金额大小 24 | 公司,商业机构,社会组织等组织机构 25 | 事件发生的时间,包含年、月、日、天、周、时、分、秒等 26 | 国家,城市,山川等抽象或具体的地点 27 | 事件发生的时间,包含年、月、日、天、周、时、分、秒等 28 | 个人或者团体 29 | 公司,商业机构,社会组织等组织机构 30 | 真实和虚构的人名或者代号 31 | 国家,城市,山川等抽象或具体的地点 32 | 死亡人数 33 | 事件发生的时间,包含年、月、日、天、周、时、分、秒等 34 | 真实和虚构的人名或者代号 35 | 公司,商业机构,社会组织等组织机构 36 | 真实和虚构的人名或者代号 37 | 降价幅度 38 | 公司,商业机构,社会组织等组织机构 39 | 公司,商业机构,社会组织等组织机构 40 | 罢工人数 41 | 货币金额大小 42 | 事件发生的时间,包含年、月、日、天、周、时、分、秒等 43 | 真实和虚构的人名或者代号 44 | 拥有相对独立的法律地位和组织结构的行政机构 45 | 个人或者团体 46 | 个人或者公司,商业机构,社会组织等组织机构 47 | 个人或者公司,商业机构,社会组织等组织机构 48 | 作为子女后代的个人人名或者代号 49 | 涨停股票 50 | 年龄计量大小 51 | 国家,城市,山川等抽象或具体的地点 52 | 事件发生的时间,包含年、月、日、天、周、时、分、秒等 53 | 涨价幅度 54 | 国家,城市,山川等抽象或具体的地点 55 | 死亡人数 56 | 国家,城市,山川等抽象或具体的地点 57 | 真实和虚构的人名或者代号 58 | 受伤人数 59 | 事件发生的时间,包含年、月、日、天、周、时、分、秒等 60 | 真实和虚构的人名或者代号 61 | 真实和虚构的人名或者代号 62 | 事件发生的时间,包含年、月、日、天、周、时、分、秒等 63 | 受伤人数 64 | 事件发生的时间,包含年、月、日、天、周、时、分、秒等 65 | 受伤人数 66 | 事件发生的时间,包含年、月、日、天、周、时、分、秒等 67 | 个人或者团体 68 | 人民法院、人民检察院和公安机关等行政机构 69 | 比赛名称 70 | 国家,城市,山川等抽象或具体的地点 71 | 生日方年龄 72 | 个人或者团体 73 | 真实和虚构的人名或者代号 74 | 自组织的团体或群体 75 | 事件发生的时间,包含年、月、日、天、周、时、分、秒等 76 | 公司,商业机构等组织机构 77 | 事件发生的时间,包含年、月、日、天、周、时、分、秒等 78 | 公司,商业机构等组织机构 79 | 国家,城市,山川等抽象或具体的地点 80 | 受伤人数 81 | 事件发生的时间,包含年、月、日、天、周、时、分、秒等 82 | 活动名称 83 | 公司,商业机构等组织机构 84 | 事件发生的时间,包含年、月、日、天、周、时、分、秒等 85 | 真实和虚构的人名或者代号 86 | 个人或者公司,商业机构,社会组织等组织机构 87 | 交易物 88 | 事件发生的时间,包含年、月、日、天、周、时、分、秒等 89 | 事件发生的时间,包含年、月、日、天、周、时、分、秒等 90 | 事件发生的时间,包含年、月、日、天、周、时、分、秒等 91 | 事件发生的时间,包含年、月、日、天、周、时、分、秒等 92 | 公司,商业机构,社会组织等组织机构 93 | 影视剧名称 94 | 国家或地区的中央银行名称或代号 95 | 个人或者团体 96 | 震源深度,包含深度距离如公里数或者千米数 97 | 事件发生的时间,包含年、月、日、天、周、时、分、秒等 98 | 事件发生的时间,包含年、月、日、天、周、时、分、秒等 99 | 夺得冠军的个人或者团体 100 | 公司,商业机构等组织机构 101 | 个人或者公司,商业机构,社会组织等组织机构 102 | 事件发生的时间,包含年、月、日、天、周、时、分、秒等 103 | 真实和虚构的人名或者代号或者职称 104 | 死亡人数 105 | 事件发生的时间,包含年、月、日、天、周、时、分、秒等 106 | 公司,商业机构等组织机构 107 | 公司,商业机构等组织机构 108 | 真实和虚构的人名或者团体代号 109 | 事件发生的时间,包含年、月、日、天、周、时、分、秒等 110 | 事件发生的时间,包含年、月、日、天、周、时、分、秒等 111 | 事件发生的时间,包含年、月、日、天、周、时、分、秒等 112 | 事件发生的时间,包含年、月、日、天、周、时、分、秒等 113 | 真实和虚构的人名或者代号 114 | 事件发生的时间,包含年、月、日、天、周、时、分、秒等 115 | 事件发生的时间,包含年、月、日、天、周、时、分、秒等 116 | 真实和虚构的人名或者代号 117 | 真实和虚构的人名或者代号 118 | 死亡人数 119 | 个人或者公司,商业机构,社会组织等组织机构 120 | 国家,城市,山川等抽象或具体的地点 121 | 事件发生的时间,包含年、月、日、天、周、时、分、秒等 122 | 公司,商业机构等组织机构 123 | 真实和虚构的人名或者团体代号 124 | 死亡人数 125 | 真实和虚构的人名或者代号 126 | 事件发生的时间,包含年、月、日、天、周、时、分、秒等 127 | 公司,商业机构,社会组织等组织机构 128 | 下架产品 129 | 事件发生的时间,包含年、月、日、天、周、时、分、秒等 130 | 事件发生的时间,包含年、月、日、天、周、时、分、秒等 131 | 失败的个人或者团队 132 | 个人或者公司,商业机构,社会组织等组织机构 133 | 游行人数 134 | 公司,商业机构等组织机构 135 | 受伤人数 136 | 事件发生的时间,包含年、月、日、天、周、时、分、秒等 137 | 受伤人数 138 | 事件发生的时间,包含年、月、日、天、周、时、分、秒等 139 | 拥有具体行政职权的机关组织机构 140 | 个人或者团体 141 | 事件发生的时间,包含年、月、日、天、周、时、分、秒等 142 | 胜利的个人或者团队 143 | 事件发生的时间,包含年、月、日、天、周、时、分、秒等 144 | 事件发生的时间,包含年、月、日、天、周、时、分、秒等 145 | 国家,城市,山川等抽象或具体的地点 146 | 融资轮次 147 | 法院判决宣告的剥夺自由或剥夺政治权利的刑罚应予执行的时长 148 | 事件发生的时间,包含年、月、日、天、周、时、分、秒等 149 | 事件发生的时间,包含年、月、日、天、周、时、分、秒等 150 | 国家,城市,山川等抽象或具体的地点 151 | 事件发生的时间,包含年、月、日、天、周、时、分、秒等 152 | 比赛名称 153 | 涨价物 154 | 公司,商业机构,社会组织等组织机构 155 | 个人或者团体名称 156 | 拥有相对独立的法律地位和组织结构的行政机构 157 | 国家,城市,山川等抽象或具体的地点 158 | 真实和虚构的人名或者代号 159 | 事件发生的时间,包含年、月、日、天、周、时、分、秒等 160 | 针对在某个领域中有特殊表现的人或事或物进行表彰而设立的项目 161 | 事件发生的时间,包含年、月、日、天、周、时、分、秒等 162 | 死亡人数 163 | 事件发生的时间,包含年、月、日、天、周、时、分、秒等 164 | 召回内容 165 | 震源在地表的投影点,震中所在地 166 | 个人或者公司,商业机构,社会组织等组织机构 167 | 比赛名称 168 | 事件发生的时间,包含年、月、日、天、周、时、分、秒等 169 | 事件发生的时间,包含年、月、日、天、周、时、分、秒等 170 | 真实和虚构的人名或者代号 171 | 事件发生的时间,包含年、月、日、天、周、时、分、秒等 172 | 事件发生的时间,包含年、月、日、天、周、时、分、秒等 173 | 个人或者团体 174 | 加息幅度 175 | 禁赛时长 176 | 公司,商业机构,社会组织等组织机构 177 | 活动名称 178 | 真实和虚构的人名或者代号 179 | 事件发生的时间,包含年、月、日、天、周、时、分、秒等 180 | 真实和虚构的人名或者代号 181 | 真实和虚构的人名或者代号,或者团体 182 | 死亡人数 183 | 事件发生的时间,包含年、月、日、天、周、时、分、秒等 184 | 受伤人数 185 | 真实和虚构的人名或者代号 186 | 事件发生的时间,包含年、月、日、天、周、时、分、秒等 187 | 公司,商业机构,社会组织等组织机构 188 | 公安、司法机关及其它行政执法机关 189 | 裁员人数 190 | 真实和虚构的人名或者代号 191 | 真实和虚构的人名或者代号 192 | 真实和虚构的人名或者代号 193 | 真实和虚构的人名或者代号 194 | 货币金额大小 195 | 真实和虚构的人名或者代号,或者团体机构 196 | 事件发生的时间,包含年、月、日、天、周、时、分、秒等 197 | 受伤人数 198 | 公司,商业机构,社会组织等组织机构 199 | 事件发生的时间,包含年、月、日、天、周、时、分、秒等 200 | 事件发生的时间,包含年、月、日、天、周、时、分、秒等 201 | 事件发生的时间,包含年、月、日、天、周、时、分、秒等 202 | 真实和虚构的人名或者代号 203 | 点赞对象 204 | 公司,商业机构,社会组织等组织机构 205 | 公司,商业机构等组织机构 206 | 国家,城市,山川等抽象或具体的地点 207 | 国家,城市,山川等抽象或具体的地点 208 | 货币金额大小 209 | 开庭案件 210 | 事件发生的时间,包含年、月、日、天、周、时、分、秒等 211 | 发布产品 212 | 事件发生的时间,包含年、月、日、天、周、时、分、秒等 213 | 国家,城市,山川等抽象或具体的地点 214 | 事件发生的时间,包含年、月、日、天、周、时、分、秒等 215 | 赛事名称 216 | 地震大小 217 | 个人或者团体 218 | -------------------------------------------------------------------------------- /data/slot_pattern/vocab_all_event_type_label_map.txt: -------------------------------------------------------------------------------- 1 | 财经/交易-上市 0 2 | 产品行为-召回 1 3 | 司法行为-举报 2 4 | 产品行为-下架 3 5 | 交往-道歉 4 6 | 财经/交易-降息 5 7 | 灾害/意外-地震 6 8 | 组织关系-解散 7 9 | 竞赛行为-夺冠 8 10 | 组织关系-辞/离职 9 11 | 灾害/意外-坠机 10 12 | 组织关系-停职 11 13 | 财经/交易-降价 12 14 | 交往-感谢 13 15 | 灾害/意外-坍/垮塌 14 16 | 产品行为-上映 15 17 | 竞赛行为-胜负 16 18 | 灾害/意外-起火 17 19 | 组织关系-解约 18 20 | 产品行为-发布 19 21 | 司法行为-开庭 20 22 | 人生-婚礼 21 23 | 组织关系-解雇 22 24 | 财经/交易-跌停 23 25 | 组织行为-罢工 24 26 | 组织行为-游行 25 27 | 人生-产子/女 26 28 | 财经/交易-加息 27 29 | 交往-会见 28 30 | 财经/交易-出售/收购 29 31 | 财经/交易-涨价 30 32 | 交往-点赞 31 33 | 人生-订婚 32 34 | 人生-分手 33 35 | 司法行为-约谈 34 36 | 灾害/意外-车祸 35 37 | 组织关系-加盟 36 38 | 财经/交易-涨停 37 39 | 竞赛行为-禁赛 38 40 | 产品行为-获奖 39 41 | 组织行为-开幕 40 42 | 人生-庆生 41 43 | 人生-出轨 42 44 | 司法行为-入狱 43 45 | 灾害/意外-爆炸 44 46 | 司法行为-立案 45 47 | 竞赛行为-晋级 46 48 | 组织关系-退出 47 49 | 交往-探班 48 50 | 人生-怀孕 49 51 | 灾害/意外-袭击 50 52 | 司法行为-罚款 51 53 | 司法行为-起诉 52 54 | 人生-离婚 53 55 | 竞赛行为-退役 54 56 | 灾害/意外-洪灾 55 57 | 竞赛行为-退赛 56 58 | 人生-结婚 57 59 | 组织关系-裁员 58 60 | 组织行为-闭幕 59 61 | 人生-失联 60 62 | 司法行为-拘捕 61 63 | 人生-死亡 62 64 | 财经/交易-融资 63 65 | 人生-求婚 64 66 | -------------------------------------------------------------------------------- /data/slot_pattern/vocab_all_slot_label_noBI_map.txt: -------------------------------------------------------------------------------- 1 | 财经/交易-降价-降价物 0 2 | 产品行为-获奖-时间 1 3 | 财经/交易-出售/收购-时间 2 4 | 司法行为-约谈-约谈对象 3 5 | 人生-死亡-时间 4 6 | 组织行为-罢工-罢工人员 5 7 | 灾害/意外-坍/垮塌-坍塌主体 6 8 | 财经/交易-降息-降息幅度 7 9 | 司法行为-开庭-开庭法院 8 10 | 灾害/意外-袭击-死亡人数 9 11 | 财经/交易-跌停-跌停股票 10 12 | 交往-感谢-时间 11 13 | 组织关系-解雇-被解雇人员 12 14 | 产品行为-发布-时间 13 15 | 司法行为-入狱-入狱者 14 16 | 组织关系-辞/离职-离职者 15 17 | 司法行为-举报-举报发起方 16 18 | 人生-失联-失联者 17 19 | 人生-失联-时间 18 20 | 司法行为-罚款-时间 19 21 | 人生-结婚-时间 20 22 | 财经/交易-降价-时间 21 23 | 财经/交易-融资-融资金额 22 24 | 产品行为-获奖-颁奖机构 23 25 | 组织关系-加盟-时间 24 26 | 人生-失联-地点 25 27 | 组织关系-解散-时间 26 28 | 司法行为-起诉-被告 27 29 | 产品行为-召回-召回方 28 30 | 人生-产子/女-产子者 29 31 | 灾害/意外-坠机-地点 30 32 | 灾害/意外-爆炸-死亡人数 31 33 | 组织关系-解雇-时间 32 34 | 交往-探班-探班对象 33 35 | 组织关系-停职-所属组织 34 36 | 竞赛行为-禁赛-被禁赛人员 35 37 | 财经/交易-降价-降价幅度 36 38 | 竞赛行为-禁赛-禁赛机构 37 39 | 组织关系-解散-解散方 38 40 | 组织行为-罢工-罢工人数 39 41 | 财经/交易-出售/收购-出售价格 40 42 | 人生-求婚-时间 41 43 | 人生-出轨-出轨对象 42 44 | 司法行为-罚款-执法机构 43 45 | 竞赛行为-晋级-晋级方 44 46 | 财经/交易-融资-跟投方 45 47 | 财经/交易-融资-融资方 46 48 | 人生-产子/女-出生者 47 49 | 财经/交易-涨停-涨停股票 48 50 | 人生-死亡-死者年龄 49 51 | 组织行为-开幕-地点 50 52 | 财经/交易-降息-时间 51 53 | 财经/交易-涨价-涨价幅度 52 54 | 人生-死亡-地点 53 55 | 灾害/意外-地震-死亡人数 54 56 | 交往-会见-地点 55 57 | 组织关系-加盟-加盟者 56 58 | 灾害/意外-袭击-受伤人数 57 59 | 灾害/意外-地震-时间 58 60 | 交往-道歉-道歉对象 59 61 | 人生-离婚-离婚双方 60 62 | 组织关系-解约-时间 61 63 | 灾害/意外-地震-受伤人数 62 64 | 人生-出轨-时间 63 65 | 灾害/意外-车祸-受伤人数 64 66 | 组织关系-退出-时间 65 67 | 司法行为-起诉-原告 66 68 | 司法行为-拘捕-拘捕者 67 69 | 竞赛行为-晋级-晋级赛事 68 70 | 组织行为-闭幕-地点 69 71 | 人生-庆生-生日方年龄 70 72 | 灾害/意外-袭击-袭击者 71 73 | 交往-道歉-道歉者 72 74 | 组织行为-游行-游行组织 73 75 | 灾害/意外-车祸-时间 74 76 | 产品行为-下架-被下架方 75 77 | 人生-婚礼-时间 76 78 | 产品行为-下架-下架方 77 79 | 灾害/意外-爆炸-地点 78 80 | 灾害/意外-洪灾-受伤人数 79 81 | 司法行为-开庭-时间 80 82 | 组织行为-闭幕-活动名称 81 83 | 财经/交易-出售/收购-收购方 82 84 | 人生-庆生-时间 83 85 | 竞赛行为-退役-退役者 84 86 | 组织关系-解约-解约方 85 87 | 财经/交易-出售/收购-交易物 86 88 | 司法行为-约谈-时间 87 89 | 财经/交易-涨停-时间 88 90 | 竞赛行为-夺冠-时间 89 91 | 人生-分手-时间 90 92 | 组织关系-解雇-解雇方 91 93 | 产品行为-上映-上映影视 92 94 | 财经/交易-加息-加息机构 93 95 | 交往-感谢-被感谢人 94 96 | 灾害/意外-地震-震源深度 95 97 | 竞赛行为-退赛-时间 96 98 | 组织行为-游行-时间 97 99 | 竞赛行为-夺冠-冠军 98 100 | 财经/交易-上市-上市企业 99 101 | 组织关系-解约-被解约方 100 102 | 组织行为-闭幕-时间 101 103 | 交往-会见-会见主体 102 104 | 灾害/意外-车祸-死亡人数 103 105 | 人生-离婚-时间 104 106 | 产品行为-发布-发布方 105 107 | 产品行为-上映-上映方 106 108 | 人生-婚礼-参礼人员 107 109 | 财经/交易-上市-时间 108 110 | 财经/交易-加息-时间 109 111 | 人生-怀孕-时间 110 112 | 产品行为-召回-时间 111 113 | 人生-订婚-订婚主体 112 114 | 竞赛行为-退役-时间 113 115 | 交往-探班-时间 114 116 | 人生-婚礼-结婚双方 115 117 | 人生-结婚-结婚双方 116 118 | 灾害/意外-坍/垮塌-死亡人数 117 119 | 组织关系-退出-退出方 118 120 | 组织行为-游行-地点 119 121 | 组织行为-罢工-时间 120 122 | 财经/交易-涨价-涨价方 121 123 | 人生-庆生-庆祝方 122 124 | 灾害/意外-洪灾-死亡人数 123 125 | 灾害/意外-袭击-袭击对象 124 126 | 交往-点赞-时间 125 127 | 组织关系-辞/离职-原所属组织 126 128 | 产品行为-下架-下架产品 127 129 | 财经/交易-融资-时间 128 130 | 产品行为-上映-时间 129 131 | 竞赛行为-胜负-败者 130 132 | 司法行为-立案-立案对象 131 133 | 组织行为-游行-游行人数 132 134 | 财经/交易-出售/收购-出售方 133 135 | 灾害/意外-起火-受伤人数 134 136 | 人生-产子/女-时间 135 137 | 灾害/意外-爆炸-受伤人数 136 138 | 组织关系-停职-时间 137 139 | 司法行为-约谈-约谈发起方 138 140 | 交往-探班-探班主体 139 141 | 产品行为-下架-时间 140 142 | 竞赛行为-胜负-胜者 141 143 | 人生-订婚-时间 142 144 | 司法行为-举报-时间 143 145 | 灾害/意外-袭击-地点 144 146 | 财经/交易-融资-融资轮次 145 147 | 司法行为-入狱-刑期 146 148 | 竞赛行为-胜负-时间 147 149 | 司法行为-拘捕-时间 148 150 | 灾害/意外-洪灾-地点 149 151 | 财经/交易-跌停-时间 150 152 | 竞赛行为-夺冠-夺冠赛事 151 153 | 财经/交易-涨价-涨价物 152 154 | 组织关系-退出-原所属组织 153 155 | 交往-会见-会见对象 154 156 | 财经/交易-降息-降息机构 155 157 | 人生-婚礼-地点 156 158 | 司法行为-拘捕-被拘捕者 157 159 | 灾害/意外-袭击-时间 158 160 | 产品行为-获奖-奖项 159 161 | 灾害/意外-起火-时间 160 162 | 灾害/意外-坠机-死亡人数 161 163 | 交往-道歉-时间 162 164 | 产品行为-召回-召回内容 163 165 | 灾害/意外-地震-震中 164 166 | 司法行为-举报-举报对象 165 167 | 竞赛行为-退赛-退赛赛事 166 168 | 司法行为-入狱-时间 167 169 | 司法行为-起诉-时间 168 170 | 人生-分手-分手双方 169 171 | 组织行为-开幕-时间 170 172 | 交往-会见-时间 171 173 | 交往-点赞-点赞方 172 174 | 财经/交易-加息-加息幅度 173 175 | 竞赛行为-禁赛-禁赛时长 174 176 | 组织行为-罢工-所属组织 175 177 | 组织行为-开幕-活动名称 176 178 | 人生-出轨-出轨方 177 179 | 灾害/意外-坠机-时间 178 180 | 人生-怀孕-怀孕者 179 181 | 人生-死亡-死者 180 182 | 灾害/意外-起火-死亡人数 181 183 | 司法行为-立案-时间 182 184 | 灾害/意外-坍/垮塌-受伤人数 183 185 | 组织关系-停职-停职人员 184 186 | 灾害/意外-爆炸-时间 185 187 | 组织关系-加盟-所加盟组织 186 188 | 司法行为-立案-立案机构 187 189 | 组织关系-裁员-裁员人数 188 190 | 人生-庆生-生日方 189 191 | 人生-求婚-求婚对象 190 192 | 人生-求婚-求婚者 191 193 | 产品行为-获奖-获奖人 192 194 | 司法行为-罚款-罚款金额 193 195 | 司法行为-罚款-罚款对象 194 196 | 灾害/意外-洪灾-时间 195 197 | 灾害/意外-坠机-受伤人数 196 198 | 组织关系-裁员-裁员方 197 199 | 财经/交易-涨价-时间 198 200 | 竞赛行为-晋级-时间 199 201 | 组织关系-裁员-时间 200 202 | 交往-感谢-致谢人 201 203 | 交往-点赞-点赞对象 202 204 | 财经/交易-融资-领投方 203 205 | 财经/交易-降价-降价方 204 206 | 灾害/意外-车祸-地点 205 207 | 财经/交易-上市-地点 206 208 | 财经/交易-上市-融资金额 207 209 | 司法行为-开庭-开庭案件 208 210 | 组织关系-辞/离职-时间 209 211 | 产品行为-发布-发布产品 210 212 | 竞赛行为-禁赛-时间 211 213 | 灾害/意外-起火-地点 212 214 | 灾害/意外-坍/垮塌-时间 213 215 | 竞赛行为-胜负-赛事名称 214 216 | 灾害/意外-地震-震级 215 217 | 竞赛行为-退赛-退赛方 216 218 | -------------------------------------------------------------------------------- /data_processing/__init__.py: -------------------------------------------------------------------------------- 1 | __author__ = 'lenovo' 2 | -------------------------------------------------------------------------------- /data_processing/basic_prepare_data.py: -------------------------------------------------------------------------------- 1 | import os 2 | import codecs 3 | import re 4 | import numpy as np 5 | import tensorflow as tf 6 | from sklearn.model_selection import StratifiedKFold,train_test_split 7 | from data_processing.data_utils import gen_char_embedding 8 | from data_processing.tokenize import CustomTokenizer,WordTokenizer 9 | # from configs.base_config import config 10 | 11 | 12 | class BaseDataPreparing(object): 13 | def __init__(self,vocab_file,slot_file,config,pretrained_embedding_file=None,word_embedding_file=None,word_seq_embedding_file=None,load_w2v_embedding=True,load_word_embedding=True,gen_new_data=False,is_inference=False): 14 | self.gen_new_data = gen_new_data 15 | self.train_data_file = os.path.join(config.get("data_dir"),config.get("data_file_name")) 16 | self.dev_data_file = os.path.join(config.get("data_dir"),config.get("orig_dev")) 17 | self.test_data_file = os.path.join(config.get("data_dir"), config.get("orig_test")) 18 | self.train_valid_split_data_path = config.get("train_valid_data_dir") 19 | self.tokenizer = CustomTokenizer(vocab_file,slot_file) 20 | self.word_tokenizer = WordTokenizer() 21 | # self.train_X_path,self.valid_X_path,self.train_Y_path,self.valid_Y_path,self.test_X_path,self.test_Y_path=None,None,None,None,None,None 22 | # self.train_word_path,self.valid_word_path,self.test_word_path = None,None,None 23 | self.slot_list = [value for key,value in self.tokenizer.slot2id.items()] 24 | self.slot_label_size = len(self.tokenizer.slot2id) 25 | if load_w2v_embedding: 26 | self.word_embedding = gen_char_embedding(pretrained_embedding_file,self.tokenizer.vocab,output_file=word_embedding_file) 27 | self.init_final_data_path(config,load_word_embedding) 28 | self.train_samples_nums = 0 29 | self.eval_samples_nums = 0 30 | self.is_inference = is_inference 31 | print("preprocessing data....") 32 | if not is_inference: 33 | if load_word_embedding: 34 | self.load_word_char_from_orig_data(gen_new_data) 35 | else: 36 | self.gen_train_dev_from_orig_data(gen_new_data) 37 | else: 38 | self.trans_test_data() 39 | 40 | if load_word_embedding: 41 | self.word_seq_embedding = gen_char_embedding(pretrained_embedding_file,self.word_tokenizer.vocab,output_file=word_seq_embedding_file) 42 | 43 | def init_final_data_path(self,config,load_word_embedding): 44 | root_path = config.get("data_dir")+"/"+config.get("train_valid_data_dir") 45 | if not os.path.exists(root_path): 46 | os.mkdir(root_path) 47 | self.train_X_path = os.path.join(config.get("data_dir"), config.get("train_valid_data_dir"), config.get("train_data_text_name")) 48 | self.valid_X_path = os.path.join(config.get("data_dir"),config.get("train_valid_data_dir"), config.get("valid_data_text_name")) 49 | self.train_Y_path = os.path.join(config.get("data_dir"),config.get("train_valid_data_dir"), config.get("train_data_tag_name")) 50 | self.valid_Y_path = os.path.join(config.get("data_dir"),config.get("train_valid_data_dir"), config.get("valid_data_tag_name")) 51 | self.test_X_path = os.path.join(config.get("data_dir"), config.get("train_valid_data_dir"), config.get("test_data_text_name")) 52 | self.test_Y_path = os.path.join(config.get("data_dir"), config.get("train_valid_data_dir"), config.get("test_data_tag_name")) 53 | if load_word_embedding: 54 | self.train_word_path = os.path.join(config.get("data_dir"), config.get("train_valid_data_dir"), config.get("train_data_text_word_name")) 55 | self.valid_word_path = os.path.join(config.get("data_dir"), config.get("train_valid_data_dir"), config.get("valid_data_text_word_name")) 56 | self.test_word_path = os.path.join(config.get("data_dir"), config.get("train_valid_data_dir"), config.get("test_data_text_word_name")) 57 | 58 | def tranform_singlg_data_example(self,text): 59 | word_list = self.tokenizer.tokenize(text) 60 | word_id_list = self.tokenizer.convert_tokens_to_ids(word_list) 61 | return word_id_list 62 | 63 | 64 | def translate_id_2_slot(self,text,label_list): 65 | entity_list = [] 66 | text_list = [w for w in text] 67 | tmp_entity = "" 68 | tmp_entity_type = "" 69 | for char,label in zip(text_list,label_list): 70 | label_string = self.tokenizer.id2slot.get(label) 71 | if label_string == "O": 72 | if tmp_entity != "": 73 | entity_list.append({tmp_entity:tmp_entity_type}) 74 | tmp_entity = "" 75 | tmp_entity_type = "" 76 | elif label_string == "PAD": 77 | break 78 | else: 79 | tmp_entity += char 80 | tmp_entity_type = re.split("-",label_string)[-1] 81 | return entity_list 82 | 83 | def trans_test_data(self): 84 | print("--------------") 85 | test_data_X, test_data_Y = self.trans_orig_data_to_training_data(self.test_data_file) 86 | test_data_X_word = self.seg_word_for_data(self.test_data_file) 87 | np.save(self.test_X_path, test_data_X) 88 | np.save(self.test_Y_path, test_data_Y) 89 | np.save(self.test_word_path, test_data_X_word) 90 | 91 | def trans_orig_data_to_training_data(self,datas_file): 92 | data_X = [] 93 | data_Y = [] 94 | with codecs.open(datas_file,'r','utf-8') as fr: 95 | for index,line in enumerate(fr): 96 | line = line.strip("\n") 97 | if index % 2 == 0: 98 | data_X.append(self.tranform_singlg_data_example(line)) 99 | else: 100 | slot_list = self.tokenizer.tokenize(line) 101 | slot_list = [slots.upper() for slots in slot_list] 102 | slot_id_list = self.tokenizer.convert_slot_to_ids(slot_list) 103 | data_Y.append(slot_id_list) 104 | return data_X,data_Y 105 | 106 | 107 | # def data_preprocessing(self,random_seed=2): 108 | # """ 109 | # 数据预处理,包括原始数据转化和训练集验证机的拆分 110 | # :return: 111 | # """ 112 | # data_X,data_Y = self.trans_orig_data_to_training_data(self.train_data_file) 113 | # 114 | # X_train, X_valid, y_train, y_valid = train_test_split(data_X,data_Y,test_size=0.1, random_state=random_seed) 115 | # self.train_samples_nums = len(X_train) 116 | # self.eval_samples_nums = len(X_valid) 117 | # np.save(self.train_X_path,X_train) 118 | # np.save(self.valid_X_path, X_valid) 119 | # np.save(self.train_Y_path, y_train) 120 | # np.save(self.valid_Y_path, y_valid) 121 | 122 | def gen_train_dev_from_orig_data(self,gen_new): 123 | if gen_new: 124 | train_data_X,train_data_Y = self.trans_orig_data_to_training_data(self.train_data_file) 125 | dev_data_X,dev_data_Y = self.trans_orig_data_to_training_data(self.dev_data_file) 126 | test_data_X,test_data_Y = self.trans_orig_data_to_training_data(self.test_data_file) 127 | # dev_data_X = np.concatenate((dev_data_X,test_data_X),axis=0) 128 | # dev_data_Y = np.concatenate((dev_data_Y,test_data_Y),axis=0) 129 | self.train_samples_nums = len(train_data_X) 130 | self.eval_samples_nums = len(dev_data_X) 131 | np.save(self.train_X_path, train_data_X) 132 | np.save(self.valid_X_path, dev_data_X) 133 | np.save(self.train_Y_path,train_data_Y) 134 | np.save(self.valid_Y_path,dev_data_Y) 135 | 136 | np.save(self.test_X_path,test_data_X) 137 | np.save(self.test_Y_path,test_data_Y) 138 | else: 139 | train_data_X = np.load(self.train_X_path) 140 | dev_data_X = np.load(self.valid_X_path) 141 | self.train_samples_nums = len(train_data_X) 142 | self.eval_samples_nums = len(dev_data_X) 143 | 144 | def gen_one_sample_words_on_chars(self,text): 145 | print(text) 146 | word_ids_split, word_str_split = self.word_tokenizer.seg(text) 147 | print(word_ids_split) 148 | print(word_str_split) 149 | word_ids_seq_char_list = [] 150 | for word,word_ids in zip(word_str_split,word_ids_split): 151 | word_len = len(word) 152 | word_ids_seq_char_list.extend([word_ids]*word_len) 153 | print(word_ids_seq_char_list) 154 | assert(len(word_ids_seq_char_list)==len(text.split(" "))) 155 | return word_ids_seq_char_list 156 | 157 | def seg_word_for_data(self,data_file): 158 | all_word_ids_char_list = [] 159 | with codecs.open(data_file,'r','utf-8') as fr: 160 | for index,line in enumerate(fr): 161 | line = line.strip("\n") 162 | if index % 2 == 0: 163 | all_word_ids_char_list.append(self.gen_one_sample_words_on_chars(line)) 164 | return all_word_ids_char_list 165 | 166 | def load_word_char_from_orig_data(self,gen_new): 167 | if gen_new: 168 | # train_data_X = np.load(self.train_X_path) 169 | # dev_data_X = np.load(self.valid_X_path) 170 | train_data_X, train_data_Y = self.trans_orig_data_to_training_data(self.train_data_file) 171 | train_data_X_word = self.seg_word_for_data(self.train_data_file) 172 | dev_data_X, dev_data_Y = self.trans_orig_data_to_training_data(self.dev_data_file) 173 | dev_data_X_word = self.seg_word_for_data(self.dev_data_file) 174 | test_data_X, test_data_Y = self.trans_orig_data_to_training_data(self.test_data_file) 175 | test_data_X_word = self.seg_word_for_data(self.test_data_file) 176 | # dev_data_X = np.concatenate((dev_data_X, test_data_X), axis=0) 177 | # dev_data_Y = np.concatenate((dev_data_Y, test_data_Y), axis=0) 178 | # dev_data_word = np.concatenate((dev_data_X_word,test_data_X_word),axis=0) 179 | self.train_samples_nums = len(train_data_X) 180 | self.eval_samples_nums = len(dev_data_X) 181 | np.save(self.train_X_path, train_data_X) 182 | np.save(self.valid_X_path, dev_data_X) 183 | np.save(self.train_Y_path, train_data_Y) 184 | np.save(self.valid_Y_path, dev_data_Y) 185 | 186 | np.save(self.test_X_path, test_data_X) 187 | np.save(self.test_Y_path, test_data_Y) 188 | np.save(self.train_word_path,train_data_X_word) 189 | np.save(self.valid_word_path,dev_data_X_word) 190 | np.save(self.test_word_path,test_data_X_word) 191 | else: 192 | train_data_X = np.load(self.train_X_path) 193 | dev_data_X = np.load(self.valid_X_path) 194 | self.train_samples_nums = len(train_data_X) 195 | self.eval_samples_nums = len(dev_data_X) 196 | -------------------------------------------------------------------------------- /data_processing/bert_mrc_prepare_data.py: -------------------------------------------------------------------------------- 1 | import os 2 | import copy 3 | import codecs 4 | import numpy as np 5 | from data_processing.basic_prepare_data import BaseDataPreparing 6 | from data_processing.mrc_query_map import ner_query_map 7 | 8 | 9 | class bertMRCPrepareData(BaseDataPreparing): 10 | def __init__(self,vocab_file,slot_file,config,bert_file,max_length,gen_new_data=False,is_inference=False,direct_cut=False): 11 | self.bert_file = bert_file 12 | self.max_length = max_length 13 | self.query_map_dict = self.gen_query_map_dict() 14 | self.direct = direct_cut 15 | super(bertMRCPrepareData,self).__init__(vocab_file,slot_file,config,pretrained_embedding_file=None,word_embedding_file=None,load_w2v_embedding=False,load_word_embedding=False,gen_new_data=gen_new_data,is_inference=is_inference) 16 | 17 | def init_final_data_path(self,config,load_word_embedding=False): 18 | root_path = config.get("data_dir") + "/" + config.get("train_valid_data_dir") 19 | if not os.path.exists(root_path): 20 | os.mkdir(root_path) 21 | self.train_X_path = os.path.join(config.get("data_dir"), config.get("train_valid_data_dir"), config.get("train_data_text_name")) 22 | self.valid_X_path = os.path.join(config.get("data_dir"),config.get("train_valid_data_dir"), config.get("valid_data_text_name")) 23 | self.train_start_Y_path = os.path.join(config.get("data_dir"),config.get("train_valid_data_dir"), config.get("train_data_start_tag_name")) 24 | self.train_end_Y_path = os.path.join(config.get("data_dir"), config.get("train_valid_data_dir"), 25 | config.get("train_data_end_tag_name")) 26 | self.valid_start_Y_path = os.path.join(config.get("data_dir"),config.get("train_valid_data_dir"), config.get("valid_data_start_tag_name")) 27 | self.valid_end_Y_path = os.path.join(config.get("data_dir"),config.get("train_valid_data_dir"), config.get("valid_data_end_tag_name")) 28 | self.test_X_path = os.path.join(config.get("data_dir"), config.get("train_valid_data_dir"), config.get("test_data_text_name")) 29 | self.train_token_type_ids_path = os.path.join(config.get("data_dir"), config.get("train_valid_data_dir"), config.get("train_data_token_type_ids_name")) 30 | self.valid_token_type_ids_path = os.path.join(config.get("data_dir"), config.get("train_valid_data_dir"), config.get("valid_data_token_type_ids_name")) 31 | self.test_token_type_ids_path = os.path.join(config.get("data_dir"), config.get("train_valid_data_dir"), config.get("test_data_token_type_ids_name")) 32 | 33 | self.train_query_len_path = os.path.join(config.get("data_dir"), config.get("train_valid_data_dir"), config.get("train_data_query_len_name")) 34 | self.valid_query_len_path = os.path.join(config.get("data_dir"), config.get("train_valid_data_dir"), config.get("valid_data_query_len_name")) 35 | self.test_query_len_path = os.path.join(config.get("data_dir"), config.get("train_valid_data_dir"), config.get("test_data_query_len_name")) 36 | self.test_query_class_path = os.path.join(config.get("data_dir"), config.get("train_valid_data_dir"), config.get("test_data_query_class")) 37 | self.src_test_sample_id_path = os.path.join(config.get("data_dir"), config.get("train_valid_data_dir"), config.get("test_data_src_sample_id")) 38 | 39 | # self.test_Y_path = os.path.join(config.get("data_dir"), config.get("train_valid_data_dir"), config.get("test_data_tag_name")) 40 | 41 | def split_one_sentence_based_on_length(self, texts, text_allow_length, labels=None): 42 | # 通用的截断 43 | data_list = [] 44 | data_label_list = [] 45 | if len(texts) > text_allow_length: 46 | left_length = 0 47 | while left_length+text_allow_length < len(texts): 48 | cur_cut_index = left_length + text_allow_length 49 | if labels != None: 50 | last_label_tmp = labels[cur_cut_index-1] 51 | if last_label_tmp.upper() != "O": 52 | while labels[cur_cut_index - 1].upper() != "O": 53 | cur_cut_index -= 1 54 | data_label_list.append(labels[left_length:cur_cut_index]) 55 | data_list.append(texts[left_length:cur_cut_index]) 56 | left_length = cur_cut_index 57 | 58 | # 别忘了最后还有余下的一小段没处理 59 | if labels != None: 60 | data_label_list.append(labels[left_length:]) 61 | data_list.append(texts[left_length:]) 62 | else: 63 | data_list.append(texts) 64 | data_label_list.append(labels) 65 | return data_list,data_label_list 66 | 67 | 68 | def split_one_sentence_based_on_entity_direct(self,texts,text_allow_length,labels=None): 69 | # 对于超过长度的直接截断 70 | data_list = [] 71 | data_label_list = [] 72 | if len(texts) > text_allow_length: 73 | pick_texts = texts[0:text_allow_length] 74 | data_list.append(pick_texts) 75 | if labels != None: 76 | data_label_list.append(labels[0:text_allow_length]) 77 | return data_list, data_label_list 78 | else: 79 | data_list.append(texts) 80 | data_label_list.append(labels) 81 | return data_list, data_label_list 82 | 83 | def token_data_text(self,text): 84 | word_list = [] 85 | if self.is_inference: 86 | word_list.extend([w for w in text if w !=" "]) 87 | else: 88 | word_list.extend(self.tokenizer.tokenize(text)) 89 | return word_list 90 | 91 | def gen_query_map_dict(self): 92 | slot_query_tag_dict = {} 93 | for slot_tag in ner_query_map.get("tags"): 94 | slot_query = ner_query_map.get("natural_query").get(slot_tag) 95 | slot_query_tokenize = [w for w in slot_query] 96 | slot_query_tokenize.insert(0, "[CLS]") 97 | slot_query_tokenize.append("[SEP]") 98 | slot_query_tag_dict.update({slot_tag:slot_query_tokenize}) 99 | return slot_query_tag_dict 100 | 101 | def find_tag_start_end_index(self,tag,label_list): 102 | start_index_tag = [0] * len(label_list) 103 | end_index_tag = [0] * len(label_list) 104 | start_tag = "B-"+tag 105 | end_tag = "I-"+tag 106 | for i in range(len(start_index_tag)): 107 | if label_list[i].upper() == start_tag: 108 | # begin 109 | start_index_tag[i] = 1 110 | elif label_list[i].upper() == end_tag: 111 | if i == len(start_index_tag)-1: 112 | # last tag 113 | end_index_tag[i] = 1 114 | else: 115 | if label_list[i+1].upper() != end_tag: 116 | end_index_tag[i] = 1 117 | return start_index_tag,end_index_tag 118 | 119 | 120 | def trans_orig_data_to_training_data(self,datas_file): 121 | data_X = [] 122 | data_start_Y = [] 123 | data_end_Y = [] 124 | token_type_ids_list = [] 125 | query_len_list = [] 126 | with codecs.open(datas_file,'r','utf-8') as fr: 127 | tmp_text_split = None 128 | for index,line in enumerate(fr): 129 | line = line.strip("\n") 130 | if index % 2 == 0: 131 | tmp_text_split = self.token_data_text(line) 132 | else: 133 | slot_label_list = self.tokenizer.tokenize(line) 134 | for slot_tag in self.query_map_dict: 135 | slot_query = self.query_map_dict.get(slot_tag) 136 | slot_query = [w for w in slot_query] 137 | query_len = len(slot_query) 138 | text_allow_max_len = self.max_length - query_len 139 | if not self.direct: 140 | gen_tmp_X_texts,gen_tmp_y_labels = self.split_one_sentence_based_on_length( 141 | tmp_text_split,text_allow_max_len,slot_label_list) 142 | else: 143 | gen_tmp_X_texts, gen_tmp_y_labels = self.split_one_sentence_based_on_entity_direct( 144 | tmp_text_split,text_allow_max_len,slot_label_list) 145 | for tmp_X,tmp_Y in zip(gen_tmp_X_texts,gen_tmp_y_labels): 146 | x_merge = slot_query + tmp_X 147 | token_type_ids = [0]*len(slot_query) + [1]*(len(tmp_X)) 148 | x_merge = self.tokenizer.convert_tokens_to_ids(x_merge) 149 | data_X.append(x_merge) 150 | start_index_tag,end_index_tag = self.find_tag_start_end_index(slot_tag,tmp_Y) 151 | # print(len(x_merge)) 152 | # print(len(start_index_tag)) 153 | start_index_tag = [0]*len(slot_query) + start_index_tag 154 | # print(len(start_index_tag)) 155 | end_index_tag = [0]*len(slot_query) + end_index_tag 156 | data_start_Y.append(start_index_tag) 157 | data_end_Y.append(end_index_tag) 158 | token_type_ids_list.append(token_type_ids) 159 | query_len_list.append(query_len) 160 | return data_X,data_start_Y,data_end_Y,token_type_ids_list,query_len_list 161 | 162 | def gen_train_dev_from_orig_data(self,gen_new): 163 | if gen_new: 164 | train_data_X,train_data_start_Y,train_data_end_Y,train_token_type_ids_list,train_query_len_list = self.trans_orig_data_to_training_data(self.train_data_file) 165 | dev_data_X,dev_data_start_Y,dev_data_end_Y,dev_token_type_ids_list,dev_query_len_list = self.trans_orig_data_to_training_data(self.dev_data_file) 166 | # test_data_X,test_data_start_Y,test_data_end_Y,test_token_type_ids_list,test_query_len_list = self.trans_orig_data_to_training_data(self.test_data_file) 167 | # dev_data_X = np.concatenate((dev_data_X,test_data_X),axis=0) 168 | # dev_data_Y = np.concatenate((dev_data_Y,test_data_Y),axis=0) 169 | self.train_samples_nums = len(train_data_X) 170 | self.eval_samples_nums = len(dev_data_X) 171 | np.save(self.train_X_path, train_data_X) 172 | np.save(self.valid_X_path, dev_data_X) 173 | np.save(self.train_start_Y_path,train_data_start_Y) 174 | np.save(self.train_end_Y_path, train_data_end_Y) 175 | np.save(self.valid_start_Y_path, dev_data_start_Y) 176 | np.save(self.train_start_Y_path, train_data_start_Y) 177 | np.save(self.valid_end_Y_path,dev_data_end_Y) 178 | np.save(self.train_token_type_ids_path, train_token_type_ids_list) 179 | np.save(self.valid_token_type_ids_path, dev_token_type_ids_list) 180 | np.save(self.train_query_len_path,train_query_len_list) 181 | np.save(self.valid_query_len_path,dev_query_len_list) 182 | # np.save(self.test_X_path,test_data_X) 183 | # np.save(self.test_token_type_ids_path, test_token_type_ids_list) 184 | # np.save(self.test_query_len_path, test_query_len_list) 185 | # np.save(self.test_Y_path,test_data_Y) 186 | else: 187 | train_data_X = np.load(self.train_X_path) 188 | dev_data_X = np.load(self.valid_X_path) 189 | self.train_samples_nums = len(train_data_X) 190 | self.eval_samples_nums = len(dev_data_X) 191 | 192 | def trans_test_data(self): 193 | self.gen_test_data_from_orig_data(self.test_data_file) 194 | 195 | def gen_test_data_from_orig_data(self,datas_file): 196 | # 相对于训练集来说,测试集构造数据要更复杂一点 197 | # 1、query要标明,2、因长度问题分割的句子最后要拼起来,因此同一个原样本的要标明 3、最后要根据query对应的实体类别根据start end 关系拼起来 198 | data_X = [] 199 | token_type_ids_list = [] 200 | query_len_list = [] 201 | query_class_list = [] 202 | src_test_sample_id = [] 203 | with codecs.open(datas_file, 'r', 'utf-8') as fr: 204 | tmp_text_split = None 205 | for index, line in enumerate(fr): 206 | line = line.strip("\n") 207 | if index % 2 == 0: 208 | tmp_text_split = self.token_data_text(line) 209 | cur_sample_id = int(index / 2) 210 | for slot_tag in self.query_map_dict: 211 | slot_query = self.query_map_dict.get(slot_tag) 212 | slot_query = [w for w in slot_query] 213 | query_len = len(slot_query) 214 | text_allow_max_len = self.max_length - query_len 215 | gen_tmp_X_texts, _ = self.split_one_sentence_based_on_length( 216 | tmp_text_split, text_allow_max_len) 217 | for tmp_X in gen_tmp_X_texts: 218 | x_merge = slot_query + tmp_X 219 | token_type_ids = [0] * len(slot_query) + [1] * (len(tmp_X)) 220 | x_merge = self.tokenizer.convert_tokens_to_ids(x_merge) 221 | data_X.append(x_merge) 222 | src_test_sample_id.append(cur_sample_id) 223 | query_class_list.append(ner_query_map.get("tags").index(slot_tag)) 224 | token_type_ids_list.append(token_type_ids) 225 | query_len_list.append(query_len) 226 | np.save(self.test_X_path,data_X) 227 | np.save(self.test_token_type_ids_path, token_type_ids_list) 228 | np.save(self.test_query_len_path, query_len_list) 229 | np.save(self.test_query_class_path,query_class_list) 230 | np.save(self.src_test_sample_id_path,src_test_sample_id) 231 | -------------------------------------------------------------------------------- /data_processing/bert_prepare_data.py: -------------------------------------------------------------------------------- 1 | import codecs 2 | from data_processing.basic_prepare_data import BaseDataPreparing 3 | 4 | class bertPrepareData(BaseDataPreparing): 5 | def __init__(self,vocab_file,slot_file,config,bert_file,max_length,gen_new_data=False,is_inference=False,label_less=True): 6 | self.bert_file = bert_file 7 | self.max_length = max_length 8 | self.label_less = label_less 9 | super(bertPrepareData,self).__init__(vocab_file,slot_file,config,pretrained_embedding_file=None,word_embedding_file=None,load_w2v_embedding=False,load_word_embedding=False,gen_new_data=gen_new_data,is_inference=is_inference) 10 | 11 | 12 | def tranform_singlg_data_example(self,text): 13 | # print(text) 14 | word_list = [] 15 | # if not self.label_less: 16 | # word_list.append("[CLS]") 17 | word_list.append("[CLS]") 18 | 19 | if self.is_inference: 20 | word_list.extend([w for w in text if w !=" "]) 21 | else: 22 | word_list.extend(self.tokenizer.tokenize(text)) 23 | if len(word_list)>=self.max_length: 24 | word_list = word_list[0:self.max_length-1] 25 | # if not self.label_less: 26 | # word_list.append("[SEP]") 27 | word_list.append("[SEP]") 28 | word_id_list = self.tokenizer.convert_tokens_to_ids(word_list) 29 | # print(len(word_id_list)) 30 | return word_id_list 31 | 32 | def trans_orig_data_to_training_data(self,datas_file): 33 | data_X = [] 34 | data_Y = [] 35 | with codecs.open(datas_file,'r','utf-8') as fr: 36 | for index,line in enumerate(fr): 37 | line = line.strip("\n") 38 | if index % 2 == 0: 39 | data_X.append(self.tranform_singlg_data_example(line)) 40 | else: 41 | # slot_list = ["[CLS]"] 42 | slot_list = [] 43 | # if not self.label_less: 44 | # slot_list.append("[CLS]") 45 | 46 | slot_list.append("O") 47 | 48 | slot_list.extend(self.tokenizer.tokenize(line)) 49 | if len(slot_list) >= self.max_length: 50 | slot_list = slot_list[0:self.max_length - 1] 51 | 52 | # if not self.label_less: 53 | # slot_list.append("[SEP]") 54 | 55 | slot_list.append("O") 56 | slot_list = [slots.upper() for slots in slot_list] 57 | slot_id_list = self.tokenizer.convert_slot_to_ids(slot_list) 58 | data_Y.append(slot_id_list) 59 | return data_X,data_Y 60 | -------------------------------------------------------------------------------- /data_processing/data_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import codecs 3 | import gensim 4 | import requests 5 | import numpy as np 6 | import json 7 | import tensorflow as tf 8 | SEGURL="" 9 | 10 | def seg_from_api(data): 11 | try: 12 | datas = {"text": data} 13 | headers = {'Content-Type': 'application/json'} 14 | res = requests.post(SEGURL, data=json.dumps(datas), headers=headers) 15 | text = res.text 16 | text_dict = json.loads(text) 17 | return text_dict 18 | except: 19 | print("dfdfdf") 20 | 21 | def seg(text): 22 | words = seg_from_api(text) 23 | word_list = [word.get("word") for word in words] 24 | return word_list 25 | 26 | def read_slots(slot_file_path=None,slot_source_type="file"): 27 | """ 28 | 根据不同的槽位模板文件,生成槽位的label 29 | :param slot_file_path: 30 | :param slot_source_type: 31 | :return: 32 | """ 33 | slot2id_dict = {} 34 | id2slot_dict = {} 35 | if slot_source_type == "file": 36 | with codecs.open(slot_file_path,'r','utf-8') as fr: 37 | for i,line in enumerate(fr): 38 | line=line.strip("\n") 39 | line = line.strip("\r") 40 | slot2id_dict[line] = i 41 | id2slot_dict[i] = line 42 | return slot2id_dict,id2slot_dict 43 | 44 | def gen_char_embedding(pretrained_char_embedding_file=None,gram_dict=None,embedding_dim=300,output_file=None): 45 | if not os.path.exists(output_file): 46 | word2vec = gensim.models.KeyedVectors.load_word2vec_format(pretrained_char_embedding_file, binary=False,unicode_errors='ignore') 47 | text_wordvec = np.zeros((len(gram_dict), embedding_dim)) 48 | print("gen_word2vec.....") 49 | count = 0 50 | for word, word_index in gram_dict.items(): 51 | count += 1 52 | if count % 500 == 0: 53 | print("count:{}.......".format(count)) 54 | try: 55 | word_vec = word2vec[word] 56 | text_wordvec[word_index] = word_vec 57 | except: 58 | print("exception:{}".format(word)) 59 | continue 60 | 61 | np.save(output_file,text_wordvec) 62 | return text_wordvec 63 | else: 64 | text_wordvec = np.load(output_file,allow_pickle=True) 65 | return text_wordvec 66 | 67 | 68 | 69 | def data_generator(input_X,input_X_word,label_Y): 70 | # input_X = np.load(text_path,allow_pickle=True) 71 | # label_Y = np.load(label_path,allow_pickle=True) 72 | for index in range(len(input_X)): 73 | text_x = input_X[index] 74 | text_x_word = input_X_word[index] 75 | label = label_Y[index] 76 | yield (text_x,text_x_word, len(text_x)),label 77 | 78 | def input_fn(input_X,input_X_word, label_Y, is_training, args): 79 | _shapes = (([None],[None],()), [None]) 80 | _types = ((tf.int32,tf.int32, tf.int32), tf.int32) 81 | _pads = ((0,0,0), 0) 82 | ds = tf.data.Dataset.from_generator( 83 | lambda: data_generator(input_X,input_X_word, label_Y), 84 | output_shapes=_shapes, 85 | output_types=_types, ) 86 | if is_training: 87 | # input_X = np.load(data_loader.train_X_path,allow_pickle=True) 88 | # label_Y = np.load(data_loader.train_Y_path,allow_pickle=True) 89 | ds = ds.shuffle(args.shuffle_buffer).repeat() 90 | 91 | ds = ds.padded_batch(args.train_batch_size, _shapes, _pads) 92 | ds = ds.prefetch(args.pre_buffer_size) 93 | 94 | return ds 95 | 96 | def data_generator_bert(input_X,label_Y): 97 | # input_X = np.load(text_path,allow_pickle=True) 98 | # label_Y = np.load(label_path,allow_pickle=True) 99 | for index in range(len(input_X)): 100 | text_x = input_X[index] 101 | label = label_Y[index] 102 | yield (text_x, len(text_x)),label 103 | 104 | def input_bert_fn(input_X, label_Y, is_training, args): 105 | _shapes = (([None], ()), [None]) 106 | _types = ((tf.int32, tf.int32), tf.int32) 107 | _pads = ((0,0), 0) 108 | ds = tf.data.Dataset.from_generator( 109 | lambda: data_generator_bert(input_X, label_Y), 110 | output_shapes=_shapes, 111 | output_types=_types, ) 112 | if is_training: 113 | # input_X = np.load(data_loader.train_X_path,allow_pickle=True) 114 | # label_Y = np.load(data_loader.train_Y_path,allow_pickle=True) 115 | ds = ds.shuffle(args.shuffle_buffer).repeat() 116 | 117 | ds = ds.padded_batch(args.train_batch_size, _shapes, _pads) 118 | ds = ds.prefetch(args.pre_buffer_size) 119 | 120 | return ds 121 | 122 | def data_generator_bert_mrc(input_Xs,start_Ys,end_Ys,token_type_ids,query_lens): 123 | for index in range(len(input_Xs)): 124 | input_x = input_Xs[index] 125 | start_y = start_Ys[index] 126 | end_y = end_Ys[index] 127 | token_type_id = token_type_ids[index] 128 | query_len = query_lens[index] 129 | yield (input_x,len(input_x),query_len,token_type_id),(start_y,end_y) 130 | 131 | def input_bert_mrc_fn(input_Xs,start_Ys,end_Ys,token_type_ids,query_lens,is_training,args): 132 | _shapes = (([None], (),(),[None]), ([None],[None])) 133 | _types = ((tf.int32,tf.int32,tf.int32,tf.int32),(tf.int32,tf.int32)) 134 | _pads = ((0,0,0,0),(0,0)) 135 | ds = tf.data.Dataset.from_generator( 136 | lambda: data_generator_bert_mrc(input_Xs, start_Ys,end_Ys,token_type_ids,query_lens), 137 | output_shapes=_shapes, 138 | output_types=_types, ) 139 | if is_training: 140 | ds = ds.shuffle(args.shuffle_buffer).repeat() 141 | 142 | ds = ds.padded_batch(args.train_batch_size, _shapes, _pads) 143 | ds = ds.prefetch(args.pre_buffer_size) 144 | 145 | return ds 146 | -------------------------------------------------------------------------------- /data_processing/mrc_query_map.py: -------------------------------------------------------------------------------- 1 | ner_query_map = { 2 | "tags":["ORG","PER","LOC"], 3 | "natural_query":{ 4 | "ORG":"找出公司,商业机构,社会组织等组织机构", 5 | "LOC":"找出国家,城市,山川等抽象或具体的地点", 6 | "PER":"找出真实和虚构的人名" 7 | } 8 | } -------------------------------------------------------------------------------- /data_processing/tokenize.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Tokenization classes.""" 16 | 17 | import collections 18 | import unicodedata 19 | import six 20 | import codecs 21 | import tensorflow as tf 22 | import re 23 | from data_processing.data_utils import read_slots 24 | from data_processing.data_utils import seg 25 | 26 | def convert_to_unicode(text): 27 | """Converts `text` to Unicode (if it's not already), assuming utf-8 input.""" 28 | if six.PY3: 29 | if isinstance(text, str): 30 | return text 31 | elif isinstance(text, bytes): 32 | return text.decode("utf-8", "ignore") 33 | else: 34 | raise ValueError("Unsupported string type: %s" % (type(text))) 35 | elif six.PY2: 36 | if isinstance(text, str): 37 | return text.decode("utf-8", "ignore") 38 | elif isinstance(text, unicode): 39 | return text 40 | else: 41 | raise ValueError("Unsupported string type: %s" % (type(text))) 42 | else: 43 | raise ValueError("Not running on Python2 or Python 3?") 44 | 45 | 46 | def printable_text(text): 47 | """Returns text encoded in a way suitable for print or `tf.logging`.""" 48 | 49 | # These functions want `str` for both Python2 and Python3, but in one case 50 | # it's a Unicode string and in the other it's a byte string. 51 | if six.PY3: 52 | if isinstance(text, str): 53 | return text 54 | elif isinstance(text, bytes): 55 | return text.decode("utf-8", "ignore") 56 | else: 57 | raise ValueError("Unsupported string type: %s" % (type(text))) 58 | elif six.PY2: 59 | if isinstance(text, str): 60 | return text 61 | elif isinstance(text, unicode): 62 | return text.encode("utf-8") 63 | else: 64 | raise ValueError("Unsupported string type: %s" % (type(text))) 65 | else: 66 | raise ValueError("Not running on Python2 or Python 3?") 67 | 68 | 69 | def load_vocab(vocab_file): 70 | """Loads a vocabulary file into a dictionary.""" 71 | vocab = collections.OrderedDict() 72 | index = 0 73 | with tf.gfile.GFile(vocab_file, "r") as reader: 74 | while True: 75 | token = convert_to_unicode(reader.readline()) 76 | if not token: 77 | break 78 | token = token.strip() 79 | vocab[token] = index 80 | index += 1 81 | return vocab 82 | 83 | 84 | def convert_tokens_to_ids(vocab, tokens, unk_token="[UNK]"): 85 | """Converts a sequence of tokens into ids using the vocab.""" 86 | ids = [] 87 | for token in tokens: 88 | if token in vocab: 89 | ids.append(vocab[token]) 90 | else: 91 | ids.append(vocab[unk_token]) 92 | return ids 93 | 94 | 95 | def whitespace_tokenize(text): 96 | """Runs basic whitespace cleaning and splitting on a peice of text.""" 97 | text = text.strip() 98 | if not text: 99 | return [] 100 | tokens = text.split() 101 | return tokens 102 | 103 | class CustomTokenizer(object): 104 | def __init__(self,vocab_file,slot_file,do_lower_case=True): 105 | self.vocab = load_vocab(vocab_file) 106 | self.slot2id,self.id2slot = read_slots(slot_file) 107 | self.do_lower_case = do_lower_case 108 | 109 | def _run_strip_accents(self, text): 110 | """Strips accents from a piece of text.""" 111 | text = unicodedata.normalize("NFD", text) 112 | output = [] 113 | for char in text: 114 | cat = unicodedata.category(char) 115 | if cat == "Mn": 116 | continue 117 | output.append(char) 118 | return "".join(output) 119 | 120 | def tokenize(self,text): 121 | orig_tokens = whitespace_tokenize(text) 122 | split_tokens = [] 123 | for token in orig_tokens: 124 | if self.do_lower_case: 125 | token = token.lower() 126 | split_tokens.append(token) 127 | return split_tokens 128 | 129 | def convert_tokens_to_ids(self, tokens): 130 | return convert_tokens_to_ids(self.vocab, tokens) 131 | 132 | def convert_slot_to_ids(self, tokens): 133 | print(self.slot2id) 134 | return convert_tokens_to_ids(self.slot2id, tokens,"O") 135 | 136 | 137 | class WordTokenizer(object): 138 | def __init__(self): 139 | self.vocab = {} 140 | self.count = 1 141 | 142 | def seg(self,text): 143 | word_ids_list = [] 144 | text = re.sub(r" ","",text) 145 | text_split = seg(text) 146 | for word in text_split: 147 | if word in self.vocab: 148 | word_ids_list.append(self.vocab.get(word)) 149 | else: 150 | self.count += 1 151 | self.vocab[word] = self.count 152 | word_ids_list.append(self.count) 153 | return word_ids_list,text_split 154 | 155 | # def seg_text_list(self,texts): 156 | # result_list = [self.seg(text) for text in texts] 157 | # return result_list 158 | 159 | class FullTokenizer(object): 160 | """Runs end-to-end tokenziation.""" 161 | 162 | def __init__(self, vocab_file, do_lower_case=True): 163 | self.vocab = load_vocab(vocab_file) 164 | self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case) 165 | self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab) 166 | 167 | def tokenize(self, text): 168 | split_tokens = [] 169 | for token in self.basic_tokenizer.tokenize(text): 170 | for sub_token in self.wordpiece_tokenizer.tokenize(token): 171 | split_tokens.append(sub_token) 172 | 173 | return split_tokens 174 | 175 | def convert_tokens_to_ids(self, tokens): 176 | return convert_tokens_to_ids(self.vocab, tokens) 177 | 178 | 179 | class CharTokenizer(object): 180 | """Runs end-to-end tokenziation.""" 181 | 182 | def __init__(self, vocab_file, do_lower_case=True): 183 | self.vocab = load_vocab(vocab_file) 184 | self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case) 185 | self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab) 186 | 187 | def tokenize(self, text): 188 | split_tokens = [] 189 | for token in self.basic_tokenizer.tokenize(text): 190 | for sub_token in token: 191 | split_tokens.append(sub_token) 192 | 193 | return split_tokens 194 | 195 | def convert_tokens_to_ids(self, tokens): 196 | return convert_tokens_to_ids(self.vocab, tokens) 197 | 198 | 199 | class BasicTokenizer(object): 200 | """Runs basic tokenization (punctuation splitting, lower casing, etc.).""" 201 | 202 | def __init__(self, do_lower_case=True): 203 | """Constructs a BasicTokenizer. 204 | 205 | Args: 206 | do_lower_case: Whether to lower case the input. 207 | """ 208 | self.do_lower_case = do_lower_case 209 | 210 | def tokenize(self, text): 211 | """Tokenizes a piece of text.""" 212 | text = convert_to_unicode(text) 213 | text = self._clean_text(text) 214 | 215 | # This was added on November 1st, 2018 for the multilingual and Chinese 216 | # models. This is also applied to the English models now, but it doesn't 217 | # matter since the English models were not trained on any Chinese data 218 | # and generally don't have any Chinese data in them (there are Chinese 219 | # characters in the vocabulary because Wikipedia does have some Chinese 220 | # words in the English Wikipedia.). 221 | text = self._tokenize_chinese_chars(text) 222 | 223 | orig_tokens = whitespace_tokenize(text) 224 | split_tokens = [] 225 | for token in orig_tokens: 226 | if self.do_lower_case: 227 | token = token.lower() 228 | token = self._run_strip_accents(token) 229 | split_tokens.extend(self._run_split_on_punc(token)) 230 | 231 | output_tokens = whitespace_tokenize(" ".join(split_tokens)) 232 | return output_tokens 233 | 234 | def _run_strip_accents(self, text): 235 | """Strips accents from a piece of text.""" 236 | text = unicodedata.normalize("NFD", text) 237 | output = [] 238 | for char in text: 239 | cat = unicodedata.category(char) 240 | if cat == "Mn": 241 | continue 242 | output.append(char) 243 | return "".join(output) 244 | 245 | def _run_split_on_punc(self, text): 246 | """Splits punctuation on a piece of text.""" 247 | chars = list(text) 248 | i = 0 249 | start_new_word = True 250 | output = [] 251 | while i < len(chars): 252 | char = chars[i] 253 | if _is_punctuation(char): 254 | output.append([char]) 255 | start_new_word = True 256 | else: 257 | if start_new_word: 258 | output.append([]) 259 | start_new_word = False 260 | output[-1].append(char) 261 | i += 1 262 | 263 | return ["".join(x) for x in output] 264 | 265 | def _tokenize_chinese_chars(self, text): 266 | """Adds whitespace around any CJK character.""" 267 | output = [] 268 | for char in text: 269 | cp = ord(char) 270 | if self._is_chinese_char(cp): 271 | output.append(" ") 272 | output.append(char) 273 | output.append(" ") 274 | else: 275 | output.append(char) 276 | return "".join(output) 277 | 278 | def _is_chinese_char(self, cp): 279 | """Checks whether CP is the codepoint of a CJK character.""" 280 | # This defines a "chinese character" as anything in the CJK Unicode block: 281 | # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block) 282 | # 283 | # Note that the CJK Unicode block is NOT all Japanese and Korean characters, 284 | # despite its name. The modern Korean Hangul alphabet is a different block, 285 | # as is Japanese Hiragana and Katakana. Those alphabets are used to write 286 | # space-separated words, so they are not treated specially and handled 287 | # like the all of the other languages. 288 | if ((cp >= 0x4E00 and cp <= 0x9FFF) or # 289 | (cp >= 0x3400 and cp <= 0x4DBF) or # 290 | (cp >= 0x20000 and cp <= 0x2A6DF) or # 291 | (cp >= 0x2A700 and cp <= 0x2B73F) or # 292 | (cp >= 0x2B740 and cp <= 0x2B81F) or # 293 | (cp >= 0x2B820 and cp <= 0x2CEAF) or 294 | (cp >= 0xF900 and cp <= 0xFAFF) or # 295 | (cp >= 0x2F800 and cp <= 0x2FA1F)): # 296 | return True 297 | 298 | return False 299 | 300 | def _clean_text(self, text): 301 | """Performs invalid character removal and whitespace cleanup on text.""" 302 | output = [] 303 | for char in text: 304 | cp = ord(char) 305 | if cp == 0 or cp == 0xfffd or _is_control(char): 306 | continue 307 | if _is_whitespace(char): 308 | output.append(" ") 309 | else: 310 | output.append(char) 311 | return "".join(output) 312 | 313 | 314 | class WordpieceTokenizer(object): 315 | """Runs WordPiece tokenziation.""" 316 | 317 | def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=100): 318 | self.vocab = vocab 319 | self.unk_token = unk_token 320 | self.max_input_chars_per_word = max_input_chars_per_word 321 | 322 | def tokenize(self, text): 323 | """Tokenizes a piece of text into its word pieces. 324 | 325 | This uses a greedy longest-match-first algorithm to perform tokenization 326 | using the given vocabulary. 327 | 328 | For example: 329 | input = "unaffable" 330 | output = ["un", "##aff", "##able"] 331 | 332 | Args: 333 | text: A single token or whitespace separated tokens. This should have 334 | already been passed through `BasicTokenizer. 335 | 336 | Returns: 337 | A list of wordpiece tokens. 338 | """ 339 | 340 | text = convert_to_unicode(text) 341 | 342 | output_tokens = [] 343 | for token in whitespace_tokenize(text): 344 | chars = list(token) 345 | if len(chars) > self.max_input_chars_per_word: 346 | output_tokens.append(self.unk_token) 347 | continue 348 | 349 | is_bad = False 350 | start = 0 351 | sub_tokens = [] 352 | while start < len(chars): 353 | end = len(chars) 354 | cur_substr = None 355 | while start < end: 356 | substr = "".join(chars[start:end]) 357 | if start > 0: 358 | substr = "##" + substr 359 | if substr in self.vocab: 360 | cur_substr = substr 361 | break 362 | end -= 1 363 | if cur_substr is None: 364 | is_bad = True 365 | break 366 | sub_tokens.append(cur_substr) 367 | start = end 368 | 369 | if is_bad: 370 | output_tokens.append(self.unk_token) 371 | else: 372 | output_tokens.extend(sub_tokens) 373 | return output_tokens 374 | 375 | 376 | def _is_whitespace(char): 377 | """Checks whether `chars` is a whitespace character.""" 378 | # \t, \n, and \r are technically contorl characters but we treat them 379 | # as whitespace since they are generally considered as such. 380 | if char == " " or char == "\t" or char == "\n" or char == "\r": 381 | return True 382 | cat = unicodedata.category(char) 383 | if cat == "Zs": 384 | return True 385 | return False 386 | 387 | 388 | def _is_control(char): 389 | """Checks whether `chars` is a control character.""" 390 | # These are technically control characters but we count them as whitespace 391 | # characters. 392 | if char == "\t" or char == "\n" or char == "\r": 393 | return False 394 | cat = unicodedata.category(char) 395 | if cat.startswith("C"): 396 | return True 397 | return False 398 | 399 | 400 | def _is_punctuation(char): 401 | """Checks whether `chars` is a punctuation character.""" 402 | cp = ord(char) 403 | # We treat all non-letter/number ASCII as punctuation. 404 | # Characters such as "^", "$", and "`" are not in the Unicode 405 | # Punctuation class but we treat them as punctuation anyways, for 406 | # consistency. 407 | if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or 408 | (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)): 409 | return True 410 | cat = unicodedata.category(char) 411 | if cat.startswith("P"): 412 | return True 413 | return False 414 | -------------------------------------------------------------------------------- /gen_kfold_data.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from data_processing.event_prepare_data import EventTypeClassificationPrepare, EventRolePrepareMRC 4 | from configs.event_config import event_config 5 | 6 | 7 | def gen_type_classification_data(): 8 | """ 9 | generate event type classification data of index_type_fold_data_{} 10 | """ 11 | # bert vocab file path 12 | vocab_file_path = os.path.join(event_config.get("bert_pretrained_model_path"), event_config.get("vocab_file")) 13 | # bert config file path 14 | bert_config_file = os.path.join(event_config.get("bert_pretrained_model_path"), event_config.get("bert_config_path")) 15 | # event type list file path 16 | event_type_file = os.path.join(event_config.get("slot_list_root_path"), event_config.get("event_type_file")) 17 | data_loader =EventTypeClassificationPrepare(vocab_file_path,512,event_type_file) 18 | # train file 19 | train_file = os.path.join(event_config.get("data_dir"),event_config.get("event_data_file_train")) 20 | # eval file 21 | eval_file = os.path.join(event_config.get("data_dir"),event_config.get("event_data_file_eval")) 22 | data_loader.k_fold_split_data(train_file,eval_file,True) 23 | 24 | def gen_role_class_data(): 25 | """ 26 | generate role mrc data for verify_neg_fold_data_{} 27 | """ 28 | # bert vocab file path 29 | vocab_file_path = os.path.join(event_config.get("bert_pretrained_model_path"), event_config.get("vocab_file")) 30 | # event role slot list file path 31 | slot_file = os.path.join(event_config.get("slot_list_root_path"),event_config.get("bert_slot_complete_file_name_role")) 32 | # schema file path 33 | schema_file = os.path.join(event_config.get("data_dir"), event_config.get("event_schema")) 34 | # query map file path 35 | query_file = os.path.join(event_config.get("slot_list_root_path"),event_config.get("query_map_file")) 36 | data_loader = EventRolePrepareMRC(vocab_file_path,512,slot_file,schema_file,query_file) 37 | train_file = os.path.join(event_config.get("data_dir"),event_config.get("event_data_file_train")) 38 | eval_file = os.path.join(event_config.get("data_dir"),event_config.get("event_data_file_eval")) 39 | data_loader.k_fold_split_data(train_file,eval_file,True) 40 | 41 | if __name__ == "__main__": 42 | gen_type_classification_data() 43 | gen_role_class_data() 44 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | __author__ = 'lenovo' 2 | -------------------------------------------------------------------------------- /models/bert_mrc.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import common_utils 3 | import optimization 4 | from models.tf_metrics import precision, recall, f1 5 | from bert import modeling 6 | from models.utils import ce_loss,cal_binary_dsc_loss,dl_dsc_loss,vanilla_dsc_loss,focal_loss 7 | 8 | logger = common_utils.set_logger('NER Training...') 9 | 10 | class bertMRC(object): 11 | def __init__(self,params,bert_config): 12 | # 丢弃概率 13 | self.dropout_rate = params["dropout_prob"] 14 | self.num_labels = 2 15 | self.rnn_size = params["rnn_size"] 16 | self.num_layers = params["num_layers"] 17 | self.hidden_units = params["hidden_units"] 18 | self.bert_config = bert_config 19 | 20 | def __call__(self,input_ids,start_labels,end_labels,token_type_ids_list,query_len_list,text_length_list,is_training,is_testing=False): 21 | bert_model = modeling.BertModel( 22 | config=self.bert_config, 23 | is_training=is_training, 24 | input_ids=input_ids, 25 | text_length=text_length_list, 26 | token_type_ids=token_type_ids_list, 27 | use_one_hot_embeddings=False 28 | ) 29 | bert_seq_output = bert_model.get_sequence_output() 30 | 31 | # bert_project = tf.layers.dense(bert_seq_output, self.hidden_units, activation=tf.nn.relu) 32 | # bert_project = tf.layers.dropout(bert_project, rate=self.dropout_rate, training=is_training) 33 | start_logits = tf.layers.dense(bert_seq_output,self.num_labels) 34 | end_logits = tf.layers.dense(bert_seq_output, self.num_labels) 35 | query_span_mask = tf.cast(tf.sequence_mask(query_len_list),tf.int32) 36 | total_seq_mask = tf.cast(tf.sequence_mask(text_length_list),tf.int32) 37 | query_span_mask = query_span_mask * -1 38 | query_len_max = tf.shape(query_span_mask)[1] 39 | left_query_len_max = tf.shape(total_seq_mask)[1] - query_len_max 40 | zero_mask_left_span = tf.zeros((tf.shape(query_span_mask)[0],left_query_len_max),dtype=tf.int32) 41 | final_mask = tf.concat((query_span_mask,zero_mask_left_span),axis=-1) 42 | final_mask = final_mask + total_seq_mask 43 | predict_start_ids = tf.argmax(start_logits, axis=-1, name="pred_start_ids") 44 | predict_end_ids = tf.argmax(end_logits, axis=-1, name="pred_end_ids") 45 | if not is_testing: 46 | # one_hot_labels = tf.one_hot(labels, depth=self.num_labels, dtype=tf.float32) 47 | # start_loss = ce_loss(start_logits,start_labels,final_mask,self.num_labels,True) 48 | # end_loss = ce_loss(end_logits,end_labels,final_mask,self.num_labels,True) 49 | 50 | # focal loss 51 | start_loss = focal_loss(start_logits,start_labels,final_mask,self.num_labels,True) 52 | end_loss = focal_loss(end_logits,end_labels,final_mask,self.num_labels,True) 53 | 54 | final_loss = start_loss + end_loss 55 | return final_loss,predict_start_ids,predict_end_ids,final_mask 56 | else: 57 | return predict_start_ids,predict_end_ids,final_mask 58 | 59 | 60 | def bert_mrc_model_fn_builder(bert_config_file,init_checkpoints,args): 61 | def model_fn(features, labels, mode, params): 62 | logger.info("*** Features ***") 63 | if isinstance(features, dict): 64 | features = features['words'],features['text_length'],features['query_length'],features['token_type_ids'] 65 | print(features) 66 | input_ids,text_length_list,query_length_list,token_type_id_list = features 67 | if labels is not None: 68 | start_labels,end_labels = labels 69 | else: 70 | start_labels, end_labels = None,None 71 | is_training = (mode == tf.estimator.ModeKeys.TRAIN) 72 | is_testing = (mode == tf.estimator.ModeKeys.PREDICT) 73 | bert_config = modeling.BertConfig.from_json_file(bert_config_file) 74 | tag_model = bertMRC(params,bert_config) 75 | # input_ids,labels,token_type_ids_list,query_len_list,text_length_list,is_training,is_testing=False 76 | if is_testing: 77 | pred_start_ids,pred_end_ids,weight = tag_model(input_ids,start_labels,end_labels, token_type_id_list,query_length_list, text_length_list, is_training,is_testing) 78 | else: 79 | loss,pred_start_ids,pred_end_ids,weight = tag_model(input_ids,start_labels,end_labels,token_type_id_list,query_length_list,text_length_list,is_training) 80 | 81 | # def metric_fn(label_ids, pred_ids): 82 | # return { 83 | # 'precision': precision(label_ids, pred_ids, params["num_labels"]), 84 | # 'recall': recall(label_ids, pred_ids, params["num_labels"]), 85 | # 'f1': f1(label_ids, pred_ids, params["num_labels"]) 86 | # } 87 | # 88 | # eval_metrics = metric_fn(labels, pred_ids) 89 | tvars = tf.trainable_variables() 90 | # 加载BERT模型 91 | if init_checkpoints: 92 | (assignment_map, initialized_variable_names) = \ 93 | modeling.get_assignment_map_from_checkpoint(tvars, 94 | init_checkpoints) 95 | tf.train.init_from_checkpoint(init_checkpoints, assignment_map) 96 | output_spec = None 97 | # f1_score_val, f1_update_op_val = f1(labels=labels, predictions=pred_ids, num_classes=params["num_labels"], 98 | # weights=weight) 99 | 100 | if mode == tf.estimator.ModeKeys.TRAIN: 101 | train_op = optimization.create_optimizer(loss,args.lr, params["decay_steps"],args.clip_norm) 102 | hook_dict = {} 103 | # precision_score, precision_update_op = precision(labels=labels, predictions=pred_ids, 104 | # num_classes=params["num_labels"], weights=weight) 105 | # 106 | # recall_score, recall_update_op = recall(labels=labels, 107 | # predictions=pred_ids, num_classes=params["num_labels"], 108 | # weights=weight) 109 | hook_dict['loss'] = loss 110 | hook_dict['global_steps'] = tf.train.get_or_create_global_step() 111 | logging_hook = tf.train.LoggingTensorHook( 112 | hook_dict, every_n_iter=args.print_log_steps) 113 | 114 | output_spec = tf.estimator.EstimatorSpec( 115 | mode=mode, 116 | loss=loss, 117 | train_op=train_op, 118 | training_hooks=[logging_hook]) 119 | 120 | elif mode == tf.estimator.ModeKeys.EVAL: 121 | # pred_ids = tf.argmax(logits, axis=-1, output_type=tf.int32) 122 | # weight = tf.sequence_mask(text_length_list) 123 | # precision_score, precision_update_op = precision(labels=labels,predictions=pred_ids,num_classes=params["num_labels"],weights=weight) 124 | # 125 | # recall_score, recall_update_op =recall(labels=labels, 126 | # predictions=pred_ids,num_classes=params["num_labels"],weights=weight) 127 | f1_start_val,f1_update_op_val = f1(labels=start_labels,predictions=pred_start_ids,num_classes=2,weights=weight,average="macro") 128 | f1_end_val,f1_end_update_op_val = f1(labels=end_labels,predictions=pred_end_ids,num_classes=2,weights=weight,average="macro") 129 | 130 | # f1_score_val_micro,f1_update_op_val_micro = f1(labels=labels,predictions=pred_ids,num_classes=params["num_labels"],weights=weight,average="micro") 131 | 132 | # acc_score_val,acc_score_op_val = tf.metrics.accuracy(labels=labels,predictions=pred_ids,weights=weight) 133 | # eval_loss = tf.metrics.mean_squared_error(labels=labels, predictions=pred_ids,weights=weight) 134 | eval_metric_ops = { 135 | "f1_start_micro":(f1_start_val,f1_update_op_val), 136 | "f1_end_micro":(f1_end_val,f1_end_update_op_val)} 137 | 138 | # eval_hook_dict = {"f1":f1_score_val,"loss":loss} 139 | 140 | # eval_logging_hook = tf.train.LoggingTensorHook( 141 | # at_end=True,every_n_iter=args.print_log_steps) 142 | output_spec = tf.estimator.EstimatorSpec( 143 | eval_metric_ops=eval_metric_ops, 144 | mode=mode, 145 | loss=loss 146 | ) 147 | else: 148 | output_spec = tf.estimator.EstimatorSpec( 149 | mode=mode, 150 | predictions={"start_ids":pred_start_ids,"end_ids":pred_end_ids} 151 | ) 152 | return output_spec 153 | return model_fn 154 | -------------------------------------------------------------------------------- /models/event_verify_av.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import common_utils 3 | import optimization 4 | from models.tf_metrics import precision, recall, f1 5 | # from albert import modeling,modeling_google 6 | from bert import modeling 7 | # from bert import modeling_theseus 8 | from models.utils import focal_loss 9 | from tensorflow.python.ops import metrics as metrics_lib 10 | 11 | logger = common_utils.set_logger('NER Training...') 12 | 13 | 14 | class VerfiyMRC(object): 15 | def __init__(self, params, bert_config): 16 | # 丢弃概率 17 | self.dropout_rate = params["dropout_prob"] 18 | self.num_labels = 2 19 | self.bert_config = bert_config 20 | 21 | def __call__(self, input_ids, start_labels, end_labels, token_type_ids_list, query_len_list, text_length_list, 22 | has_answer_label, is_training, is_testing=False): 23 | bert_model = modeling.BertModel( 24 | config=self.bert_config, 25 | is_training=is_training, 26 | input_ids=input_ids, 27 | text_length=text_length_list, 28 | token_type_ids=token_type_ids_list, 29 | use_one_hot_embeddings=False 30 | ) 31 | bert_seq_output = bert_model.get_sequence_output() 32 | first_seq_hidden = bert_model.get_pooled_output() 33 | # bert_project = tf.layers.dense(bert_seq_output, self.hidden_units, activation=tf.nn.relu) 34 | # bert_project = tf.layers.dropout(bert_project, rate=self.dropout_rate, training=is_training) 35 | start_logits = tf.layers.dense(bert_seq_output, self.num_labels) 36 | end_logits = tf.layers.dense(bert_seq_output, self.num_labels) 37 | query_span_mask = tf.cast(tf.sequence_mask(query_len_list), tf.int32) 38 | total_seq_mask = tf.cast(tf.sequence_mask(text_length_list), tf.int32) 39 | query_span_mask = query_span_mask * -1 40 | query_len_max = tf.shape(query_span_mask)[1] 41 | left_query_len_max = tf.shape(total_seq_mask)[1] - query_len_max 42 | zero_mask_left_span = tf.zeros((tf.shape(query_span_mask)[0], left_query_len_max), dtype=tf.int32) 43 | final_mask = tf.concat((query_span_mask, zero_mask_left_span), axis=-1) 44 | final_mask = final_mask + total_seq_mask 45 | predict_start_ids = tf.argmax(start_logits, axis=-1, name="pred_start_ids") 46 | predict_start_prob = tf.nn.softmax(start_logits, axis=-1) 47 | predict_end_prob = tf.nn.softmax(end_logits, axis=-1) 48 | predict_end_ids = tf.argmax(end_logits, axis=-1, name="pred_end_ids") 49 | # has_answer_logits = tf.layers.dropout(first_seq_hidden,rate=self.dropout_rate,training=is_training) 50 | has_answer_logits = tf.layers.dense(first_seq_hidden, 1) 51 | predict_has_answer_probs = tf.nn.sigmoid(has_answer_logits) 52 | if not is_testing: 53 | # one_hot_labels = tf.one_hot(labels, depth=self.num_labels, dtype=tf.float32) 54 | # start_loss = ce_loss(start_logits,start_labels,final_mask,self.num_labels,True) 55 | # end_loss = ce_loss(end_logits,end_labels,final_mask,self.num_labels,True) 56 | 57 | # focal loss 58 | start_loss = focal_loss(start_logits, start_labels, final_mask, self.num_labels, True, 1.8) 59 | end_loss = focal_loss(end_logits, end_labels, final_mask, self.num_labels, True, 1.8) 60 | has_answer_label = tf.cast(has_answer_label, tf.float32) 61 | per_example_loss = tf.nn.sigmoid_cross_entropy_with_logits(labels=has_answer_label, 62 | logits=has_answer_logits) 63 | has_answer_loss = tf.reduce_mean(per_example_loss) 64 | # has_answer_loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=one_hot_labels,logits=has_answer_logits)) 65 | final_loss = (1.5 * start_loss + end_loss + has_answer_loss) / 3.0 66 | return final_loss, predict_start_ids, predict_end_ids, final_mask, predict_start_prob, predict_end_prob, predict_has_answer_probs 67 | else: 68 | return predict_start_ids, predict_end_ids, final_mask, predict_start_prob, predict_end_prob, predict_has_answer_probs 69 | 70 | 71 | def event_verify_mrc_model_fn_builder(bert_config_file, init_checkpoints, args): 72 | def model_fn(features, labels, mode, params): 73 | logger.info("*** Features ***") 74 | if isinstance(features, dict): 75 | features = features['words'], features['text_length'], features['query_length'], features['token_type_ids'] 76 | print(features) 77 | input_ids, text_length_list, query_length_list, token_type_id_list = features 78 | if labels is not None: 79 | start_labels, end_labels, has_answer_label = labels 80 | else: 81 | start_labels, end_labels, has_answer_label = None, None, None 82 | 83 | is_training = (mode == tf.estimator.ModeKeys.TRAIN) 84 | is_testing = (mode == tf.estimator.ModeKeys.PREDICT) 85 | bert_config = modeling.BertConfig.from_json_file(bert_config_file) 86 | tag_model = VerfiyMRC(params, bert_config) 87 | 88 | # input_ids,labels,token_type_ids_list,query_len_list,text_length_list,is_training,is_testing=False 89 | if is_testing: 90 | pred_start_ids, pred_end_ids, weight, predict_start_prob, predict_end_prob, has_answer_prob = tag_model( 91 | input_ids, start_labels, end_labels, token_type_id_list, query_length_list, text_length_list, 92 | has_answer_label, is_training, is_testing) 93 | # predict_ids,weight,predict_prob = tag_model(input_ids,labels,token_type_id_list,query_length_list,text_length_list,is_training,is_testing) 94 | else: 95 | loss, pred_start_ids, pred_end_ids, weight, predict_start_prob, predict_end_prob, has_answer_prob = tag_model( 96 | input_ids, start_labels, end_labels, token_type_id_list, query_length_list, text_length_list, 97 | has_answer_label, is_training) 98 | # loss,predict_ids,weight,predict_prob = tag_model(input_ids,labels,token_type_id_list,query_length_list,text_length_list,is_training,is_testing) 99 | 100 | # def metric_fn(label_ids, pred_ids): 101 | # return { 102 | # 'precision': precision(label_ids, pred_ids, params["num_labels"]), 103 | # 'recall': recall(label_ids, pred_ids, params["num_labels"]), 104 | # 'f1': f1(label_ids, pred_ids, params["num_labels"]) 105 | # } 106 | # 107 | # eval_metrics = metric_fn(labels, pred_ids) 108 | tvars = tf.trainable_variables() 109 | # 加载BERT模型 110 | if init_checkpoints: 111 | (assignment_map, initialized_variable_names) = \ 112 | modeling.get_assignment_map_from_checkpoint(tvars, 113 | init_checkpoints) 114 | tf.train.init_from_checkpoint(init_checkpoints, assignment_map) 115 | output_spec = None 116 | # f1_score_val, f1_update_op_val = f1(labels=labels, predictions=pred_ids, num_classes=params["num_labels"], 117 | # weights=weight) 118 | 119 | if mode == tf.estimator.ModeKeys.TRAIN: 120 | train_op = optimization.create_optimizer(loss, args.lr, params["train_steps"], params["num_warmup_steps"], 121 | args.clip_norm) 122 | hook_dict = {} 123 | # precision_score, precision_update_op = precision(labels=labels, predictions=pred_ids, 124 | # num_classes=params["num_labels"], weights=weight) 125 | # 126 | # recall_score, recall_update_op = recall(labels=labels, 127 | # predictions=pred_ids, num_classes=params["num_labels"], 128 | # weights=weight) 129 | hook_dict['loss'] = loss 130 | hook_dict['global_steps'] = tf.train.get_or_create_global_step() 131 | logging_hook = tf.train.LoggingTensorHook( 132 | hook_dict, every_n_iter=args.print_log_steps) 133 | 134 | output_spec = tf.estimator.EstimatorSpec( 135 | mode=mode, 136 | loss=loss, 137 | train_op=train_op, 138 | training_hooks=[logging_hook]) 139 | 140 | elif mode == tf.estimator.ModeKeys.EVAL: 141 | has_answer_pred = tf.where(has_answer_prob > 0.5, tf.ones_like(has_answer_prob), 142 | tf.zeros_like(has_answer_prob)) 143 | 144 | # pred_ids = tf.argmax(logits, axis=-1, output_type=tf.int32) 145 | # weight = tf.sequence_mask(text_length_list) 146 | # precision_score, precision_update_op = precision(labels=labels,predictions=pred_ids,num_classes=params["num_labels"],weights=weight) 147 | # 148 | # recall_score, recall_update_op =recall(labels=labels, 149 | # predictions=pred_ids,num_classes=params["num_labels"],weights=weight) 150 | # def metric_fn(per_example_loss, label_ids, probabilities): 151 | 152 | # logits_split = tf.split(probabilities, params["num_labels"], axis=-1) 153 | # label_ids_split = tf.split(label_ids, params["num_labels"], axis=-1) 154 | # # metrics change to auc of every class 155 | # eval_dict = {} 156 | # for j, logits in enumerate(logits_split): 157 | # label_id_ = tf.cast(label_ids_split[j], dtype=tf.int32) 158 | # current_auc, update_op_auc = tf.metrics.auc(label_id_, logits) 159 | # eval_dict[str(j)] = (current_auc, update_op_auc) 160 | # eval_dict['eval_loss'] = tf.metrics.mean(values=per_example_loss) 161 | # return eval_dict 162 | # eval_metrics = metric_fn(per_example_loss, labels, pred_ids) 163 | f1_start_val, f1_update_op_val = f1(labels=start_labels, predictions=pred_start_ids, num_classes=2, 164 | weights=weight, average="macro") 165 | f1_end_val, f1_end_update_op_val = f1(labels=end_labels, predictions=pred_end_ids, num_classes=2, 166 | weights=weight, average="macro") 167 | # f1_val,f1_update_op_val = f1(labels=labels,predictions=predict_ids,num_classes=3,weights=weight,average="macro") 168 | has_answer_label = tf.cast(has_answer_label, tf.float32) 169 | f1_has_val, f1_has_update_op_val = f1(labels=has_answer_label, predictions=has_answer_pred, num_classes=2) 170 | 171 | # f1_score_val_micro,f1_update_op_val_micro = f1(labels=labels,predictions=pred_ids,num_classes=params["num_labels"],weights=weight,average="micro") 172 | 173 | # acc_score_val,acc_score_op_val = tf.metrics.accuracy(labels=labels,predictions=pred_ids,weights=weight) 174 | # eval_loss = tf.metrics.mean_squared_error(labels=labels, predictions=pred_ids,weights=weight) 175 | 176 | eval_metric_ops = { 177 | "f1_start_macro": (f1_start_val, f1_update_op_val), 178 | "f1_end_macro": (f1_end_val, f1_end_update_op_val), 179 | "f1_has_answer_macro": (f1_has_val, f1_has_update_op_val), 180 | "eval_loss": tf.metrics.mean(values=loss)} 181 | 182 | # eval_metric_ops = { 183 | # "f1_macro":(f1_val,f1_update_op_val), 184 | # "eval_loss":tf.metrics.mean(values=loss)} 185 | 186 | # eval_hook_dict = {"f1":f1_score_val,"loss":loss} 187 | 188 | # eval_logging_hook = tf.train.LoggingTensorHook( 189 | # at_end=True,every_n_iter=args.print_log_steps) 190 | output_spec = tf.estimator.EstimatorSpec( 191 | eval_metric_ops=eval_metric_ops, 192 | mode=mode, 193 | loss=loss 194 | ) 195 | else: 196 | output_spec = tf.estimator.EstimatorSpec( 197 | mode=mode, 198 | predictions={"start_ids": pred_start_ids, "end_ids": pred_end_ids, "start_probs": predict_start_prob, 199 | "end_probs": predict_end_prob, "has_answer_probs": has_answer_prob} 200 | ) 201 | # output_spec = tf.estimator.EstimatorSpec( 202 | # mode=mode, 203 | # predictions={"pred_ids":predict_ids,"pred_probs":predict_prob} 204 | # ) 205 | return output_spec 206 | 207 | return model_fn 208 | -------------------------------------------------------------------------------- /models/layers/__init__.py: -------------------------------------------------------------------------------- 1 | __author__ = 'lenovo' 2 | -------------------------------------------------------------------------------- /models/tf_metrics.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | from tensorflow.python.ops.metrics_impl import _streaming_confusion_matrix 4 | 5 | 6 | def precision(labels, predictions, num_classes, pos_indices=None, 7 | weights=None, average='micro'): 8 | """Multi-class precision metric for Tensorflow 9 | 10 | Parameters 11 | ---------- 12 | labels : Tensor of tf.int32 or tf.int64 13 | The true labels 14 | predictions : Tensor of tf.int32 or tf.int64 15 | The predictions, same shape as labels 16 | num_classes : int 17 | The number of classes 18 | pos_indices : list of int, optional 19 | The indices of the positive classes, default is all 20 | weights : Tensor of tf.int32, optional 21 | Mask, must be of compatible shape with labels 22 | average : str, optional 23 | 'micro': counts the total number of true positives, false 24 | positives, and false negatives for the classes in 25 | `pos_indices` and infer the metric from it. 26 | 'macro': will compute the metric separately for each class in 27 | `pos_indices` and average. Will not account for class 28 | imbalance. 29 | 'weighted': will compute the metric separately for each class in 30 | `pos_indices` and perform a weighted average by the total 31 | number of true labels for each class. 32 | 33 | Returns 34 | ------- 35 | tuple of (scalar float Tensor, update_op) 36 | """ 37 | cm, op = _streaming_confusion_matrix( 38 | labels, predictions, num_classes, weights) 39 | pr, _, _ = metrics_from_confusion_matrix( 40 | cm, pos_indices, average=average) 41 | op, _, _ = metrics_from_confusion_matrix( 42 | op, pos_indices, average=average) 43 | return (pr, op) 44 | 45 | 46 | def recall(labels, predictions, num_classes, pos_indices=None, weights=None, 47 | average='micro'): 48 | """Multi-class recall metric for Tensorflow 49 | 50 | Parameters 51 | ---------- 52 | labels : Tensor of tf.int32 or tf.int64 53 | The true labels 54 | predictions : Tensor of tf.int32 or tf.int64 55 | The predictions, same shape as labels 56 | num_classes : int 57 | The number of classes 58 | pos_indices : list of int, optional 59 | The indices of the positive classes, default is all 60 | weights : Tensor of tf.int32, optional 61 | Mask, must be of compatible shape with labels 62 | average : str, optional 63 | 'micro': counts the total number of true positives, false 64 | positives, and false negatives for the classes in 65 | `pos_indices` and infer the metric from it. 66 | 'macro': will compute the metric separately for each class in 67 | `pos_indices` and average. Will not account for class 68 | imbalance. 69 | 'weighted': will compute the metric separately for each class in 70 | `pos_indices` and perform a weighted average by the total 71 | number of true labels for each class. 72 | 73 | Returns 74 | ------- 75 | tuple of (scalar float Tensor, update_op) 76 | """ 77 | cm, op = _streaming_confusion_matrix( 78 | labels, predictions, num_classes, weights) 79 | _, re, _ = metrics_from_confusion_matrix( 80 | cm, pos_indices, average=average) 81 | _, op, _ = metrics_from_confusion_matrix( 82 | op, pos_indices, average=average) 83 | return (re, op) 84 | 85 | 86 | def f1(labels, predictions, num_classes, pos_indices=None, weights=None, 87 | average='micro'): 88 | return fbeta(labels, predictions, num_classes, pos_indices, weights, 89 | average) 90 | 91 | 92 | def fbeta(labels, predictions, num_classes, pos_indices=None, weights=None, 93 | average='micro', beta=1): 94 | """Multi-class fbeta metric for Tensorflow 95 | 96 | Parameters 97 | ---------- 98 | labels : Tensor of tf.int32 or tf.int64 99 | The true labels 100 | predictions : Tensor of tf.int32 or tf.int64 101 | The predictions, same shape as labels 102 | num_classes : int 103 | The number of classes 104 | pos_indices : list of int, optional 105 | The indices of the positive classes, default is all 106 | weights : Tensor of tf.int32, optional 107 | Mask, must be of compatible shape with labels 108 | average : str, optional 109 | 'micro': counts the total number of true positives, false 110 | positives, and false negatives for the classes in 111 | `pos_indices` and infer the metric from it. 112 | 'macro': will compute the metric separately for each class in 113 | `pos_indices` and average. Will not account for class 114 | imbalance. 115 | 'weighted': will compute the metric separately for each class in 116 | `pos_indices` and perform a weighted average by the total 117 | number of true labels for each class. 118 | beta : int, optional 119 | Weight of precision in harmonic mean 120 | 121 | Returns 122 | ------- 123 | tuple of (scalar float Tensor, update_op) 124 | """ 125 | cm, op = _streaming_confusion_matrix( 126 | labels, predictions, num_classes, weights) 127 | _, _, fbeta = metrics_from_confusion_matrix( 128 | cm, pos_indices, average=average, beta=beta) 129 | _, _, op = metrics_from_confusion_matrix( 130 | op, pos_indices, average=average, beta=beta) 131 | return (fbeta, op) 132 | 133 | 134 | def safe_div(numerator, denominator): 135 | """Safe division, return 0 if denominator is 0""" 136 | numerator, denominator = tf.to_float(numerator), tf.to_float(denominator) 137 | zeros = tf.zeros_like(numerator, dtype=numerator.dtype) 138 | denominator_is_zero = tf.equal(denominator, zeros) 139 | return tf.where(denominator_is_zero, zeros, numerator / denominator) 140 | 141 | 142 | def pr_re_fbeta(cm, pos_indices, beta=1): 143 | """Uses a confusion matrix to compute precision, recall and fbeta""" 144 | num_classes = cm.shape[0] 145 | neg_indices = [i for i in range(num_classes) if i not in pos_indices] 146 | cm_mask = np.ones([num_classes, num_classes]) 147 | cm_mask[neg_indices, neg_indices] = 0 148 | diag_sum = tf.reduce_sum(tf.diag_part(cm * cm_mask)) 149 | 150 | cm_mask = np.ones([num_classes, num_classes]) 151 | cm_mask[:, neg_indices] = 0 152 | tot_pred = tf.reduce_sum(cm * cm_mask) 153 | 154 | cm_mask = np.ones([num_classes, num_classes]) 155 | cm_mask[neg_indices, :] = 0 156 | tot_gold = tf.reduce_sum(cm * cm_mask) 157 | 158 | pr = safe_div(diag_sum, tot_pred) 159 | re = safe_div(diag_sum, tot_gold) 160 | fbeta = safe_div((1. + beta**2) * pr * re, beta**2 * pr + re) 161 | 162 | return pr, re, fbeta 163 | 164 | 165 | def metrics_from_confusion_matrix(cm, pos_indices=None, average='micro', 166 | beta=1): 167 | """Precision, Recall and F1 from the confusion matrix 168 | 169 | Parameters 170 | ---------- 171 | cm : tf.Tensor of type tf.int32, of shape (num_classes, num_classes) 172 | The streaming confusion matrix. 173 | pos_indices : list of int, optional 174 | The indices of the positive classes 175 | beta : int, optional 176 | Weight of precision in harmonic mean 177 | average : str, optional 178 | 'micro', 'macro' or 'weighted' 179 | """ 180 | num_classes = cm.shape[0] 181 | if pos_indices is None: 182 | pos_indices = [i for i in range(num_classes)] 183 | 184 | if average == 'micro': 185 | return pr_re_fbeta(cm, pos_indices, beta) 186 | elif average in {'macro', 'weighted'}: 187 | precisions, recalls, fbetas, n_golds = [], [], [], [] 188 | for idx in pos_indices: 189 | pr, re, fbeta = pr_re_fbeta(cm, [idx], beta) 190 | precisions.append(pr) 191 | recalls.append(re) 192 | fbetas.append(fbeta) 193 | cm_mask = np.zeros([num_classes, num_classes]) 194 | cm_mask[idx, :] = 1 195 | n_golds.append(tf.to_float(tf.reduce_sum(cm * cm_mask))) 196 | 197 | if average == 'macro': 198 | pr = tf.reduce_mean(precisions) 199 | re = tf.reduce_mean(recalls) 200 | fbeta = tf.reduce_mean(fbetas) 201 | return pr, re, fbeta 202 | if average == 'weighted': 203 | n_gold = tf.reduce_sum(n_golds) 204 | pr_sum = sum(p * n for p, n in zip(precisions, n_golds)) 205 | pr = safe_div(pr_sum, n_gold) 206 | re_sum = sum(r * n for r, n in zip(recalls, n_golds)) 207 | re = safe_div(re_sum, n_gold) 208 | fbeta_sum = sum(f * n for f, n in zip(fbetas, n_golds)) 209 | fbeta = safe_div(fbeta_sum, n_gold) 210 | return pr, re, fbeta 211 | 212 | else: 213 | raise NotImplementedError() 214 | -------------------------------------------------------------------------------- /models/utils.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import tensorflow_probability as tfp 3 | import torch.nn as nn 4 | nn.KLDivLoss() 5 | # tf.enable_eager_execution() 6 | 7 | # batch_size_tensor = tf.Variable(tf.range(0,4)) 8 | # seq_max_len_tensor = tf.range(0,3) 9 | # batch_zeros = tf.Variable(tf.zeros((4,3,3),dtype=tf.int32)) 10 | # print(batch_zeros.get_shape().as_list()) 11 | # batch_size_tensor = tf.expand_dims(batch_size_tensor,1) 12 | # batch_size_tensor = tf.tile(batch_size_tensor,[1,3]) 13 | # # batch_size_tensor = tf.expand_dims(batch_size_tensor,1) 14 | # # res = tf.add(batch_size_tensor,batch_zeros) 15 | # indices = tf.Variable([[[0]], [[0]], [[0]], [[0]]]) 16 | # # res = tf.scatter_nd_update(batch_zeros, indices, batch_size_tensor) 17 | # seq_max_len_tensor = tf.expand_dims(seq_max_len_tensor,0) 18 | # seq_max_len_tensor = tf.tile(seq_max_len_tensor,[4,1]) 19 | # # batch_size_tensor = tf.expand_dims(batch_size_tensor,1) 20 | # # res = tf.add(seq_max_len_tensor,batch_size_tensor) 21 | # const_var = tf.constant([[1,0,2],[0,0,1],[1,1,2],[2,1,0]]) 22 | # # print(batch_zeros[:,:,0]) 23 | # batch_zeros[:,:,0].assign(batch_size_tensor) 24 | # batch_zeros[:,:,1].assign(seq_max_len_tensor) 25 | # batch_zeros[:,:,2].assign(const_var) 26 | # print(batch_zeros) 27 | # seq_zeros = tf.zeros((4,3),dtype=tf.int32) 28 | # batch_size_tensor = tf.expand_dims(batch_size_tensor,1) 29 | # res = tf.concat([batch_size_tensor,seq_zeros],axis=1) 30 | # print(res) 31 | # # res = tf.tile(res,[1,3]) 32 | # # # seq_max_len_tensor = tf.expand_dims() 33 | # # res = tf.add(res,seq_zeros) 34 | # with tf.Session() as sess: 35 | # print(sess.run(res)) 36 | # # print(res) 37 | 38 | # tf.enable_eager_execution() 39 | # 40 | # data = tf.Variable([[2], 41 | # [3], 42 | # [4], 43 | # [5], 44 | # [6]]) 45 | # 46 | # cond = tf.where(tf.less(data, 5)) # update value less than 5 47 | # match_data = tf.gather_nd(data, cond) 48 | # square_data = tf.square(match_data) # square value less than 5 49 | # 50 | # data = tf.scatter_nd_update(data, cond, square_data) 51 | # 52 | # print(data) 53 | 54 | def cal_binary_dsc_loss(logits,labels,seq_mask,num_labels,one_hot=True,smoothing_lambda=1.0): 55 | # 这里暂时不用mask,因为mask的地方,label都是0,会被忽略掉 56 | if one_hot: 57 | labels = tf.one_hot(labels, depth=num_labels, dtype=tf.float32) 58 | else: 59 | labels = tf.expand_dims(labels,axis=-1) 60 | # seq_mask = tf.cast(seq_mask, tf.float32) 61 | predict_prob = tf.nn.softmax(logits, axis=-1, name="predict_prob") 62 | pos_prob = predict_prob[:, :, 1] 63 | neg_prob = predict_prob[:, :, 0] 64 | pos_label = labels[:, :, 1] 65 | nominator = neg_prob * pos_prob * pos_label 66 | denominator = neg_prob * pos_prob + pos_label 67 | loss = (nominator + smoothing_lambda)/(denominator + smoothing_lambda) 68 | loss = 1. - loss 69 | loss = tf.reduce_sum(loss,axis=-1) 70 | loss = tf.reduce_mean(loss) 71 | return loss 72 | 73 | def dice_dsc_loss(logits,labels,text_length_list,seq_mask,slot_label_num,smoothing_lambda=1.0): 74 | """ 75 | dice loss dsc 76 | :param logits: [batch_size,time_step,num_class] 77 | :param labels: [batch_size,time_step] 78 | :param seq_length:[batch_size] 79 | :return: 80 | """ 81 | 82 | predict_prob = tf.nn.softmax(logits,axis=-1,name="predict_prob") 83 | label_one_hot = tf.one_hot(labels, depth=slot_label_num, axis=-1) 84 | # seq_mask = tf.sequence_mask(seq_mask) 85 | # seq_mask = tf.cast(seq_mask,dtype=tf.float32) 86 | # batch_size_tensor = tf.range(0,tf.shape(logits)[0]) 87 | # seq_max_len_tensor = tf.range(0,tf.shape(logits)[1]) 88 | # # batch_size_tensor = tf.expand_dims(batch_size_tensor,1) 89 | # seq_max_len_tensor = tf.expand_dims(seq_max_len_tensor,axis=0) 90 | # seq_max_len_tensor = tf.tile(seq_max_len_tensor,[tf.shape(logits)[0],1]) 91 | # seq_max_len_tensor = tf.expand_dims(seq_max_len_tensor,axis=-1) 92 | # batch_size_tensor = tf.expand_dims(batch_size_tensor, 1) 93 | # batch_size_tensor = tf.tile(batch_size_tensor, [1, tf.shape(logits)[1]]) 94 | # batch_size_tensor = tf.expand_dims(batch_size_tensor, -1) 95 | # # batch_zeros_result = tf.zeros((tf.shape(logits)[0],tf.shape(logits)[1],3), dtype=tf.int32) 96 | # labels = tf.expand_dims(labels,axis=-1) 97 | # gather_idx = tf.concat([batch_size_tensor,seq_max_len_tensor,labels],axis=-1) 98 | # gather_result = tf.gather_nd(predict_prob,gather_idx) 99 | # # gather_result = gather_result 100 | # neg_prob = 1. - gather_result 101 | # neg_prob = neg_prob 102 | # # gather_result = gather_result 103 | # cost = 1. - neg_prob*gather_result/(neg_prob*gather_result+1.) 104 | # cost = cost * seq_mask 105 | # cost = tf.reduce_sum(cost,axis=-1) 106 | # cost = tf.reduce_mean(cost) 107 | # return cost 108 | # neg_prob = 1.- predict_prob 109 | nominator = 2*predict_prob*label_one_hot+smoothing_lambda 110 | denomiator = predict_prob*predict_prob+label_one_hot*label_one_hot+smoothing_lambda 111 | result = nominator/denomiator 112 | result = 1. - result 113 | result = tf.reduce_sum(result,axis=-1) 114 | result = result * seq_mask 115 | result = tf.reduce_sum(result,axis=-1,keep_dims=True) 116 | result = result/tf.cast(text_length_list,tf.float32) 117 | result = tf.reduce_mean(result) 118 | return result 119 | # cost = cal_binary_dsc_loss(predict_prob[:, :, 0],label_one_hot[:, :, 0],seq_mask) 120 | # for i in range(1,slot_label_num): 121 | # cost += cal_binary_dsc_loss(predict_prob[:, :, i],label_one_hot[:, :, i],seq_mask) 122 | # # print(denominator) 123 | # cost = cost/float(slot_label_num) 124 | # cost = tf.reduce_mean(cost) 125 | # return cost 126 | 127 | def vanilla_dsc_loss(logits,labels,seq_mask,num_labels,smoothing_lambda=1.0,one_hot=True): 128 | if one_hot: 129 | labels = tf.one_hot(labels, depth=num_labels, dtype=tf.float32) 130 | else: 131 | labels = tf.expand_dims(labels,axis=-1) 132 | 133 | predict_prob = tf.nn.softmax(logits, axis=-1, name="predict_prob") 134 | pos_prob = predict_prob[:, :, 1] 135 | neg_prob = predict_prob[:,:,0] 136 | pos_label = labels[:, :, 1] 137 | neg_label = labels[:,:,0] 138 | denominator = 2 * pos_prob * pos_label + neg_prob * pos_label + pos_prob * neg_label + smoothing_lambda 139 | nominator = 2 * pos_prob * pos_label + smoothing_lambda 140 | loss = 1. - nominator / denominator 141 | loss = loss * tf.cast(seq_mask,tf.float32) 142 | loss = tf.reduce_sum(loss,axis=-1) 143 | loss = tf.reduce_mean(loss,axis=0) 144 | return loss 145 | 146 | def dl_dsc_loss(logits,labels,text_length_list,seq_mask,slot_label_num,smoothing_lambda=1.0,gamma=2.0): 147 | predict_prob = tf.nn.softmax(logits, axis=-1, name="predict_prob") 148 | label_one_hot = tf.one_hot(labels, depth=slot_label_num, axis=-1) 149 | # neg_prob = 1.- predict_prob 150 | # neg_prob = tf.pow(neg_prob,gamma) 151 | pos_prob = predict_prob[:,:,1] 152 | pos_prob_squre = tf.pow(pos_prob,2) 153 | pos_label = label_one_hot[:,:,1] 154 | pos_label_squre = tf.pow(pos_label,2) 155 | nominator = 2*pos_prob_squre*pos_label_squre+smoothing_lambda 156 | denominator = pos_label_squre+pos_label_squre+smoothing_lambda 157 | result = nominator/denominator 158 | result = 1.-result 159 | result = result * tf.cast(seq_mask,tf.float32) 160 | result = tf.reduce_sum(result, axis=-1, keep_dims=True) 161 | result = result / tf.cast(text_length_list, tf.float32) 162 | result = tf.reduce_mean(result) 163 | return result 164 | 165 | def ce_loss(logits,labels,mask,num_labels,one_hot=True,imbalanced_ratio=2): 166 | if one_hot: 167 | labels = tf.one_hot(labels, depth=num_labels, dtype=tf.float32) 168 | else: 169 | labels = tf.expand_dims(labels,axis=-1) 170 | probs = tf.nn.softmax(logits,axis=-1) 171 | pos_probs = probs[:,:,1] 172 | pos_probs = tf.pow(pos_probs,imbalanced_ratio) 173 | pos_probs = tf.expand_dims(pos_probs, axis=-1) 174 | neg_probs = 1. - pos_probs 175 | probs = tf.concat([neg_probs,pos_probs],axis=-1) 176 | print(probs) 177 | log_probs = tf.log(probs+1e-7) 178 | per_example_loss = -tf.reduce_sum(tf.cast(labels,tf.float32) * log_probs, axis=-1) 179 | per_example_loss = per_example_loss * tf.cast(mask, tf.float32) 180 | loss = tf.reduce_sum(per_example_loss, axis=-1) 181 | loss = tf.reduce_mean(loss) 182 | return loss 183 | 184 | def focal_loss(logits,labels,mask,num_labels,one_hot=True,lambda_param=1.5): 185 | probs = tf.nn.softmax(logits,axis=-1) 186 | pos_probs = probs[:,:,1] 187 | prob_label_pos = tf.where(tf.equal(labels,1),pos_probs,tf.ones_like(pos_probs)) 188 | prob_label_neg = tf.where(tf.equal(labels,0),pos_probs,tf.zeros_like(pos_probs)) 189 | loss = tf.pow(1. - prob_label_pos,lambda_param)*tf.log(prob_label_pos + 1e-7) + \ 190 | tf.pow(prob_label_neg,lambda_param)*tf.log(1. - prob_label_neg + 1e-7) 191 | loss = -loss * tf.cast(mask,tf.float32) 192 | loss = tf.reduce_sum(loss,axis=-1,keepdims=True) 193 | # loss = loss/tf.cast(tf.reduce_sum(mask,axis=-1),tf.float32) 194 | loss = tf.reduce_mean(loss) 195 | return loss 196 | 197 | def span_loss(logits,labels,mask): 198 | probs = tf.nn.softmax(logits,axis=1) 199 | arg_max_label = tf.cast(tf.where(probs > 0.5,tf.ones_like(labels),tf.zeros_like(labels)),tf.int32) 200 | arg_max_label *= mask 201 | 202 | 203 | def test(seq_length,all_length): 204 | mask = tf.cast(tf.sequence_mask(seq_length),tf.int32) 205 | mask = mask * -1 206 | left = tf.zeros((4,3),dtype=tf.int32) 207 | all = tf.concat((mask,left),axis=-1) 208 | all_mask = tf.cast(tf.sequence_mask(all_length),tf.int32) 209 | all = all + all_mask 210 | return all 211 | def test2(): 212 | const_var = tf.constant([[0,0,0,1,0,1], [0,0, 0, 1,0,0], [0,0,1,0,0,0], [0, 1, 0,1,1,0]]) 213 | a = tf.expand_dims(const_var,axis=-1) 214 | b = tf.where(a) 215 | return b 216 | 217 | if __name__ == "__main__": 218 | # const_var = tf.constant([[0,0,0,1,0,1], [0,0, 0, 1,0,0], [0,0,1,0,0,0], [0, 1, 0,1,1,0]]) 219 | # logits_tensor = tf.Variable([[[0.2,0.4],[0.1,0.6],[0.88,0.4],[0.2,0.4],[0.1,0.6],[0.88,0.4]],[[0.2,0.4],[0.4,0.6],[0.4,23],[0.2,0.4],[0.4,0.6],[0.4,23]],[[0.1,0.2],[0.1,0.4],[0.88,0.4],[0.1,0.2],[0.1,0.4],[0.88,0.4]],[[0.1,0.4],[0.4,0.6],[0.88,23],[0.1,0.4],[0.4,0.6],[0.88,23]]]) 220 | # seq_length_s = tf.Variable([3,1,2,1]) 221 | # all_length_s = tf.Variable([6,4,3,5]) 222 | # mask_all = test(seq_length_s,all_length_s) 223 | # loss,labels = ce_loss(logits_tensor,const_var,mask_all,2) 224 | 225 | result = test2() 226 | 227 | 228 | # cost,b = focal_dsc_loss(logits_tensor,const_var,seq_length_s,"",3) 229 | # print(selected_probs) 230 | # print(idx) 231 | # gathered_prob = tf.gather_nd(predict_prob) 232 | with tf.Session() as sess: 233 | sess.run(tf.global_variables_initializer()) 234 | # print(sess.run(orig_prob)) 235 | # print(sess.run(idx)) 236 | # print(sess.run(selected_probs)) 237 | print(sess.run(result)) 238 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | tensorflow-gpu==1.12.0 2 | bert4keras 3 | numpy 4 | gensim 5 | -------------------------------------------------------------------------------- /run_event.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' # ERROR 4 | # os.environ["CUDA_VISIBLE_DEVICES"] = "0" 5 | import tensorflow as tf 6 | from argparse import ArgumentParser 7 | from train_helper import run_event_role_mrc, run_event_classification 8 | from train_helper import run_event_binclassification, run_event_verify_role_mrc 9 | import numpy as np 10 | 11 | np.set_printoptions(threshold=np.inf) 12 | tf.logging.set_verbosity(tf.logging.INFO) 13 | 14 | 15 | def main(): 16 | parser = ArgumentParser() 17 | parser.add_argument("--model_type", default="role", type=str) 18 | parser.add_argument("--dropout_prob", default=0.2, type=float) 19 | parser.add_argument("--rnn_units", default=256, type=int) 20 | parser.add_argument("--epochs", default=15, type=int) 21 | # bert lr 22 | parser.add_argument("--lr", default=1e-5, type=float) 23 | # parser.add_argument("--lr", default=0.001, type=float) 24 | parser.add_argument("--clip_norm", default=5.0, type=float) 25 | parser.add_argument("--train_batch_size", default=16, type=int) 26 | parser.add_argument("--valid_batch_size", default=32, type=int) 27 | parser.add_argument("--shuffle_buffer", default=128, type=int) 28 | parser.add_argument("--do_train", action='store_true', default=True) 29 | parser.add_argument("--do_test", action='store_true', default=True) 30 | parser.add_argument("--gen_new_data", action='store_true', default=False) 31 | parser.add_argument("--tolerant_steps", default=200, type=int) 32 | parser.add_argument("--run_hook_steps", default=100, type=int) 33 | parser.add_argument("--num_layers", default=3, type=int) 34 | parser.add_argument("--hidden_units", default=128, type=int) 35 | parser.add_argument("--print_log_steps", default=50, type=int) 36 | parser.add_argument("--decay_epoch", default=12, type=int) 37 | parser.add_argument("--pre_buffer_size", default=1, type=int) 38 | parser.add_argument("--bert_used", default=False, action='store_true') 39 | parser.add_argument("--gpu_nums", default=1, type=int) 40 | parser.add_argument("--model_checkpoint_dir", type=str, default="role_bert_model_dir") 41 | parser.add_argument("--model_pb_dir", type=str, default="role_bert_model_pb") 42 | parser.add_argument("--fold_index", type=int) 43 | 44 | args = parser.parse_args() 45 | if args.model_type == "role": 46 | run_event_role_mrc(args) 47 | elif args.model_type == "classification": 48 | run_event_classification(args) 49 | elif args.model_type == "binary": 50 | run_event_binclassification(args) 51 | elif args.model_type == "avmrc": 52 | run_event_verify_role_mrc(args) 53 | # if args.bert_used: 54 | # if args.model_type == "bert_mrc": 55 | # if args.theseus_compressed: 56 | # print(args.model_type) 57 | # run_bert_mrc_theseus(args) 58 | # else: 59 | # run_bert_mrc(args) 60 | # else: 61 | # run_bert(args) 62 | # else: 63 | # if args.model_type == "lstm_crf" or args.model_type == "lstm_only": 64 | # run_train(args) 65 | # elif args.model_type=="lstm_cnn_crf": 66 | # run_train_cnn(args) 67 | # else: 68 | # run_lan(args) 69 | # run_event_trigger_bert(args) 70 | 71 | 72 | if __name__ == '__main__': 73 | main() 74 | -------------------------------------------------------------------------------- /run_event_classification.sh: -------------------------------------------------------------------------------- 1 | python run_event.py --model_type classification --fold_index 0 --dropout_prob 0.2 --epochs 5 --lr 2e-6 --clip_norm 5.0 --train_batch_size 2 --valid_batch_size 4 --shuffle_buffer 128 --do_train --do_test --tolerant_steps 500 --run_hook_steps 50 --print_log_steps 50 --decay_epoch 10 --pre_buffer_size 16 --bert_used --gpu_nums 1 --model_checkpoint_dir type_class_bert_model_dir --model_pb_dir type_class_bert_model_pb -------------------------------------------------------------------------------- /run_event_role.sh: -------------------------------------------------------------------------------- 1 | python run_event.py --fold_index 0 --model_type role --dropout_prob 0.2 --epochs 6 --lr 9e-6 --clip_norm 5.0 --train_batch_size 4 --valid_batch_size 8 --shuffle_buffer 128 --do_train --do_test --tolerant_steps 500 --run_hook_steps 50 --print_log_steps 500 --decay_epoch 10 --pre_buffer_size 16 --bert_used --gpu_nums 1 --model_checkpoint_dir role_bert_model_dir --model_pb_dir role_bert_model_pb 2 | -------------------------------------------------------------------------------- /run_retro_eav.sh: -------------------------------------------------------------------------------- 1 | python run_event.py --fold_index 0 --model_type binary --dropout_prob 0.2 --epochs 5 --lr 6e-6 --clip_norm 5.0 --train_batch_size 4 --valid_batch_size 8 --shuffle_buffer 128 --do_train --do_test --tolerant_steps 500 --run_hook_steps 50 --print_log_steps 500 --decay_epoch 10 --pre_buffer_size 16 --bert_used --gpu_nums 1 --model_checkpoint_dir role_verify_cls_bert_model_dir --model_pb_dir role_verify_cls_bert_model_pb 2 | -------------------------------------------------------------------------------- /run_retro_rolemrc.sh: -------------------------------------------------------------------------------- 1 | python run_event.py --fold_index 0 --model_type avmrc --dropout_prob 0.2 --epochs 8 --lr 8e-6 --clip_norm 5.0 --train_batch_size 4 --valid_batch_size 8 --shuffle_buffer 128 --do_train --do_test --tolerant_steps 500 --run_hook_steps 50 --print_log_steps 500 --decay_epoch 10 --pre_buffer_size 16 --bert_used --gpu_nums 1 --model_checkpoint_dir role_verify_avmrc_bert_model_dir --model_pb_dir role_verify_avmrc_bert_model_pb 2 | --------------------------------------------------------------------------------