├── a1.png ├── a2.png ├── .vscode └── settings.json ├── .dockerignore ├── code ├── simple_run.sh ├── utils.py ├── optimization.py ├── tokenization.py ├── run_biaffine_relation.py └── modeling.py ├── .gitignore └── README.md /a1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xueyouluo/biaffine-bert-relation-extract/HEAD/a1.png -------------------------------------------------------------------------------- /a2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xueyouluo/biaffine-bert-relation-extract/HEAD/a2.png -------------------------------------------------------------------------------- /.vscode/settings.json: -------------------------------------------------------------------------------- 1 | { 2 | "python.pythonPath": "/home/xueyou/.conda/envs/jason_py3/bin/python" 3 | } -------------------------------------------------------------------------------- /.dockerignore: -------------------------------------------------------------------------------- 1 | .git/ 2 | code/__pycache__ 3 | __pycache__/ 4 | user_data/models/ 5 | user_data/pretrain_tfrecords/ 6 | user_data/texts/ 7 | user_data/tcdata/ 8 | user_data/emb/ 9 | user_data/chinese_roberta_wwm_ext_L-12_H-768_A-12/ -------------------------------------------------------------------------------- /code/simple_run.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | export BERT_DIR=/nfs/users/xueyou/data/bert_pretrain/electra_180g_base 4 | export CONFIG_FILE=${BERT_DIR}/base_discriminator_config.json 5 | export INIT_CHECKPOINT=${BERT_DIR}/electra_180g_base.ckpt 6 | export DATA_DIR=/data/xueyou/data/corpus/task_data/LIC2019 7 | export SEED=20190525 8 | export OUTPUT_DIR=${DATA_DIR}/baseline 9 | export SPATIAL_DROPOUT=0. 10 | export EMBEDDING_DROPOUT=0. 11 | 12 | python run_biaffine_relation.py \ 13 | --vocab_file=vocab.txt \ 14 | --bert_config_file=${CONFIG_FILE} \ 15 | --init_checkpoint=${INIT_CHECKPOINT} \ 16 | --do_lower_case=True \ 17 | --max_seq_length=128 \ 18 | --train_batch_size=32 \ 19 | --learning_rate=4e-5 \ 20 | --num_train_epochs=5.0 \ 21 | --save_checkpoints_steps=1000 \ 22 | --do_train=false \ 23 | --do_eval=true \ 24 | --use_fgm=false \ 25 | --fgm_epsilon=0.8 \ 26 | --fgm_loss_ratio=1.0 \ 27 | --spatial_dropout=${SPATIAL_DROPOUT} \ 28 | --embedding_dropout=${EMBEDDING_DROPOUT} \ 29 | --head_lr_ratio=20.0 \ 30 | --biaffine_size=768 \ 31 | --electra=true \ 32 | --amp=true \ 33 | --seed=${SEED} \ 34 | --data_dir=${DATA_DIR} \ 35 | --output_dir=${OUTPUT_DIR} -------------------------------------------------------------------------------- /code/utils.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | import re 4 | import random 5 | import json 6 | import glob 7 | import codecs 8 | import os 9 | from tqdm import tqdm 10 | from collections import defaultdict 11 | 12 | np.random.seed(20190525) 13 | random.seed(20190525) 14 | 15 | def get_biaffine_predicate(pred_text,scores,label_list,predicate_labels,threshold=0): 16 | l = len(pred_text) 17 | size = l * (l+1) // 2 18 | def get_position(n,k): 19 | def prefix_sum(i): 20 | return (n + n - i) * (i + 1) // 2 21 | 22 | left,right=0,n 23 | while left < right: 24 | mid = (left + right) // 2 25 | if prefix_sum(mid) < k: 26 | left = mid + 1 27 | else: 28 | right = mid 29 | 30 | s = left 31 | e = k - s * n + s * (s + 1) // 2 32 | return (s,e) 33 | 34 | tags = [] 35 | entities = defaultdict(set) 36 | for pos, lpos in np.argwhere(scores > threshold): 37 | lb = label_list[lpos] 38 | s,e = get_position(l,pos) 39 | if 'EH2ET' in lb: 40 | if s <= e: 41 | entities[s].add((s,e,lb[:-5])) 42 | else: 43 | tags.append((s,e,lb)) 44 | 45 | results = [] 46 | for p in predicate_labels: 47 | Hs = [] 48 | Ho = [] 49 | T = set() 50 | for s,e,t in tags: 51 | tp = t[-5:] 52 | tt = t[:-5] 53 | if tt != p: 54 | continue 55 | 56 | if tp == 'SH2OH': 57 | Hs.extend(entities.get(s,[])) 58 | Ho.extend(entities.get(e,[])) 59 | if tp == 'OH2SH': 60 | Ho.extend(entities.get(s,[])) 61 | Hs.extend(entities.get(e,[])) 62 | if tp == 'ST2OT': 63 | T.add((s,e)) 64 | if tp == 'OT2ST': 65 | T.add((e,s)) 66 | 67 | for s in set(Hs): 68 | for o in set(Ho): 69 | if (s[1],o[1]) in T: 70 | results.append((s,p,o)) 71 | 72 | return results 73 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | tcdata/ 132 | user_data/* 133 | !user_data/extra_data 134 | !user_data/track3 135 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # 基于Biaffine结构的关系抽取模型 2 | 3 | 这一两个星期在研究关系抽取相关的内容,这里做一个小结。 4 | 5 | 主要的方法有pipeline形式和joint形式的,其中: 6 | 7 | - pipeline的先抽取实体,再对实体之间进行关系判断 8 | - joint则是直接将二者同时抽取出来,避免pipeline形式的误差累积以及更充分利用实体和关系之间的信息 9 | 10 | 个人原来做过pipeline形式的抽取,因此这里主要研究一下joint形式的抽取方法。 11 | 12 | ## 效果 13 | 14 | 对比了苏神的GPLinker,在LIC2019任务上验证集结果如下: 15 | 16 | | 模型 | F1 | 17 | | ------------------ | ------ | 18 | | CasRel | 0.8220 | 19 | | GPLinker(Standard) | 0.8272 | 20 | | Biaffine | 0.8297 | 21 | 22 | 具体指标: 23 | 24 | - precision 0.8124273240449063 25 | - recall 0.8477354452999187 26 | - f1 0.829705920456018 27 | 28 | ## 方法介绍 29 | 30 | ### 建模 31 | 32 | 如果了解过biaffine做NER抽取的应该知道,biaffine也是对token-pair进行建模,得到token-pair的信息。在NER中,我们最终得到的是[B,L,L,N]的tensor,代表了Batch中每个文本L*L个entity head(EH)-entity tail(ET)的标签信息。 33 | 34 | 为了获得关系信息,我们还需要得到SPO中subject head(SH)-object head(OH)和subject tail(ST)-object tail(OT)之间的联系。还有因为SPO中O可能出现在S的前面,因此还需要考虑OH-SH和OT-ST这类关系。这样根据SH-OH、OH-SH、ST-OT、OT-ST以及EH-ET的信息,我们可以判断句子中存在哪些关系。具体的话可以参考[TPLinker](https://github.com/131250208/TPlinker-joint-extraction)的建模方法。 35 | 36 | 一开始尝试的是TPLinker最开始的方法,分别对entity,spo的head、tail建模。设计了两种方案: 37 | 38 | 1. 第一种与做NER类似,只不过多加了predicate的biaffine识别 39 | 40 | ![结构1](a1.png) 41 | 42 | 43 | 2. 第二种改了一下,biaffine出来的结果是[B,L,L,N],其中N为hidden size,然后再分别过线性层转换为NER和predicate的标签预测 44 | 45 | ![结构2](a2.png) 46 | 47 | 这两种建模方案试下来,在NER识别上效果还可以,但是在predicate识别上效果非常差,基本没学到信息。通过对数据进行分析发现,predicate的标签非常稀疏,在大部分数据上部分标签基本没有,这也导致部分predicate的参数较难学到有用的信息。 48 | 49 | 在这条路上做了一些尝试效果不佳后,开始尝试tplinker_plus的思路。 50 | 51 | 这个思路挺有意思的,将多分类问题转换为多标签分类问题,具体来说就是对每个token-pair,判断它是否是: 52 | 53 | - 某类实体的EH-ET 54 | - 某类P的SH-OH/OH-SH 55 | - 某类P的ST-OT/OT-ST 56 | 57 | 转换为实体类别数量N + P类别数量M\*4个类别的多标签分类。在关系抽取里面可以不关心实体的具体类别,只要判断它是不是实体就好,这样就是1+M\*4个类别。 58 | 59 | ### 损失函数 60 | 61 | 对于多标签分类常见的是用binary-cross-entropy作为损失函数,这里用了苏神的[《将“softmax+交叉熵”推广到多标签分类问题》](https://kexue.fm/archives/7359)中提到的loss,可以解决类别不平衡的问题。具体来说就是下面这个形式。关于它的推导就请移步苏神的论文了。 62 | 63 | 64 | 65 | ### 解码 66 | 67 | 改成上面这种建模形式后,其实在原来的NER-biaffine基础上要做的改动就很小了,只需要将label改改,再加loss地方改改就OK了。 68 | 69 | 解码方面的流程与tplinker类似: 70 | 71 | - 收集所有的实体位置信息entities = set((s,e)) 72 | - 对每类P,我们先根据SH2OH或OH2SH,以及entities的信息,找到满足subject和object的起始位置的entity,再判断subject和object实体的结束位置是否满足ST2OT或OT2ST。 73 | - 只保留满足前面筛选条件的SPO对。 74 | 75 | ## 实验 76 | 77 | 使用哈工大开源的中文[electra-base](https://github.com/ymcui/Chinese-ELECTRA)模型作为预训练参数,设置head learning rate为20倍(即非BERT参数的学习率要乘以这个倍速),同时biaffine size从150调大到768(因为要包含的信息更多了),然后finetune了5个epoch。 78 | 79 | 没调参跑下来就得到上面的结果了。 80 | 81 | ## 总结 82 | 83 | 个人觉得biaffine结构还是比较简单的,只是加了两个dense层和一个W参数矩阵,没有在bert上面加复杂的骚操作,其实这个跟bert论文中的思路差不多,下游任务的时候尽可能都是加个简单的线性变换就好了。侧面说明BERT这种强大的预训练模型还是厉害,还有很多潜力可以去挖掘。 84 | 85 | 缺点就是基于token-pair形式的模型,它的输入数据空间复杂度还是比较高的$O(BL^2N)$ ,因此不适合用来抽取非常长的问题,尤其是抽取的关系是跨篇章结构的话,那就更难了。其次就是标签数据非常稀疏,即使我们只采用了上三角的数据,仍然有大量的0标签,如何解决这类稀疏标签的问题也值得研究。 86 | 87 | 88 | 89 | 90 | -------------------------------------------------------------------------------- /code/optimization.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Functions and classes related to optimization (weight updates).""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import re 22 | import tensorflow as tf 23 | 24 | def create_optimizer(loss, init_lr, num_train_steps, num_warmup_steps, hvd=None, amp=False, accumulation_step=1,freeze_bert=False, head_lr_ratio=1.0): 25 | """Creates an optimizer training op. 26 | 27 | Args: 28 | loss: training loss 29 | init_lr: initial learning rate 30 | num_train_steps: total training steps 31 | num_warmup_steps: warmup steps 32 | hvd: whether use hvd for distribute training 33 | amp: whether use auto-mix-precision to speed up training 34 | accumulation_step: gradient accumulation steps 35 | freeze_bert: whether to freeze bert variables 36 | head_lr_ratio: bert and head should have different learning rate 37 | """ 38 | global_step = tf.train.get_or_create_global_step() 39 | 40 | learning_rate = tf.constant(value=init_lr, shape=[], dtype=tf.float32) 41 | 42 | # Implements linear decay of the learning rate. 43 | learning_rate = tf.train.polynomial_decay( 44 | learning_rate, 45 | global_step, 46 | num_train_steps, 47 | end_learning_rate=0.0,#if not use_swa else init_lr/2, 48 | power=1.0, 49 | cycle=False) 50 | 51 | # Implements linear warmup. I.e., if global_step < num_warmup_steps, the 52 | # learning rate will be `global_step/num_warmup_steps * init_lr`. 53 | if num_warmup_steps: 54 | global_steps_int = tf.cast(global_step, tf.int32) 55 | warmup_steps_int = tf.constant(num_warmup_steps, dtype=tf.int32) 56 | 57 | global_steps_float = tf.cast(global_steps_int, tf.float32) 58 | warmup_steps_float = tf.cast(warmup_steps_int, tf.float32) 59 | 60 | warmup_percent_done = global_steps_float / warmup_steps_float 61 | warmup_learning_rate = init_lr * warmup_percent_done 62 | 63 | is_warmup = tf.cast(global_steps_int < warmup_steps_int, tf.float32) 64 | learning_rate = ( 65 | (1.0 - is_warmup) * learning_rate + is_warmup * warmup_learning_rate) 66 | 67 | # It is recommended that you use this optimizer for fine tuning, since this 68 | # is how the model was trained (note that the Adam m/v variables are NOT 69 | # loaded from init_checkpoint.) 70 | optimizer = AdamWeightDecayOptimizer( 71 | learning_rate=learning_rate, 72 | head_lr_ratio=head_lr_ratio, 73 | weight_decay_rate=0.01, 74 | beta_1=0.9, 75 | beta_2=0.999, 76 | epsilon=1e-6, 77 | exclude_from_weight_decay=["LayerNorm", "layer_norm", "bias"]) 78 | 79 | if hvd is not None: 80 | from horovod.tensorflow.compression import Compression 81 | optimizer = hvd.DistributedOptimizer(optimizer, sparse_as_dense = True, compression=Compression.fp16 if amp else Compression.none) 82 | 83 | if amp: 84 | loss_scaler = tf.train.experimental.DynamicLossScale(initial_loss_scale=2**32, increment_period=1000, multiplier=2.0) 85 | optimizer = tf.train.experimental.enable_mixed_precision_graph_rewrite(optimizer, loss_scaler) 86 | loss_scale_value = tf.identity(loss_scaler(), name="loss_scale") 87 | 88 | tvars = tf.trainable_variables() 89 | if freeze_bert: 90 | tvars = [var for var in tvars if 'bert' not in var.name] 91 | grads_and_vars = optimizer.compute_gradients(loss, tvars) 92 | 93 | if accumulation_step > 1: 94 | tf.logging.info('### Using Gradient Accumulation with {} ###'.format(accumulation_step)) 95 | 96 | local_step = tf.get_variable(name="local_step", shape=[], dtype=tf.int32, trainable=False, 97 | initializer=tf.zeros_initializer) 98 | batch_finite = tf.get_variable(name="batch_finite", shape=[], dtype=tf.bool, trainable=False, 99 | initializer=tf.ones_initializer) 100 | accum_vars = [tf.get_variable( 101 | name=tvar.name.split(":")[0] + "/accum", 102 | shape=tvar.shape.as_list(), 103 | dtype=tf.float32, 104 | trainable=False, 105 | initializer=tf.zeros_initializer()) for tvar in tf.trainable_variables()] 106 | 107 | reset_step = tf.cast(tf.math.equal(local_step % accumulation_step, 0), dtype=tf.bool) 108 | local_step = tf.cond(reset_step, lambda:local_step.assign(tf.ones_like(local_step)), lambda:local_step.assign_add(1)) 109 | 110 | grads_and_vars_and_accums = [(gv[0],gv[1],accum_vars[i]) for i, gv in enumerate(grads_and_vars) if gv[0] is not None] 111 | grads, tvars, accum_vars = list(zip(*grads_and_vars_and_accums)) 112 | 113 | all_are_finite = tf.reduce_all([tf.reduce_all(tf.is_finite(g)) for g in grads]) if amp else tf.constant(True, dtype=tf.bool) 114 | batch_finite = tf.cond(reset_step, 115 | lambda: batch_finite.assign(tf.math.logical_and(tf.constant(True, dtype=tf.bool), all_are_finite)), 116 | lambda: batch_finite.assign(tf.math.logical_and(batch_finite, all_are_finite))) 117 | 118 | # This is how the model was pre-trained. 119 | # ensure global norm is a finite number 120 | # to prevent clip_by_global_norm from having a hizzy fit. 121 | (clipped_grads, _) = tf.clip_by_global_norm( 122 | grads, clip_norm=1.0, 123 | use_norm=tf.cond( 124 | all_are_finite, 125 | lambda: tf.global_norm(grads), 126 | lambda: tf.constant(1.0))) 127 | 128 | accum_vars = tf.cond(reset_step, 129 | lambda: [accum_vars[i].assign(grad) for i, grad in enumerate(clipped_grads)], 130 | lambda: [accum_vars[i].assign_add(grad) for i, grad in enumerate(clipped_grads)]) 131 | 132 | def update(accum_vars): 133 | return optimizer.apply_gradients(list(zip(accum_vars, tvars))) 134 | 135 | update_step = tf.identity(tf.cast(tf.math.equal(local_step % accumulation_step, 0), dtype=tf.bool), name="update_step") 136 | update_op = tf.cond(update_step, 137 | lambda: update(accum_vars), lambda: tf.no_op()) 138 | 139 | new_global_step = tf.cond(tf.math.logical_and(update_step, 140 | tf.cast(hvd.allreduce(tf.cast(batch_finite, tf.int32)), tf.bool) if hvd is not None else batch_finite), 141 | lambda: global_step+1, 142 | lambda: global_step) 143 | new_global_step = tf.identity(new_global_step, name='step_update') 144 | train_op = tf.group(update_op, [global_step.assign(new_global_step)]) 145 | else: 146 | grads_and_vars = [(g, v) for g, v in grads_and_vars if g is not None] 147 | grads, tvars = list(zip(*grads_and_vars)) 148 | all_are_finite = tf.reduce_all( 149 | [tf.reduce_all(tf.is_finite(g)) for g in grads]) if amp else tf.constant(True, dtype=tf.bool) 150 | 151 | # This is how the model was pre-trained. 152 | # ensure global norm is a finite number 153 | # to prevent clip_by_global_norm from having a hizzy fit. 154 | (clipped_grads, _) = tf.clip_by_global_norm( 155 | grads, clip_norm=1.0, 156 | use_norm=tf.cond( 157 | all_are_finite, 158 | lambda: tf.global_norm(grads), 159 | lambda: tf.constant(1.0))) 160 | 161 | train_op = optimizer.apply_gradients( 162 | list(zip(clipped_grads, tvars))) 163 | 164 | new_global_step = tf.cond(all_are_finite, lambda: global_step + 1, lambda: global_step) 165 | new_global_step = tf.identity(new_global_step, name='step_update') 166 | train_op = tf.group(train_op, [global_step.assign(new_global_step)]) 167 | return train_op, learning_rate 168 | 169 | class AdamWeightDecayOptimizer(tf.train.Optimizer): 170 | """A basic Adam optimizer that includes "correct" L2 weight decay.""" 171 | 172 | def __init__(self, 173 | learning_rate, 174 | head_lr_ratio=1.0, 175 | weight_decay_rate=0.0, 176 | beta_1=0.9, 177 | beta_2=0.999, 178 | epsilon=1e-6, 179 | exclude_from_weight_decay=None, 180 | name="AdamWeightDecayOptimizer"): 181 | """Constructs a AdamWeightDecayOptimizer.""" 182 | super(AdamWeightDecayOptimizer, self).__init__(False, name) 183 | 184 | self.learning_rate = learning_rate 185 | self.weight_decay_rate = weight_decay_rate 186 | self.beta_1 = beta_1 187 | self.beta_2 = beta_2 188 | self.epsilon = epsilon 189 | self.exclude_from_weight_decay = exclude_from_weight_decay 190 | self.head_lr_ratio = head_lr_ratio 191 | 192 | def _apply_gradients(self, grads_and_vars, learning_rate): 193 | """See base class.""" 194 | assignments = [] 195 | for (grad, param) in grads_and_vars: 196 | if grad is None or param is None: 197 | continue 198 | 199 | param_name = self._get_variable_name(param.name) 200 | 201 | m = tf.get_variable( 202 | name=param_name + "/adam_m", 203 | shape=param.shape.as_list(), 204 | dtype=tf.float32, 205 | trainable=False, 206 | initializer=tf.zeros_initializer()) 207 | v = tf.get_variable( 208 | name=param_name + "/adam_v", 209 | shape=param.shape.as_list(), 210 | dtype=tf.float32, 211 | trainable=False, 212 | initializer=tf.zeros_initializer()) 213 | 214 | # Standard Adam update. 215 | next_m = ( 216 | tf.multiply(self.beta_1, m) + tf.multiply(1.0 - self.beta_1, grad)) 217 | next_v = ( 218 | tf.multiply(self.beta_2, v) + tf.multiply(1.0 - self.beta_2, 219 | tf.square(grad))) 220 | update = next_m / (tf.sqrt(next_v) + self.epsilon) 221 | 222 | # Just adding the square of the weights to the loss function is *not* 223 | # the correct way of using L2 regularization/weight decay with Adam, 224 | # since that will interact with the m and v parameters in strange ways. 225 | # 226 | # Instead we want ot decay the weights in a manner that doesn't interact 227 | # with the m/v parameters. This is equivalent to adding the square 228 | # of the weights to the loss with plain (non-momentum) SGD. 229 | if self.weight_decay_rate > 0: 230 | if self._do_use_weight_decay(param_name): 231 | update += self.weight_decay_rate * param 232 | 233 | update_with_lr = learning_rate * update 234 | next_param = param - update_with_lr 235 | 236 | assignments.extend( 237 | [param.assign(next_param), 238 | m.assign(next_m), 239 | v.assign(next_v)]) 240 | 241 | return assignments 242 | 243 | def apply_gradients(self, grads_and_vars, global_step=None, name=None): 244 | """See base class.""" 245 | if self.head_lr_ratio > 1.0: 246 | def is_backbone(n): 247 | return 'bert' in n 248 | assignments = [] 249 | backbone_gvs = [] 250 | head_gvs = [] 251 | for grad,var in grads_and_vars: 252 | if is_backbone(var.name): 253 | backbone_gvs.append((grad,var)) 254 | else: 255 | head_gvs.append((grad,var)) 256 | assignments += self._apply_gradients(backbone_gvs,self.learning_rate) 257 | assignments += self._apply_gradients(head_gvs,self.learning_rate * self.head_lr_ratio) 258 | else: 259 | assignments = self._apply_gradients(grads_and_vars, self.learning_rate) 260 | return tf.group(*assignments,name=name) 261 | 262 | def _do_use_weight_decay(self, param_name): 263 | """Whether to use L2 weight decay for `param_name`.""" 264 | if not self.weight_decay_rate: 265 | return False 266 | if self.exclude_from_weight_decay: 267 | for r in self.exclude_from_weight_decay: 268 | if re.search(r, param_name) is not None: 269 | return False 270 | return True 271 | 272 | def _get_variable_name(self, param_name): 273 | """Get the variable name from the tensor name.""" 274 | m = re.match("^(.*):\\d+$", param_name) 275 | if m is not None: 276 | param_name = m.group(1) 277 | return param_name 278 | -------------------------------------------------------------------------------- /code/tokenization.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Tokenization classes.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import collections 22 | import re 23 | import unicodedata 24 | import six 25 | import tensorflow as tf 26 | 27 | 28 | 29 | def convert_to_unicode(text): 30 | """Converts `text` to Unicode (if it's not already), assuming utf-8 input.""" 31 | if six.PY3: 32 | if isinstance(text, str): 33 | return text 34 | elif isinstance(text, bytes): 35 | return text.decode("utf-8", "ignore") 36 | else: 37 | raise ValueError("Unsupported string type: %s" % (type(text))) 38 | elif six.PY2: 39 | if isinstance(text, str): 40 | return text.decode("utf-8", "ignore") 41 | elif isinstance(text, unicode): 42 | return text 43 | else: 44 | raise ValueError("Unsupported string type: %s" % (type(text))) 45 | else: 46 | raise ValueError("Not running on Python2 or Python 3?") 47 | 48 | 49 | def printable_text(text): 50 | """Returns text encoded in a way suitable for print or `tf.logging`.""" 51 | 52 | # These functions want `str` for both Python2 and Python3, but in one case 53 | # it's a Unicode string and in the other it's a byte string. 54 | if six.PY3: 55 | if isinstance(text, str): 56 | return text 57 | elif isinstance(text, bytes): 58 | return text.decode("utf-8", "ignore") 59 | else: 60 | raise ValueError("Unsupported string type: %s" % (type(text))) 61 | elif six.PY2: 62 | if isinstance(text, str): 63 | return text 64 | elif isinstance(text, unicode): 65 | return text.encode("utf-8") 66 | else: 67 | raise ValueError("Unsupported string type: %s" % (type(text))) 68 | else: 69 | raise ValueError("Not running on Python2 or Python 3?") 70 | 71 | 72 | def load_vocab(vocab_file): 73 | """Loads a vocabulary file into a dictionary.""" 74 | vocab = collections.OrderedDict() 75 | index = 0 76 | with tf.gfile.GFile(vocab_file, "r") as reader: 77 | while True: 78 | token = convert_to_unicode(reader.readline()) 79 | if not token: 80 | break 81 | token = token.strip() 82 | vocab[token] = index 83 | index += 1 84 | return vocab 85 | 86 | 87 | def convert_by_vocab(vocab, items): 88 | """Converts a sequence of [tokens|ids] using the vocab.""" 89 | output = [] 90 | for item in items: 91 | output.append(vocab[item]) 92 | return output 93 | 94 | 95 | def convert_tokens_to_ids(vocab, tokens): 96 | return convert_by_vocab(vocab, tokens) 97 | 98 | 99 | def convert_ids_to_tokens(inv_vocab, ids): 100 | return convert_by_vocab(inv_vocab, ids) 101 | 102 | 103 | def whitespace_tokenize(text): 104 | """Runs basic whitespace cleaning and splitting on a piece of text.""" 105 | text = text.strip() 106 | if not text: 107 | return [] 108 | tokens = text.split() 109 | return tokens 110 | 111 | class SimpleTokenizer(object): 112 | def __init__(self, vocab_file, do_lower_case=True): 113 | self.vocab = load_vocab(vocab_file) 114 | self.do_lower_case = do_lower_case 115 | self.inv_vocab = {v: k for k, v in self.vocab.items()} 116 | 117 | def tokenize(self, text): 118 | if self.do_lower_case: 119 | text = text.lower() 120 | tokens = [] 121 | for token in text.strip(): 122 | if token == ' ': 123 | tokens.append('[SPACE]') 124 | elif token in self.vocab: 125 | tokens.append(token) 126 | else: 127 | tokens.append('[UNK]') 128 | return tokens 129 | 130 | def convert_tokens_to_ids(self, tokens): 131 | return convert_by_vocab(self.vocab, tokens) 132 | 133 | def convert_ids_to_tokens(self, ids): 134 | return convert_by_vocab(self.inv_vocab, ids) 135 | 136 | 137 | class FullTokenizer(object): 138 | """Runs end-to-end tokenziation.""" 139 | 140 | def __init__(self, vocab_file, do_lower_case=True): 141 | self.vocab = load_vocab(vocab_file) 142 | self.inv_vocab = {v: k for k, v in self.vocab.items()} 143 | self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case) 144 | self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab) 145 | 146 | def tokenize(self, text): 147 | split_tokens = [] 148 | for token in self.basic_tokenizer.tokenize(text): 149 | for sub_token in self.wordpiece_tokenizer.tokenize(token): 150 | split_tokens.append(sub_token) 151 | 152 | return split_tokens 153 | 154 | def convert_tokens_to_ids(self, tokens): 155 | return convert_by_vocab(self.vocab, tokens) 156 | 157 | def convert_ids_to_tokens(self, ids): 158 | return convert_by_vocab(self.inv_vocab, ids) 159 | 160 | 161 | class BasicTokenizer(object): 162 | """Runs basic tokenization (punctuation splitting, lower casing, etc.).""" 163 | 164 | def __init__(self, do_lower_case=True): 165 | """Constructs a BasicTokenizer. 166 | 167 | Args: 168 | do_lower_case: Whether to lower case the input. 169 | """ 170 | self.do_lower_case = do_lower_case 171 | 172 | def tokenize(self, text): 173 | """Tokenizes a piece of text.""" 174 | text = convert_to_unicode(text) 175 | text = self._clean_text(text) 176 | 177 | # This was added on November 1st, 2018 for the multilingual and Chinese 178 | # models. This is also applied to the English models now, but it doesn't 179 | # matter since the English models were not trained on any Chinese data 180 | # and generally don't have any Chinese data in them (there are Chinese 181 | # characters in the vocabulary because Wikipedia does have some Chinese 182 | # words in the English Wikipedia.). 183 | text = self._tokenize_chinese_chars(text) 184 | 185 | orig_tokens = whitespace_tokenize(text) 186 | split_tokens = [] 187 | for token in orig_tokens: 188 | if self.do_lower_case: 189 | token = token.lower() 190 | token = self._run_strip_accents(token) 191 | split_tokens.extend(self._run_split_on_punc(token)) 192 | 193 | output_tokens = whitespace_tokenize(" ".join(split_tokens)) 194 | return output_tokens 195 | 196 | def _run_strip_accents(self, text): 197 | """Strips accents from a piece of text.""" 198 | text = unicodedata.normalize("NFD", text) 199 | output = [] 200 | for char in text: 201 | cat = unicodedata.category(char) 202 | if cat == "Mn": 203 | continue 204 | output.append(char) 205 | return "".join(output) 206 | 207 | def _run_split_on_punc(self, text): 208 | """Splits punctuation on a piece of text.""" 209 | chars = list(text) 210 | i = 0 211 | start_new_word = True 212 | output = [] 213 | while i < len(chars): 214 | char = chars[i] 215 | if _is_punctuation(char): 216 | output.append([char]) 217 | start_new_word = True 218 | else: 219 | if start_new_word: 220 | output.append([]) 221 | start_new_word = False 222 | output[-1].append(char) 223 | i += 1 224 | 225 | return ["".join(x) for x in output] 226 | 227 | def _tokenize_chinese_chars(self, text): 228 | """Adds whitespace around any CJK character.""" 229 | output = [] 230 | for char in text: 231 | cp = ord(char) 232 | if self._is_chinese_char(cp): 233 | output.append(" ") 234 | output.append(char) 235 | output.append(" ") 236 | else: 237 | output.append(char) 238 | return "".join(output) 239 | 240 | def _is_chinese_char(self, cp): 241 | """Checks whether CP is the codepoint of a CJK character.""" 242 | # This defines a "chinese character" as anything in the CJK Unicode block: 243 | # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block) 244 | # 245 | # Note that the CJK Unicode block is NOT all Japanese and Korean characters, 246 | # despite its name. The modern Korean Hangul alphabet is a different block, 247 | # as is Japanese Hiragana and Katakana. Those alphabets are used to write 248 | # space-separated words, so they are not treated specially and handled 249 | # like the all of the other languages. 250 | if ((cp >= 0x4E00 and cp <= 0x9FFF) or # 251 | (cp >= 0x3400 and cp <= 0x4DBF) or # 252 | (cp >= 0x20000 and cp <= 0x2A6DF) or # 253 | (cp >= 0x2A700 and cp <= 0x2B73F) or # 254 | (cp >= 0x2B740 and cp <= 0x2B81F) or # 255 | (cp >= 0x2B820 and cp <= 0x2CEAF) or 256 | (cp >= 0xF900 and cp <= 0xFAFF) or # 257 | (cp >= 0x2F800 and cp <= 0x2FA1F)): # 258 | return True 259 | 260 | return False 261 | 262 | def _clean_text(self, text): 263 | """Performs invalid character removal and whitespace cleanup on text.""" 264 | output = [] 265 | for char in text: 266 | cp = ord(char) 267 | if cp == 0 or cp == 0xfffd or _is_control(char): 268 | continue 269 | if _is_whitespace(char): 270 | output.append(" ") 271 | else: 272 | output.append(char) 273 | return "".join(output) 274 | 275 | 276 | class WordpieceTokenizer(object): 277 | """Runs WordPiece tokenziation.""" 278 | 279 | def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=200): 280 | self.vocab = vocab 281 | self.unk_token = unk_token 282 | self.max_input_chars_per_word = max_input_chars_per_word 283 | 284 | def tokenize(self, text): 285 | """Tokenizes a piece of text into its word pieces. 286 | 287 | This uses a greedy longest-match-first algorithm to perform tokenization 288 | using the given vocabulary. 289 | 290 | For example: 291 | input = "unaffable" 292 | output = ["un", "##aff", "##able"] 293 | 294 | Args: 295 | text: A single token or whitespace separated tokens. This should have 296 | already been passed through `BasicTokenizer. 297 | 298 | Returns: 299 | A list of wordpiece tokens. 300 | """ 301 | 302 | text = convert_to_unicode(text) 303 | 304 | output_tokens = [] 305 | for token in whitespace_tokenize(text): 306 | chars = list(token) 307 | if len(chars) > self.max_input_chars_per_word: 308 | output_tokens.append(self.unk_token) 309 | continue 310 | 311 | is_bad = False 312 | start = 0 313 | sub_tokens = [] 314 | while start < len(chars): 315 | end = len(chars) 316 | cur_substr = None 317 | while start < end: 318 | substr = "".join(chars[start:end]) 319 | if start > 0: 320 | substr = "##" + substr 321 | if substr in self.vocab: 322 | cur_substr = substr 323 | break 324 | end -= 1 325 | if cur_substr is None: 326 | is_bad = True 327 | break 328 | sub_tokens.append(cur_substr) 329 | start = end 330 | 331 | if is_bad: 332 | output_tokens.append(self.unk_token) 333 | else: 334 | output_tokens.extend(sub_tokens) 335 | return output_tokens 336 | 337 | 338 | def _is_whitespace(char): 339 | """Checks whether `chars` is a whitespace character.""" 340 | # \t, \n, and \r are technically contorl characters but we treat them 341 | # as whitespace since they are generally considered as such. 342 | if char == " " or char == "\t" or char == "\n" or char == "\r": 343 | return True 344 | cat = unicodedata.category(char) 345 | if cat == "Zs": 346 | return True 347 | return False 348 | 349 | 350 | def _is_control(char): 351 | """Checks whether `chars` is a control character.""" 352 | # These are technically control characters but we count them as whitespace 353 | # characters. 354 | if char == "\t" or char == "\n" or char == "\r": 355 | return False 356 | cat = unicodedata.category(char) 357 | if cat.startswith("C"): 358 | return True 359 | return False 360 | 361 | 362 | def _is_punctuation(char): 363 | """Checks whether `chars` is a punctuation character.""" 364 | cp = ord(char) 365 | # We treat all non-letter/number ASCII as punctuation. 366 | # Characters such as "^", "$", and "`" are not in the Unicode 367 | # Punctuation class but we treat them as punctuation anyways, for 368 | # consistency. 369 | if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or 370 | (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)): 371 | return True 372 | cat = unicodedata.category(char) 373 | if cat.startswith("P"): 374 | return True 375 | return False 376 | -------------------------------------------------------------------------------- /code/run_biaffine_relation.py: -------------------------------------------------------------------------------- 1 | #! usr/bin/env python3 2 | # -*- coding:utf-8 -*- 3 | """ 4 | Copyright 2022 Xueyou.Luo 5 | BASED ON Google_BERT. 6 | """ 7 | from __future__ import absolute_import, division, print_function 8 | 9 | import collections 10 | import json 11 | import os 12 | import pdb 13 | import pickle 14 | import random 15 | import sys 16 | import time 17 | from collections import defaultdict 18 | 19 | import numpy as np 20 | import tensorflow as tf 21 | from tqdm import tqdm 22 | 23 | import modeling 24 | import optimization 25 | import tokenization 26 | from utils import get_biaffine_predicate 27 | 28 | # 这里为了避免打印重复的日志信息 29 | tf.get_logger().propagate = False 30 | 31 | flags = tf.flags 32 | 33 | FLAGS = flags.FLAGS 34 | 35 | ## K-fold 36 | flags.DEFINE_integer("fold_id", 0, "which fold") 37 | flags.DEFINE_integer("fold_num", 1, "total fold number") 38 | 39 | flags.DEFINE_integer("seed", 20190525, "random seed") 40 | 41 | flags.DEFINE_string( 42 | "task_name", "spo", "The name of the task to train." 43 | ) 44 | 45 | flags.DEFINE_string( 46 | "data_dir", None, 47 | "The input datadir.", 48 | ) 49 | 50 | flags.DEFINE_string( 51 | "output_dir", None, 52 | "The output directory where the model checkpoints will be written." 53 | ) 54 | 55 | flags.DEFINE_string( 56 | "bert_config_file", None, 57 | "The config json file corresponding to the pre-trained BERT model." 58 | ) 59 | 60 | flags.DEFINE_string( 61 | "vocab_file", None, 62 | "The vocabulary file that the BERT model was trained on.") 63 | 64 | flags.DEFINE_string( 65 | "init_checkpoint", None, 66 | "Initial checkpoint (usually from a pre-trained BERT model)." 67 | ) 68 | 69 | flags.DEFINE_bool( 70 | "do_lower_case", True, 71 | "Whether to lower case the input text." 72 | ) 73 | 74 | flags.DEFINE_integer( 75 | "max_seq_length", 128, 76 | "The maximum total input sequence length after WordPiece tokenization." 77 | ) 78 | 79 | flags.DEFINE_bool( 80 | "do_train", False, 81 | "Whether to run training." 82 | ) 83 | 84 | flags.DEFINE_bool( 85 | "do_eval", False, 86 | "Whether to run eval on the dev set.") 87 | 88 | flags.DEFINE_bool( 89 | "do_train_and_eval", False, 90 | "Whether to run training and evaluation." 91 | ) 92 | flags.DEFINE_bool( 93 | "do_predict", False, 94 | "Whether to run the model in inference mode on the test set.") 95 | 96 | flags.DEFINE_integer( 97 | "train_batch_size", 64, 98 | "Total batch size for training.") 99 | 100 | flags.DEFINE_integer( 101 | "eval_batch_size", 32, 102 | "Total batch size for eval.") 103 | 104 | flags.DEFINE_integer( 105 | "predict_batch_size", 32, 106 | "Total batch size for predict.") 107 | 108 | flags.DEFINE_float( 109 | "learning_rate", 5e-6, 110 | "The initial learning rate for Adam.") 111 | 112 | flags.DEFINE_float( 113 | "num_train_epochs", 10.0, 114 | "Total number of training epochs to perform.") 115 | 116 | flags.DEFINE_float( 117 | "warmup_proportion", 0.1, 118 | "Proportion of training to perform linear learning rate warmup for. " 119 | "E.g., 0.1 = 10% of training.") 120 | 121 | flags.DEFINE_integer( 122 | "save_checkpoints_steps", 1000, 123 | "How often to save the model checkpoint.") 124 | 125 | flags.DEFINE_bool("horovod", False, 126 | "Whether to use Horovod for multi-gpu runs") 127 | flags.DEFINE_bool( 128 | "amp", False, "Whether to enable AMP ops. When false, uses TF32 on A100 and FP32 on V100 GPUS.") 129 | flags.DEFINE_bool("use_xla", False, "Whether to enable XLA JIT compilation.") 130 | flags.DEFINE_string( 131 | "pooling_type", 'last', "last | first_last " 132 | ) 133 | # Dropout 134 | flags.DEFINE_float("embedding_dropout", 0.0, "dropout ratio of embedding") 135 | flags.DEFINE_float("spatial_dropout", 0.0, 136 | "dropout ratio of embedding, in channel") 137 | flags.DEFINE_float("bert_dropout", 0.0, "dropout ratio of bert") 138 | # FGM 139 | flags.DEFINE_bool( 140 | "use_fgm", False, 141 | "Whether to use FGM to train model.") 142 | flags.DEFINE_float("fgm_epsilon", 0.3, "The epsilon value for FGM") 143 | flags.DEFINE_float("fgm_loss_ratio", 1.0, "The ratio of fgm loss") 144 | flags.DEFINE_float("head_lr_ratio", 1.0, "The ratio of header learning rate") 145 | flags.DEFINE_bool("use_bilstm", False, 146 | "Whether to use Bi-LSTM in the last layer.") 147 | flags.DEFINE_bool("electra", False, "Whether to use electra") 148 | flags.DEFINE_bool("dp_decode", False, "Whether to use dp to decode") 149 | flags.DEFINE_integer("biaffine_size", 150, "biaffine size") 150 | 151 | 152 | class InputExample(object): 153 | """A single training/test example for simple sequence classification.""" 154 | 155 | def __init__(self, guid, text, label=None, raw_text=None): 156 | """Constructs a InputExample. 157 | 158 | Args: 159 | guid: Unique id for the example. 160 | text_a: string. The untokenized text of the first sequence. For single 161 | sequence tasks, only this sequence must be specified. 162 | label: (Optional) string. The label of the example. This should be 163 | specified for train and dev examples, but not for test examples. 164 | """ 165 | self.guid = guid 166 | self.text = text 167 | self.label = label 168 | self.raw_text = raw_text 169 | 170 | class InputFeatures(object): 171 | """A single set of features of data.""" 172 | 173 | def __init__(self, input_ids, input_mask, segment_ids, span_mask, gold_labels): 174 | self.input_ids = input_ids 175 | self.input_mask = input_mask 176 | self.segment_ids = segment_ids 177 | self.span_mask = span_mask 178 | self.gold_labels = gold_labels 179 | 180 | def to_dict(self): 181 | return { 182 | "input_ids":self.input_ids, 183 | "input_mask":self.input_mask, 184 | "segment_ids":self.segment_ids, 185 | "span_mask":self.span_mask, 186 | "gold_labels":self.gold_labels, 187 | 188 | } 189 | 190 | class DataProcessor(object): 191 | """Base class for data converters for sequence classification data sets.""" 192 | 193 | def get_train_examples(self, data_dir): 194 | """Gets a collection of `InputExample`s for the train set.""" 195 | raise NotImplementedError() 196 | 197 | def get_dev_examples(self, data_dir): 198 | """Gets a collection of `InputExample`s for the dev set.""" 199 | raise NotImplementedError() 200 | 201 | def get_labels(self): 202 | """Gets the list of labels for this data set.""" 203 | raise NotImplementedError() 204 | 205 | class SPOObject: 206 | def __init__(self,s,p,o,st,ss,se,ot,os,oe): 207 | self.s = s 208 | self.o = o 209 | self.st = st 210 | self.p = p 211 | self.ot = ot 212 | self.ss = ss 213 | self.se = se 214 | self.os = os 215 | self.oe = oe 216 | 217 | def __str__(self): 218 | return f"{self.s} | {self.st} - {self.p} - {self.o} | {self.ot}" 219 | 220 | @classmethod 221 | def from_item(cls,text,item): 222 | s = item['subject'] 223 | p = item['predicate'] 224 | o = item['object'] 225 | st = 'entity' #item['subject_type'] 226 | ot = 'entity' #item['object_type'] 227 | ss = text.find(s) 228 | if ss == -1: 229 | return None 230 | se = ss + len(s) - 1 231 | 232 | os = text.find(o) 233 | if os == -1: 234 | return None 235 | 236 | oe = os + len(o) - 1 237 | return cls(s,p,o,st,ss,se,ot,os,oe) 238 | 239 | class SPOProcessor(DataProcessor): 240 | def __init__(self, fold_id=0, fold_num=0, max_seq_length=128): 241 | self.fold_id = fold_id 242 | self.fold_num = fold_num 243 | self.max_seq_length = max_seq_length 244 | 245 | def get_train_examples(self, data_dir, file_name='train_data.json'): 246 | examples = [] 247 | 248 | for i, line in enumerate(open(os.path.join(data_dir, file_name))): 249 | item = json.loads(line) 250 | guid = "%s-%s" % ('train', i) 251 | text = item['text'].strip() 252 | if len(text) > self.max_seq_length: 253 | for step in range(0,len(text),self.max_seq_length): 254 | _text = text[step:step+self.max_seq_length].strip() 255 | if len(_text) < 5: 256 | continue 257 | label = item['spo_list'] 258 | label = self.spo_convert(_text, label) 259 | examples.append(InputExample(guid=guid, text=_text, label=label)) 260 | else: 261 | label = item['spo_list'] 262 | label = self.spo_convert(text, label) 263 | examples.append(InputExample(guid=guid, text=text, label=label)) 264 | 265 | 266 | random.shuffle(examples) 267 | return examples 268 | 269 | def get_dev_examples(self, data_dir, file_name="dev_data.json"): 270 | examples = [] 271 | for i, line in enumerate(open(os.path.join(data_dir, file_name))): 272 | item = json.loads(line) 273 | guid = '%s-%s' % ('dev', i) 274 | text = item['text'].strip() 275 | if len(text) > self.max_seq_length: 276 | for step in range(0,len(text),self.max_seq_length): 277 | _text = text[step:step+self.max_seq_length].strip() 278 | if len(_text) < 5: 279 | continue 280 | label = item['spo_list'] 281 | label = self.spo_convert(_text, label) 282 | examples.append(InputExample(guid=guid, text=_text, label=label)) 283 | else: 284 | label = item['spo_list'] 285 | label = self.spo_convert(text, label) 286 | examples.append(InputExample(guid=guid, text=text, label=label)) 287 | 288 | return examples 289 | 290 | def get_test_examples(self, data_dir, file_name="final_test.txt"): 291 | examples = [] 292 | return examples 293 | 294 | def get_ner_labels(self): 295 | # labels = ["景点", "作品", "书籍", "歌曲", "气候", "生物", "出版社", "目", "Number", "地点", "网络小说", "历史人物", "网站", "音乐专辑", "图书作品", "城市", "人物", "Text", "学校", "影视作品", "企业", "Date", "学科专业", "语言", "电视综艺", "机构", "行政区", "国家"] 296 | labels = ['entity'] 297 | return labels 298 | 299 | def get_predicate_labels(self): 300 | labels = ["祖籍", "父亲", "总部地点", "出生地", "目", "面积", "简称", "上映时间", "妻子", "所属专辑", "注册资本", "首都", "导演", "字", "身高", "出品公司", "修业年限", "出生日期", "制片人", "母亲", "编剧", "国籍", "海拔", "连载网站", "丈夫", "朝代", "民族", "号", "出版社", "主持人", "专业代码", "歌手", "作词", "主角", "董事长", "成立日期", "毕业院校", "占地面积", "官方语言", "邮政编码", "人口数量", "所在城市", "作者", "作曲", "气候", "嘉宾", "主演", "改编自", "创始人"] 301 | return labels 302 | 303 | def get_all_labels(self): 304 | link_types = { 305 | "SH2OH", # subject head to object head 306 | "OH2SH", # object head to subject head 307 | "ST2OT", # subject tail to object tail 308 | "OT2ST", # object tail to subject tail 309 | } 310 | tags = {''.join([ent, "EH2ET"]) for ent in self.get_ner_labels()} # EH2ET: entity head to entity tail 311 | tags |= {''.join([rel, lt]) for rel in self.get_predicate_labels() for lt in link_types} 312 | 313 | return sorted(tags) 314 | 315 | def spo_convert(self, text, label): 316 | new_labels = [] 317 | for x in label: 318 | spo = SPOObject.from_item(text,x) 319 | if spo: 320 | new_labels.append(spo) 321 | return new_labels 322 | 323 | def convert_single_example(ex_index, example, label_map, max_seq_length, tokenizer, is_training): 324 | tokens = tokenizer.tokenize(example.text) 325 | text = example.text 326 | if len(tokens) > max_seq_length: 327 | tokens = tokens[0:max_seq_length] 328 | text = text[0:max_seq_length] 329 | try: 330 | assert len(text) == len(tokens) 331 | except: 332 | print(text) 333 | print(tokens) 334 | print(example.guid) 335 | raise 336 | 337 | ntokens = [] 338 | segment_ids = [] 339 | span_mask = [] 340 | 341 | for i, token in enumerate(tokens): 342 | ntokens.append(token) 343 | segment_ids.append(0) 344 | span_mask.append(1) 345 | 346 | input_ids = tokenizer.convert_tokens_to_ids(ntokens) 347 | input_mask = [1] * len(input_ids) 348 | while len(input_ids) < max_seq_length: 349 | input_ids.append(0) 350 | input_mask.append(0) 351 | segment_ids.append(0) 352 | span_mask.append(0) 353 | 354 | assert len(input_ids) == max_seq_length 355 | assert len(input_mask) == max_seq_length 356 | assert len(segment_ids) == max_seq_length 357 | assert len(span_mask) == max_seq_length 358 | 359 | if is_training: 360 | size = len(text) 361 | n = size * (size + 1) // 2 362 | gold_labels = [[0] * len(label_map) for _ in range(n)] 363 | 364 | def get_position(s,e): 365 | return s * size + e - s * (s + 1) // 2 366 | 367 | for spo in example.label: 368 | if spo.ss >= size or spo.se >= size or spo.os >= size or spo.oe >= size: 369 | continue 370 | gold_labels[get_position(spo.ss,spo.se)][label_map[''.join([spo.st,'EH2ET'])]] = 1 371 | gold_labels[get_position(spo.os,spo.oe)][label_map[''.join([spo.ot,'EH2ET'])]] = 1 372 | if spo.ss > spo.os: 373 | gold_labels[get_position(spo.os,spo.ss)][label_map[''.join([spo.p,'OH2SH'])]] = 1 374 | else: 375 | gold_labels[get_position(spo.ss,spo.os)][label_map[''.join([spo.p,'SH2OH'])]] = 1 376 | 377 | if spo.se > spo.oe: 378 | gold_labels[get_position(spo.oe,spo.se)][label_map[''.join([spo.p,'OT2ST'])]] = 1 379 | else: 380 | gold_labels[get_position(spo.se,spo.oe)][label_map[''.join([spo.p,'ST2OT'])]] = 1 381 | 382 | else: 383 | gold_labels = [[0] * len(label_map)] 384 | 385 | feature = InputFeatures( 386 | input_ids=input_ids, 387 | input_mask=input_mask, 388 | segment_ids=segment_ids, 389 | span_mask=span_mask, 390 | gold_labels=gold_labels, 391 | ) 392 | return feature 393 | 394 | def generator_based_input_fn_builder(examples, label_list, max_seq_length, tokenizer, is_training, batch_size): 395 | label_map = {} 396 | for (i, label) in enumerate(label_list): 397 | label_map[label] = i 398 | 399 | def generator(): 400 | for (ex_index, example) in enumerate(examples): 401 | feature = convert_single_example(ex_index, example, label_map, max_seq_length, tokenizer, 402 | is_training) 403 | yield feature.to_dict() 404 | 405 | def input_fn(params): 406 | d = tf.data.Dataset.from_generator( 407 | generator, 408 | output_types={ 409 | "input_ids": tf.int32, 410 | "input_mask": tf.int32, 411 | 'segment_ids': tf.int32, 412 | 'span_mask': tf.int32, 413 | 'gold_labels': tf.int32 414 | }, 415 | output_shapes={ 416 | "input_ids": tf.TensorShape([max_seq_length]), 417 | "input_mask": tf.TensorShape([max_seq_length]), 418 | 'segment_ids': tf.TensorShape([max_seq_length]), 419 | 'span_mask': tf.TensorShape([max_seq_length]), 420 | 'gold_labels': tf.TensorShape([None,len(label_map)]) 421 | } 422 | ) 423 | 424 | if is_training: 425 | d = d.repeat() 426 | d = d.shuffle(buffer_size=100) 427 | 428 | d = d.padded_batch( 429 | batch_size, 430 | padded_shapes={ 431 | "input_ids": (tf.TensorShape([max_seq_length])), 432 | "input_mask": tf.TensorShape([max_seq_length]), 433 | "segment_ids": tf.TensorShape([max_seq_length]), 434 | "span_mask": tf.TensorShape([max_seq_length]), 435 | "gold_labels": tf.TensorShape([None,len(label_map)]) 436 | }, 437 | padding_values={ 438 | 'input_ids': 0, 439 | "input_mask": 0, 440 | "segment_ids": 0, 441 | 'span_mask': 0, 442 | 'gold_labels': -1 # -1是为了boolen_mask方便而设置 443 | }, 444 | drop_remainder=False 445 | ).prefetch(20) 446 | return d 447 | return input_fn 448 | 449 | 450 | def biaffine_mapping(vector_set_1, 451 | vector_set_2, 452 | output_size, 453 | add_bias_1=True, 454 | add_bias_2=True, 455 | initializer=None, 456 | name='Bilinear'): 457 | """Bilinear mapping: maps two vector spaces to a third vector space. 458 | The input vector spaces are two 3d matrices: batch size x bucket size x values 459 | A typical application of the function is to compute a square matrix 460 | representing a dependency tree. The output is for each bucket a square 461 | matrix of the form [bucket size, output size, bucket size]. If the output size 462 | is set to 1 then results is [bucket size, 1, bucket size] equivalent to 463 | a square matrix where the bucket for instance represent the tokens on 464 | the x-axis and y-axis. In this way represent the adjacency matrix of a 465 | dependency graph (see https://arxiv.org/abs/1611.01734). 466 | Args: 467 | vector_set_1: vectors of space one 468 | vector_set_2: vectors of space two 469 | output_size: number of output labels (e.g. edge labels) 470 | add_bias_1: Whether to add a bias for input one 471 | add_bias_2: Whether to add a bias for input two 472 | initializer: Initializer for the bilinear weight map 473 | Returns: 474 | Output vector space as 4d matrix: 475 | batch size x bucket size x output size x bucket size 476 | The output could represent an unlabeled dependency tree when 477 | the output size is 1 or a labeled tree otherwise. 478 | """ 479 | with tf.variable_scope(name, reuse=tf.AUTO_REUSE): 480 | # Dynamic shape info 481 | batch_size = tf.shape(vector_set_1)[0] 482 | bucket_size = tf.shape(vector_set_1)[1] 483 | 484 | if add_bias_1: 485 | vector_set_1 = tf.concat( 486 | [vector_set_1, tf.ones([batch_size, bucket_size, 1])], axis=2) 487 | if add_bias_2: 488 | vector_set_2 = tf.concat( 489 | [vector_set_2, tf.ones([batch_size, bucket_size, 1])], axis=2) 490 | 491 | # Static shape info 492 | vector_set_1_size = vector_set_1.get_shape().as_list()[-1] 493 | vector_set_2_size = vector_set_2.get_shape().as_list()[-1] 494 | 495 | if not initializer: 496 | initializer = tf.orthogonal_initializer() 497 | 498 | # Mapping matrix 499 | bilinear_map = tf.get_variable( 500 | 'bilinear_map', [vector_set_1_size, 501 | output_size, vector_set_2_size], 502 | initializer=initializer) 503 | 504 | # The matrix operations and reshapings for bilinear mapping. 505 | # b: batch size (batch of buckets) 506 | # v1, v2: values (size of vectors) 507 | # n: tokens (size of bucket) 508 | # r: labels (output size), e.g. 1 if unlabeled or number of edge labels. 509 | 510 | # [b, n, v1] -> [b*n, v1] 511 | vector_set_1 = tf.reshape(vector_set_1, [-1, vector_set_1_size]) 512 | 513 | # [v1, r, v2] -> [v1, r*v2] 514 | bilinear_map = tf.reshape(bilinear_map, [vector_set_1_size, -1]) 515 | 516 | # [b*n, v1] x [v1, r*v2] -> [b*n, r*v2] 517 | bilinear_mapping = tf.matmul(vector_set_1, bilinear_map) 518 | 519 | # [b*n, r*v2] -> [b, n*r, v2] 520 | bilinear_mapping = tf.reshape( 521 | bilinear_mapping, 522 | [batch_size, bucket_size * output_size, vector_set_2_size]) 523 | 524 | # [b, n*r, v2] x [b, n, v2]T -> [b, n*r, n] 525 | bilinear_mapping = tf.matmul( 526 | bilinear_mapping, vector_set_2, adjoint_b=True) 527 | 528 | # [b, n*r, n] -> [b, n, r, n] 529 | bilinear_mapping = tf.reshape( 530 | bilinear_mapping, [batch_size, bucket_size, output_size, bucket_size]) 531 | return bilinear_mapping 532 | 533 | 534 | def create_model(bert_config, is_training, input_ids, input_mask, 535 | segment_ids, span_mask, num_labels,use_fgm=False, 536 | perturbation=None, spatial_dropout=None,embedding_dropout=0.0, 537 | bilstm=None,biaffine_size=150,pooling_type='last'): 538 | model = modeling.BertModel( 539 | config=bert_config, 540 | is_training=is_training, 541 | input_ids=input_ids, 542 | input_mask=input_mask, 543 | token_type_ids=segment_ids, 544 | use_one_hot_embeddings=False, 545 | use_fgm=use_fgm, 546 | perturbation=perturbation, 547 | spatial_dropout=spatial_dropout, 548 | embedding_dropout=embedding_dropout 549 | ) 550 | 551 | output_layer = model.get_sequence_output() 552 | 553 | if pooling_type != 'last': 554 | raise NotImplementedError('没实现。') 555 | 556 | batch_size, seq_length, hidden_size = modeling.get_shape_list( 557 | output_layer, expected_rank=3) 558 | 559 | if bilstm is not None and len(bilstm) == 2: 560 | tf.logging.info('Using Bi-LSTM') 561 | sequence_length = tf.reduce_sum(input_mask, axis=-1) 562 | with tf.variable_scope('bilstm', reuse=tf.AUTO_REUSE): 563 | outputs, states = tf.nn.bidirectional_dynamic_rnn( 564 | cell_fw=bilstm[0], 565 | cell_bw=bilstm[1], 566 | dtype=tf.float32, 567 | sequence_length=sequence_length, 568 | inputs=output_layer 569 | ) 570 | output_layer = tf.concat(outputs, -1) 571 | 572 | if is_training: 573 | output_layer = tf.nn.dropout(output_layer, keep_prob=0.9) 574 | 575 | # Magic Number 576 | size = biaffine_size 577 | 578 | starts = tf.layers.dense(output_layer, size, kernel_initializer=tf.truncated_normal_initializer( 579 | stddev=0.02), name='start', reuse=tf.AUTO_REUSE) 580 | ends = tf.layers.dense(output_layer, size, kernel_initializer=tf.truncated_normal_initializer( 581 | stddev=0.02), name='end', reuse=tf.AUTO_REUSE) 582 | 583 | biaffine = biaffine_mapping( 584 | starts, 585 | ends, 586 | num_labels, 587 | add_bias_1=True, 588 | add_bias_2=True, 589 | initializer=tf.zeros_initializer(), 590 | name='biaffine') 591 | 592 | # [B,1,L] [B,L,1] -> [B,L,L] 593 | span_mask = tf.cast(span_mask, dtype=tf.bool) 594 | candidate_scores_mask = tf.logical_and(tf.expand_dims( 595 | span_mask, axis=1), tf.expand_dims(span_mask, axis=2)) 596 | # B,L,L 597 | sentence_ends_leq_starts = tf.tile( 598 | tf.expand_dims( 599 | tf.logical_not(tf.sequence_mask(tf.range(seq_length), seq_length)), 600 | 0), 601 | [batch_size, 1, 1] 602 | ) 603 | # B,L,L 604 | candidate_scores_mask = tf.logical_and( 605 | candidate_scores_mask, sentence_ends_leq_starts) 606 | # B*L*L 607 | flattened_candidate_scores_mask = tf.reshape(candidate_scores_mask, [-1]) 608 | 609 | def get_valid_scores(biaffine): 610 | # B,L,L,N 611 | candidate_scores = tf.transpose(biaffine, [0, 1, 3, 2]) 612 | candidate_scores = tf.boolean_mask(tf.reshape( 613 | candidate_scores, [-1, num_labels]), flattened_candidate_scores_mask) 614 | return candidate_scores 615 | 616 | # 只获取合法位置的logits,最终变成[X,num_labels],X大小与batch中数据相关 617 | candidate_ner_scores = get_valid_scores(biaffine) 618 | 619 | return candidate_ner_scores, model 620 | 621 | def multilabel_categorical_crossentropy(y_pred, y_true): 622 | """ 623 | 参考苏神的实现:https://github.com/bojone/bert4keras/blob/3161648d20bfe7f501297d4bb33a0bad1ffd4002/bert4keras/backend.py#L250 624 | y_pred: (batch_size, shaking_seq_len, type_size) 625 | y_true: (batch_size, shaking_seq_len, type_size) 626 | y_true and y_pred have the same shape,elements in y_true are either 0 or 1, 627 | 1 tags positive classes,0 tags negtive classes(means tok-pair does not have this type of link). 628 | """ 629 | y_true = tf.cast(y_true,y_pred.dtype) 630 | y_pred = (1 - 2 * y_true) * y_pred 631 | y_neg = y_pred - y_true * 1e20 632 | y_pos = y_pred - (1 - y_true) * 1e20 633 | zeros = tf.zeros_like(y_pred[..., :1]) 634 | y_neg = tf.concat([y_neg, zeros], axis=-1) 635 | y_pos = tf.concat([y_pos, zeros], axis=-1) 636 | neg_loss = tf.reduce_logsumexp(y_neg, axis=-1) 637 | pos_loss = tf.reduce_logsumexp(y_pos, axis=-1) 638 | return neg_loss + pos_loss 639 | 640 | def model_fn_builder(bert_config, num_labels, init_checkpoint=None, learning_rate=None, 641 | num_train_steps=None, num_warmup_steps=None, 642 | use_one_hot_embeddings=False, hvd=None, amp=False): 643 | def model_fn(features, labels, mode, params): 644 | tf.compat.v1.logging.info("*** Features ***") 645 | for name in sorted(features.keys()): 646 | tf.compat.v1.logging.info( 647 | " name = %s, shape = %s" % (name, features[name].shape)) 648 | input_ids = features["input_ids"] 649 | input_mask = features["input_mask"] 650 | segment_ids = features["segment_ids"] 651 | span_mask = features["span_mask"] 652 | is_training = (mode == tf.estimator.ModeKeys.TRAIN) 653 | 654 | if is_training and FLAGS.bert_dropout > 0.0: 655 | bert_config.hidden_dropout_prob = FLAGS.bert_dropout 656 | bert_config.attention_probs_dropout_prob = FLAGS.bert_dropout 657 | 658 | batch_size = tf.shape(input_ids)[0] 659 | spatial_dropout_layer = None 660 | if is_training and FLAGS.spatial_dropout > 0.0: 661 | spatial_dropout_layer = tf.keras.layers.SpatialDropout1D( 662 | FLAGS.spatial_dropout) 663 | 664 | bilstm = None 665 | if FLAGS.use_bilstm: 666 | fw_cell = tf.nn.rnn_cell.LSTMCell(bert_config.hidden_size) 667 | bw_cell = tf.nn.rnn_cell.LSTMCell(bert_config.hidden_size) 668 | if is_training: 669 | fw_cell = lstm_dropout_warpper(fw_cell) 670 | bw_cell = lstm_dropout_warpper(bw_cell) 671 | bilstm = (fw_cell, bw_cell) 672 | 673 | reuse_model = FLAGS.use_fgm 674 | candidate_ner_scores, model = create_model( 675 | bert_config, is_training, input_ids, input_mask, segment_ids, span_mask, num_labels, 676 | spatial_dropout=spatial_dropout_layer, bilstm=bilstm, use_fgm=reuse_model, 677 | biaffine_size=FLAGS.biaffine_size,pooling_type=FLAGS.pooling_type,embedding_dropout=FLAGS.embedding_dropout 678 | ) 679 | 680 | output_spec = None 681 | 682 | if mode == tf.estimator.ModeKeys.TRAIN: 683 | tvars = tf.trainable_variables() 684 | initialized_variable_names = {} 685 | if init_checkpoint and (hvd is None or hvd.rank() == 0): 686 | (assignment_map, 687 | initialized_variable_names) = modeling.get_assignment_map_from_checkpoint(tvars, init_checkpoint, convert_electra=FLAGS.electra) 688 | tf.train.init_from_checkpoint(init_checkpoint, assignment_map) 689 | tf.compat.v1.logging.info("**** Trainable Variables ****") 690 | 691 | for var in tvars: 692 | init_string = "" 693 | if var.name in initialized_variable_names: 694 | init_string = ", *INIT_FROM_CKPT*" 695 | tf.compat.v1.logging.info(" name = %s, shape = %s%s", var.name, var.shape, 696 | init_string) 697 | 698 | gold_labels = features['gold_labels'] 699 | gold_labels = tf.reshape(gold_labels,[-1,num_labels]) 700 | # 根据-1的padding来获取真实的label,得到[X,num_labels],X与candidate_ner_scores一致 701 | gold_labels = tf.boolean_mask(gold_labels,tf.not_equal(gold_labels[...,0],-1)) 702 | total_loss = multilabel_categorical_crossentropy(candidate_ner_scores,gold_labels) 703 | total_loss = tf.reduce_sum(total_loss) / tf.to_float(batch_size) 704 | # 只计算有label的位置的准确率,避免大量0的干扰 705 | acc = tf.metrics.accuracy(gold_labels,tf.cast(tf.greater(candidate_ner_scores,0),gold_labels.dtype),weights=tf.cast(tf.greater(gold_labels,0),gold_labels.dtype)) 706 | 707 | tensor_to_log = { 708 | "accuracy": acc[1] * 100 709 | } 710 | if FLAGS.use_fgm: 711 | embedding_output = model.get_embedding_output() 712 | grad, = tf.gradients( 713 | total_loss, 714 | embedding_output, 715 | aggregation_method=tf.AggregationMethod.EXPERIMENTAL_ACCUMULATE_N) 716 | grad = tf.stop_gradient(grad) 717 | perturbation = modeling.scale_l2(grad, FLAGS.fgm_epsilon) 718 | adv_candidate_ner_scores, _ = create_model( 719 | bert_config, is_training, input_ids, input_mask, segment_ids, span_mask, num_labels, 720 | use_fgm=True, perturbation=perturbation, spatial_dropout=spatial_dropout_layer, bilstm=bilstm, 721 | biaffine_size=FLAGS.biaffine_size,pooling_type=FLAGS.pooling_type,embedding_dropout=FLAGS.embedding_dropout 722 | ) 723 | 724 | adv_loss = multilabel_categorical_crossentropy(adv_candidate_ner_scores,gold_labels) 725 | adv_loss = tf.reduce_sum(adv_loss) / tf.to_float(batch_size) 726 | 727 | total_loss = (total_loss + FLAGS.fgm_loss_ratio * 728 | adv_loss) / (1 + FLAGS.fgm_loss_ratio) 729 | 730 | train_op, _ = optimization.create_optimizer( 731 | total_loss, learning_rate, num_train_steps, num_warmup_steps, hvd, amp, head_lr_ratio=FLAGS.head_lr_ratio) 732 | output_spec = tf.estimator.EstimatorSpec( 733 | mode=mode, 734 | loss=total_loss, 735 | train_op=train_op, 736 | training_hooks=[tf.train.LoggingTensorHook(tensor_to_log, every_n_iter=50)]) 737 | elif mode == tf.estimator.ModeKeys.EVAL: 738 | # Fake metric 739 | def metric_fn(): 740 | unused_mean = tf.metrics.mean(tf.ones([2, 3])) 741 | return { 742 | "unused_mean": unused_mean 743 | } 744 | eval_metric_ops = metric_fn() 745 | output_spec = tf.estimator.EstimatorSpec( 746 | mode=mode, 747 | loss=tf.constant(1.0), 748 | eval_metric_ops=eval_metric_ops) 749 | elif mode == tf.estimator.ModeKeys.PREDICT: 750 | output_spec = tf.estimator.EstimatorSpec( 751 | mode=mode, 752 | predictions={ 753 | "score": tf.expand_dims(candidate_ner_scores, 0), 754 | 'batch_size': tf.expand_dims(batch_size, 0)} 755 | ) 756 | return output_spec 757 | 758 | return model_fn 759 | 760 | 761 | def main(_): 762 | # Set different seed for different model 763 | seed = FLAGS.seed + FLAGS.fold_id 764 | tf.random.set_random_seed(seed) 765 | random.seed(seed) 766 | np.random.seed(seed) 767 | 768 | start_time = time.time() 769 | tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.INFO) 770 | 771 | if FLAGS.horovod: 772 | import horovod.tensorflow as hvd 773 | hvd.init() 774 | 775 | processors = { 776 | "spo": SPOProcessor 777 | } 778 | 779 | if not FLAGS.do_train and not FLAGS.do_eval and not FLAGS.do_predict and not FLAGS.do_train_and_eval: 780 | raise ValueError( 781 | "At least one of `do_train` or `do_eval` must be True.") 782 | 783 | bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file) 784 | 785 | if FLAGS.max_seq_length > bert_config.max_position_embeddings: 786 | raise ValueError( 787 | "Cannot use sequence length %d because the BERT model " 788 | "was only trained up to sequence length %d" % 789 | (FLAGS.max_seq_length, bert_config.max_position_embeddings)) 790 | 791 | task_name = FLAGS.task_name.lower() 792 | if task_name not in processors: 793 | raise ValueError("Task not found: %s" % (task_name)) 794 | 795 | tf.io.gfile.makedirs(FLAGS.output_dir) 796 | 797 | processor = processors[task_name](FLAGS.fold_id,FLAGS.fold_num,FLAGS.max_seq_length) 798 | 799 | label_list = processor.get_all_labels() 800 | 801 | # 避免alignment的处理 802 | tokenizer = tokenization.SimpleTokenizer( 803 | vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case) 804 | 805 | master_process = True 806 | training_hooks = [] 807 | global_batch_size = FLAGS.train_batch_size 808 | hvd_rank = 0 809 | 810 | config = tf.compat.v1.ConfigProto() 811 | config.gpu_options.allow_growth = True 812 | config.allow_soft_placement = True 813 | 814 | if FLAGS.horovod: 815 | global_batch_size = FLAGS.train_batch_size * hvd.size() 816 | master_process = (hvd.rank() == 0) 817 | hvd_rank = hvd.rank() 818 | config.gpu_options.visible_device_list = str(hvd.local_rank()) 819 | if hvd.size() > 1: 820 | training_hooks.append(hvd.BroadcastGlobalVariablesHook(0)) 821 | 822 | if FLAGS.use_xla: 823 | config.graph_options.optimizer_options.global_jit_level = tf.compat.v1.OptimizerOptions.ON_1 824 | if FLAGS.amp: 825 | tf.enable_resource_variables() 826 | 827 | run_config = tf.estimator.RunConfig( 828 | model_dir=FLAGS.output_dir if master_process else None, 829 | session_config=config, 830 | log_step_count_steps=50, 831 | save_checkpoints_steps=FLAGS.save_checkpoints_steps if master_process else None, 832 | keep_checkpoint_max=1) 833 | 834 | if master_process: 835 | tf.compat.v1.logging.info("***** Configuaration *****") 836 | for key in FLAGS.__flags.keys(): 837 | tf.compat.v1.logging.info( 838 | ' {}: {}'.format(key, getattr(FLAGS, key))) 839 | tf.compat.v1.logging.info("**************************") 840 | 841 | train_examples = None 842 | num_train_steps = None 843 | num_warmup_steps = None 844 | 845 | if FLAGS.do_train or FLAGS.do_train_and_eval: 846 | train_examples = processor.get_train_examples(FLAGS.data_dir) 847 | num_train_steps = int( 848 | len(train_examples) / global_batch_size * FLAGS.num_train_epochs) 849 | num_warmup_steps = int(num_train_steps * FLAGS.warmup_proportion) 850 | 851 | start_index = 0 852 | end_index = len(train_examples) 853 | 854 | if FLAGS.horovod: 855 | num_examples_per_rank = len(train_examples) // hvd.size() 856 | remainder = len(train_examples) % hvd.size() 857 | if hvd.rank() < remainder: 858 | start_index = hvd.rank() * (num_examples_per_rank+1) 859 | end_index = start_index + num_examples_per_rank + 1 860 | else: 861 | start_index = hvd.rank() * num_examples_per_rank + remainder 862 | end_index = start_index + (num_examples_per_rank) 863 | 864 | model_fn = model_fn_builder( 865 | bert_config=bert_config, 866 | num_labels=len(label_list), 867 | init_checkpoint=FLAGS.init_checkpoint, 868 | learning_rate=FLAGS.learning_rate if not FLAGS.horovod else FLAGS.learning_rate * hvd.size(), 869 | num_train_steps=num_train_steps, 870 | num_warmup_steps=num_warmup_steps, 871 | use_one_hot_embeddings=False, 872 | hvd=None if not FLAGS.horovod else hvd, 873 | amp=FLAGS.amp) 874 | 875 | estimator = tf.estimator.Estimator( 876 | model_fn=model_fn, 877 | config=run_config) 878 | 879 | if FLAGS.do_train or FLAGS.do_train_and_eval: 880 | tf.compat.v1.logging.info("***** Running training *****") 881 | tf.compat.v1.logging.info(" Num examples = %d", len(train_examples)) 882 | tf.compat.v1.logging.info(" Batch size = %d", FLAGS.train_batch_size) 883 | tf.compat.v1.logging.info(" Num steps = %d", num_train_steps) 884 | train_input_fn = generator_based_input_fn_builder( 885 | examples=train_examples[start_index:end_index], 886 | label_list=label_list, 887 | max_seq_length=FLAGS.max_seq_length, 888 | tokenizer=tokenizer, 889 | is_training=True, 890 | batch_size=FLAGS.train_batch_size 891 | ) 892 | 893 | if FLAGS.do_predict or FLAGS.do_eval or FLAGS.do_train_and_eval: 894 | if FLAGS.do_eval or FLAGS.do_train_and_eval: 895 | predict_examples = processor.get_dev_examples(FLAGS.data_dir) 896 | else: 897 | predict_examples = processor.get_test_examples(FLAGS.data_dir) 898 | predict_batch_size = FLAGS.predict_batch_size 899 | tf.compat.v1.logging.info("***** Running prediction*****") 900 | tf.compat.v1.logging.info(" Num examples = %d", len(predict_examples)) 901 | tf.compat.v1.logging.info(" Batch size = %d", predict_batch_size) 902 | 903 | predict_input_fn = generator_based_input_fn_builder( 904 | examples=predict_examples, 905 | label_list=label_list, 906 | max_seq_length=FLAGS.max_seq_length, 907 | tokenizer=tokenizer, 908 | is_training=False, 909 | batch_size=predict_batch_size 910 | ) 911 | 912 | if FLAGS.do_train_and_eval: 913 | raise NotImplementedError('没实现,交给你们自己了。') 914 | else: 915 | if FLAGS.do_train: 916 | estimator.train(input_fn=train_input_fn, 917 | max_steps=num_train_steps, hooks=training_hooks) 918 | 919 | if FLAGS.do_eval or FLAGS.do_predict: 920 | if FLAGS.do_predict: 921 | raise NotImplementedError('没实现,参考eval的逻辑很容易实现。') 922 | idx = 0 923 | # TP - 预测对的数量 924 | # PN - 预测的数量 925 | # TN - 真实的数量 926 | TP,PN,TN = 1e-10,1e-10,1e-10 927 | for i, prediction in enumerate(tqdm(estimator.predict(input_fn=predict_input_fn, yield_single_examples=True), total=len(predict_examples)//predict_batch_size)): 928 | scores = prediction['score'] 929 | offset = 0 930 | bz = prediction['batch_size'] 931 | for j in range(bz): 932 | example = predict_examples[idx] 933 | text = example.text 934 | pred_text = example.text[:FLAGS.max_seq_length] 935 | size = len(pred_text) * (len(pred_text) + 1) // 2 936 | pred_score = scores[offset:offset+size] 937 | idx += 1 938 | offset += size 939 | ret = get_biaffine_predicate(pred_text,pred_score,label_list,processor.get_predicate_labels()) 940 | truth = set([(spo.s,spo.p,spo.o) for spo in example.label]) 941 | predict = set([(pred_text[s[0]:s[1]+1],p,pred_text[o[0]:o[1]+1]) for s,p,o in ret]) 942 | TP += len(truth & predict) 943 | TN += len(truth) 944 | PN += len(predict) 945 | print(f'precision {TP/PN}, recall {TP/TN}, f1 {2*TP/(PN+TN)}') 946 | 947 | if __name__ == "__main__": 948 | flags.mark_flag_as_required("data_dir") 949 | flags.mark_flag_as_required("task_name") 950 | flags.mark_flag_as_required("vocab_file") 951 | flags.mark_flag_as_required("bert_config_file") 952 | flags.mark_flag_as_required("output_dir") 953 | tf.compat.v1.app.run() 954 | -------------------------------------------------------------------------------- /code/modeling.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """The main BERT model and related functions.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import collections 22 | import copy 23 | import json 24 | import math 25 | import re 26 | import six 27 | import tensorflow as tf 28 | import numpy as np 29 | 30 | def layer_norm(input_tensor, name=None): 31 | """Run layer normalization on the last dimension of the tensor.""" 32 | return tf.contrib.layers.layer_norm( 33 | inputs=input_tensor, begin_norm_axis=-1, begin_params_axis=-1, scope=name) 34 | 35 | 36 | def scale_l2(x, norm_length=1.0): 37 | # shape(x) = (batch, num_timesteps, d) 38 | # Divide x by max(abs(x)) for a numerically stable L2 norm. 39 | # 2norm(x) = a * 2norm(x/a) 40 | # Scale over the full sequence, dims (1, 2) 41 | alpha = tf.reduce_max(tf.abs(x), (1, 2), keep_dims=True) + 1e-12 42 | l2_norm = alpha * tf.sqrt( 43 | tf.reduce_sum(tf.pow(x / alpha, 2), (1, 2), keep_dims=True) + 1e-8) 44 | x_unit = x / l2_norm 45 | return norm_length * x_unit 46 | 47 | 48 | class BertConfig(object): 49 | """Configuration for `BertModel`.""" 50 | 51 | def __init__(self, 52 | vocab_size, 53 | hidden_size=768, 54 | num_hidden_layers=12, 55 | num_attention_heads=12, 56 | intermediate_size=3072, 57 | hidden_act="gelu", 58 | hidden_dropout_prob=0.1, 59 | attention_probs_dropout_prob=0.1, 60 | max_position_embeddings=512, 61 | type_vocab_size=16, 62 | initializer_range=0.02): 63 | """Constructs BertConfig. 64 | 65 | Args: 66 | vocab_size: Vocabulary size of `inputs_ids` in `BertModel`. 67 | hidden_size: Size of the encoder layers and the pooler layer. 68 | num_hidden_layers: Number of hidden layers in the Transformer encoder. 69 | num_attention_heads: Number of attention heads for each attention layer in 70 | the Transformer encoder. 71 | intermediate_size: The size of the "intermediate" (i.e., feed-forward) 72 | layer in the Transformer encoder. 73 | hidden_act: The non-linear activation function (function or string) in the 74 | encoder and pooler. 75 | hidden_dropout_prob: The dropout probability for all fully connected 76 | layers in the embeddings, encoder, and pooler. 77 | attention_probs_dropout_prob: The dropout ratio for the attention 78 | probabilities. 79 | max_position_embeddings: The maximum sequence length that this model might 80 | ever be used with. Typically set this to something large just in case 81 | (e.g., 512 or 1024 or 2048). 82 | type_vocab_size: The vocabulary size of the `token_type_ids` passed into 83 | `BertModel`. 84 | initializer_range: The stdev of the truncated_normal_initializer for 85 | initializing all weight matrices. 86 | """ 87 | self.vocab_size = vocab_size 88 | self.hidden_size = hidden_size 89 | self.num_hidden_layers = num_hidden_layers 90 | self.num_attention_heads = num_attention_heads 91 | self.hidden_act = hidden_act 92 | self.intermediate_size = intermediate_size 93 | self.hidden_dropout_prob = hidden_dropout_prob 94 | self.attention_probs_dropout_prob = attention_probs_dropout_prob 95 | self.max_position_embeddings = max_position_embeddings 96 | self.type_vocab_size = type_vocab_size 97 | self.initializer_range = initializer_range 98 | 99 | @classmethod 100 | def from_dict(cls, json_object): 101 | """Constructs a `BertConfig` from a Python dictionary of parameters.""" 102 | config = BertConfig(vocab_size=None) 103 | for (key, value) in six.iteritems(json_object): 104 | config.__dict__[key] = value 105 | return config 106 | 107 | @classmethod 108 | def from_json_file(cls, json_file): 109 | """Constructs a `BertConfig` from a json file of parameters.""" 110 | with tf.gfile.GFile(json_file, "r") as reader: 111 | text = reader.read() 112 | return cls.from_dict(json.loads(text)) 113 | 114 | def to_dict(self): 115 | """Serializes this instance to a Python dictionary.""" 116 | output = copy.deepcopy(self.__dict__) 117 | return output 118 | 119 | def to_json_string(self): 120 | """Serializes this instance to a JSON string.""" 121 | return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n" 122 | 123 | 124 | class BertModel(object): 125 | """BERT model ("Bidirectional Encoder Representations from Transformers"). 126 | 127 | Example usage: 128 | 129 | ```python 130 | # Already been converted into WordPiece token ids 131 | input_ids = tf.constant([[31, 51, 99], [15, 5, 0]]) 132 | input_mask = tf.constant([[1, 1, 1], [1, 1, 0]]) 133 | token_type_ids = tf.constant([[0, 0, 1], [0, 2, 0]]) 134 | 135 | config = modeling.BertConfig(vocab_size=32000, hidden_size=512, 136 | num_hidden_layers=8, num_attention_heads=6, intermediate_size=1024) 137 | 138 | model = modeling.BertModel(config=config, is_training=True, 139 | input_ids=input_ids, input_mask=input_mask, token_type_ids=token_type_ids) 140 | 141 | label_embeddings = tf.get_variable(...) 142 | pooled_output = model.get_pooled_output() 143 | logits = tf.matmul(pooled_output, label_embeddings) 144 | ... 145 | ``` 146 | """ 147 | 148 | def __init__(self, 149 | config, 150 | is_training, 151 | input_ids, 152 | input_mask=None, 153 | token_type_ids=None, 154 | use_one_hot_embeddings=False, 155 | use_fgm=False, 156 | perturbation=None, 157 | spatial_dropout=None, 158 | electra=False, 159 | embedding_dropout=0.0, 160 | embedding_file=None, 161 | scope='bert'): 162 | """Constructor for BertModel. 163 | 164 | Args: 165 | config: `BertConfig` instance. 166 | is_training: bool. true for training model, false for eval model. Controls 167 | whether dropout will be applied. 168 | input_ids: int32 Tensor of shape [batch_size, seq_length]. 169 | input_mask: (optional) int32 Tensor of shape [batch_size, seq_length]. 170 | token_type_ids: (optional) int32 Tensor of shape [batch_size, seq_length]. 171 | use_one_hot_embeddings: (optional) bool. Whether to use one-hot word 172 | embeddings or tf.embedding_lookup() for the word embeddings. On the TPU, 173 | it is much faster if this is True, on the CPU or GPU, it is faster if 174 | this is False. 175 | use_fgm: whether to use FGM 176 | perturbation: FGM perturbation 177 | scope: (optional) variable scope. Defaults to "bert". 178 | 179 | Raises: 180 | ValueError: The config is invalid or one of the input tensor shapes 181 | is invalid. 182 | """ 183 | config = copy.deepcopy(config) 184 | if not is_training: 185 | config.hidden_dropout_prob = 0.0 186 | config.attention_probs_dropout_prob = 0.0 187 | 188 | input_shape = get_shape_list(input_ids, expected_rank=2) 189 | batch_size = input_shape[0] 190 | seq_length = input_shape[1] 191 | 192 | if input_mask is None: 193 | input_mask = tf.ones(shape=[batch_size, seq_length], dtype=tf.int32) 194 | 195 | if token_type_ids is None: 196 | token_type_ids = tf.zeros(shape=[batch_size, seq_length], dtype=tf.int32) 197 | 198 | with tf.variable_scope(scope, default_name="bert", reuse=tf.AUTO_REUSE if use_fgm else None): 199 | with tf.variable_scope("embeddings"): 200 | # Perform embedding lookup on the word ids. 201 | (embedding_output, self.embedding_table) = embedding_lookup( 202 | input_ids=input_ids, 203 | vocab_size=config.vocab_size, 204 | embedding_size=config.hidden_size, 205 | initializer_range=config.initializer_range, 206 | word_embedding_name="word_embeddings", 207 | use_one_hot_embeddings=use_one_hot_embeddings, 208 | embedding_file=embedding_file, 209 | embedding_dropout=embedding_dropout if (is_training and embedding_dropout >0.0) else 0.0) 210 | 211 | if use_fgm and perturbation is not None: 212 | embedding_output = embedding_output + perturbation 213 | else: 214 | embedding_output = embedding_output 215 | 216 | if is_training and spatial_dropout is not None: 217 | embedding_output = spatial_dropout(embedding_output,is_training) 218 | 219 | # Add positional embeddings and token type embeddings, then layer 220 | # normalize and perform dropout. 221 | self.embedding_output, self.position_embeddings = embedding_postprocessor( 222 | input_tensor=embedding_output, 223 | use_token_type=True, 224 | token_type_ids=token_type_ids, 225 | token_type_vocab_size=config.type_vocab_size, 226 | token_type_embedding_name="token_type_embeddings", 227 | use_position_embeddings=True, 228 | position_embedding_name="position_embeddings", 229 | initializer_range=config.initializer_range, 230 | max_position_embeddings=config.max_position_embeddings, 231 | dropout_prob=config.hidden_dropout_prob) 232 | 233 | with tf.variable_scope("encoder"): 234 | # This converts a 2D mask of shape [batch_size, seq_length] to a 3D 235 | # mask of shape [batch_size, seq_length, seq_length] which is used 236 | # for the attention scores. 237 | attention_mask = create_attention_mask_from_input_mask( 238 | input_ids, input_mask) 239 | 240 | # Run the stacked transformer. 241 | # `sequence_output` shape = [batch_size, seq_length, hidden_size]. 242 | self.all_encoder_layers = transformer_model( 243 | input_tensor=self.embedding_output, 244 | attention_mask=attention_mask, 245 | hidden_size=config.hidden_size, 246 | num_hidden_layers=config.num_hidden_layers, 247 | num_attention_heads=config.num_attention_heads, 248 | intermediate_size=config.intermediate_size, 249 | intermediate_act_fn=get_activation(config.hidden_act), 250 | hidden_dropout_prob=config.hidden_dropout_prob, 251 | attention_probs_dropout_prob=config.attention_probs_dropout_prob, 252 | initializer_range=config.initializer_range, 253 | do_return_all_layers=True) 254 | 255 | self.sequence_output = self.all_encoder_layers[-1] 256 | # The "pooler" converts the encoded sequence tensor of shape 257 | # [batch_size, seq_length, hidden_size] to a tensor of shape 258 | # [batch_size, hidden_size]. This is necessary for segment-level 259 | # (or segment-pair-level) classification tasks where we need a fixed 260 | # dimensional representation of the segment. 261 | if electra: 262 | # electra没有pooler层 263 | self.pooled_output = self.sequence_output[:,0] 264 | else: 265 | with tf.variable_scope("pooler"): 266 | # We "pool" the model by simply taking the hidden state corresponding 267 | # to the first token. We assume that this has been pre-trained 268 | first_token_tensor = tf.squeeze(self.sequence_output[:, 0:1, :], axis=1) 269 | self.pooled_output = tf.layers.dense( 270 | first_token_tensor, 271 | config.hidden_size, 272 | activation=tf.tanh, 273 | kernel_initializer=create_initializer(config.initializer_range)) 274 | 275 | def get_pooled_output(self): 276 | return self.pooled_output 277 | 278 | def get_sequence_output(self): 279 | """Gets final hidden layer of encoder. 280 | 281 | Returns: 282 | float Tensor of shape [batch_size, seq_length, hidden_size] corresponding 283 | to the final hidden of the transformer encoder. 284 | """ 285 | return self.sequence_output 286 | 287 | def get_all_encoder_layers(self): 288 | return self.all_encoder_layers 289 | 290 | def get_position_embedding_output(self): 291 | return self.position_embeddings 292 | 293 | def get_embedding_output(self): 294 | """Gets output of the embedding lookup (i.e., input to the transformer). 295 | 296 | Returns: 297 | float Tensor of shape [batch_size, seq_length, hidden_size] corresponding 298 | to the output of the embedding layer, after summing the word 299 | embeddings with the positional embeddings and the token type embeddings, 300 | then performing layer normalization. This is the input to the transformer. 301 | """ 302 | return self.embedding_output 303 | 304 | def get_embedding_table(self): 305 | return self.embedding_table 306 | 307 | 308 | def gelu(input_tensor): 309 | """Gaussian Error Linear Unit. 310 | 311 | This is a smoother version of the RELU. 312 | Original paper: https://arxiv.org/abs/1606.08415 313 | 314 | Args: 315 | input_tensor: float Tensor to perform activation. 316 | 317 | Returns: 318 | `input_tensor` with the GELU activation applied. 319 | """ 320 | cdf = 0.5 * (1.0 + tf.erf(input_tensor / tf.sqrt(2.0))) 321 | return input_tensor * cdf 322 | 323 | 324 | def get_activation(activation_string): 325 | """Maps a string to a Python function, e.g., "relu" => `tf.nn.relu`. 326 | 327 | Args: 328 | activation_string: String name of the activation function. 329 | 330 | Returns: 331 | A Python function corresponding to the activation function. If 332 | `activation_string` is None, empty, or "linear", this will return None. 333 | If `activation_string` is not a string, it will return `activation_string`. 334 | 335 | Raises: 336 | ValueError: The `activation_string` does not correspond to a known 337 | activation. 338 | """ 339 | 340 | # We assume that anything that"s not a string is already an activation 341 | # function, so we just return it. 342 | if not isinstance(activation_string, six.string_types): 343 | return activation_string 344 | 345 | if not activation_string: 346 | return None 347 | 348 | act = activation_string.lower() 349 | if act == "linear": 350 | return None 351 | elif act == "relu": 352 | return tf.nn.relu 353 | elif act == "gelu": 354 | return gelu 355 | elif act == "tanh": 356 | return tf.tanh 357 | else: 358 | raise ValueError("Unsupported activation: %s" % act) 359 | 360 | 361 | def get_assignment_map_from_checkpoint(tvars, init_checkpoint, ignore_names=[], convert_electra=False): 362 | """Compute the union of the current variables and checkpoint variables.""" 363 | assignment_map = {} 364 | initialized_variable_names = {} 365 | 366 | name_to_variable = collections.OrderedDict() 367 | for var in tvars: 368 | name = var.name 369 | m = re.match("^(.*):\\d+$", name) 370 | if m is not None: 371 | name = m.group(1) 372 | if ignore_names and name in ignore_names: 373 | continue 374 | name_to_variable[name] = var 375 | 376 | init_vars = tf.train.list_variables(init_checkpoint) 377 | 378 | assignment_map = collections.OrderedDict() 379 | for x in init_vars: 380 | (name, var) = (x[0], x[1]) 381 | new_name = name 382 | if convert_electra: 383 | new_name = name.replace('electra','bert') 384 | if new_name not in name_to_variable: 385 | continue 386 | assignment_map[name] = new_name 387 | initialized_variable_names[new_name] = 1 388 | initialized_variable_names[new_name + ":0"] = 1 389 | 390 | return (assignment_map, initialized_variable_names) 391 | 392 | 393 | def dropout(input_tensor, dropout_prob): 394 | """Perform dropout. 395 | 396 | Args: 397 | input_tensor: float Tensor. 398 | dropout_prob: Python float. The probability of dropping out a value (NOT of 399 | *keeping* a dimension as in `tf.nn.dropout`). 400 | 401 | Returns: 402 | A version of `input_tensor` with dropout applied. 403 | """ 404 | if dropout_prob is None or dropout_prob == 0.0: 405 | return input_tensor 406 | 407 | output = tf.nn.dropout(input_tensor, 1.0 - dropout_prob) 408 | return output 409 | 410 | 411 | def layer_norm_and_dropout(input_tensor, dropout_prob, name=None): 412 | """Runs layer normalization followed by dropout.""" 413 | output_tensor = layer_norm(input_tensor, name) 414 | output_tensor = dropout(output_tensor, dropout_prob) 415 | return output_tensor 416 | 417 | def create_initializer(initializer_range=0.02): 418 | """Creates a `truncated_normal_initializer` with the given range.""" 419 | return tf.truncated_normal_initializer(stddev=initializer_range) 420 | 421 | def load_pretrained_embedding(embedding_file, vocab_size, embedding_size): 422 | pretrained = np.random.normal(size=(vocab_size,embedding_size)) 423 | for i,line in enumerate(open(embedding_file)): 424 | fields = line.strip().split() 425 | word = fields[0] 426 | ebd = np.asarray(fields[1:]) 427 | if len(ebd) != embedding_size: 428 | tf.logging.warning(f'第{i}行embedding大小为{len(ebd)} != {embedding_size}') 429 | return None 430 | else: 431 | pretrained[i] = ebd 432 | return pretrained 433 | 434 | def embedding_lookup(input_ids, 435 | vocab_size, 436 | embedding_size=128, 437 | initializer_range=0.02, 438 | word_embedding_name="word_embeddings", 439 | use_one_hot_embeddings=False, 440 | embedding_file=None, 441 | embedding_dropout=0.0): 442 | """Looks up words embeddings for id tensor. 443 | 444 | Args: 445 | input_ids: int32 Tensor of shape [batch_size, seq_length] containing word 446 | ids. 447 | vocab_size: int. Size of the embedding vocabulary. 448 | embedding_size: int. Width of the word embeddings. 449 | initializer_range: float. Embedding initialization range. 450 | word_embedding_name: string. Name of the embedding table. 451 | use_one_hot_embeddings: bool. If True, use one-hot method for word 452 | embeddings. If False, use `tf.nn.embedding_lookup()`. One hot is better 453 | for TPUs. 454 | 455 | Returns: 456 | float Tensor of shape [batch_size, seq_length, embedding_size]. 457 | """ 458 | # This function assumes that the input is of shape [batch_size, seq_length, 459 | # num_inputs]. 460 | # 461 | # If the input is a 2D tensor of shape [batch_size, seq_length], we 462 | # reshape to [batch_size, seq_length, 1]. 463 | if input_ids.shape.ndims == 2: 464 | input_ids = tf.expand_dims(input_ids, axis=[-1]) 465 | 466 | if embedding_file: 467 | tf.logging.info(f'从{embedding_file}加载预训练词向量') 468 | pretrained = load_pretrained_embedding(embedding_file,vocab_size,embedding_size) 469 | if pretrained is not None: 470 | initializer = tf.constant_initializer(value=pretrained) 471 | embedding_table = tf.get_variable( 472 | name=word_embedding_name, 473 | initializer=lambda : initializer([vocab_size,embedding_size])) 474 | else: 475 | raise Exception('初始化词向量失败') 476 | else: 477 | embedding_table = tf.get_variable( 478 | name=word_embedding_name, 479 | shape=[vocab_size, embedding_size], 480 | initializer=create_initializer(initializer_range)) 481 | 482 | if embedding_dropout > 0.0: 483 | mask = tf.nn.dropout(tf.ones([vocab_size]),keep_prob=1-embedding_dropout) * (1-embedding_dropout) 484 | mask = tf.expand_dims(mask,1) 485 | embedding_table = mask * embedding_table 486 | 487 | if use_one_hot_embeddings: 488 | flat_input_ids = tf.reshape(input_ids, [-1]) 489 | one_hot_input_ids = tf.one_hot(flat_input_ids, depth=vocab_size) 490 | output = tf.matmul(one_hot_input_ids, embedding_table) 491 | else: 492 | output = tf.nn.embedding_lookup(embedding_table, input_ids) 493 | 494 | input_shape = get_shape_list(input_ids) 495 | 496 | output = tf.reshape(output, 497 | input_shape[0:-1] + [input_shape[-1] * embedding_size]) 498 | return (output, embedding_table) 499 | 500 | 501 | def embedding_postprocessor(input_tensor, 502 | use_token_type=False, 503 | token_type_ids=None, 504 | token_type_vocab_size=16, 505 | token_type_embedding_name="token_type_embeddings", 506 | use_position_embeddings=True, 507 | position_embedding_name="position_embeddings", 508 | initializer_range=0.02, 509 | max_position_embeddings=512, 510 | dropout_prob=0.1): 511 | """Performs various post-processing on a word embedding tensor. 512 | 513 | Args: 514 | input_tensor: float Tensor of shape [batch_size, seq_length, 515 | embedding_size]. 516 | use_token_type: bool. Whether to add embeddings for `token_type_ids`. 517 | token_type_ids: (optional) int32 Tensor of shape [batch_size, seq_length]. 518 | Must be specified if `use_token_type` is True. 519 | token_type_vocab_size: int. The vocabulary size of `token_type_ids`. 520 | token_type_embedding_name: string. The name of the embedding table variable 521 | for token type ids. 522 | use_position_embeddings: bool. Whether to add position embeddings for the 523 | position of each token in the sequence. 524 | position_embedding_name: string. The name of the embedding table variable 525 | for positional embeddings. 526 | initializer_range: float. Range of the weight initialization. 527 | max_position_embeddings: int. Maximum sequence length that might ever be 528 | used with this model. This can be longer than the sequence length of 529 | input_tensor, but cannot be shorter. 530 | dropout_prob: float. Dropout probability applied to the final output tensor. 531 | 532 | Returns: 533 | float tensor with same shape as `input_tensor`. 534 | 535 | Raises: 536 | ValueError: One of the tensor shapes or input values is invalid. 537 | """ 538 | input_shape = get_shape_list(input_tensor, expected_rank=3) 539 | batch_size = input_shape[0] 540 | seq_length = input_shape[1] 541 | width = input_shape[2] 542 | 543 | output = input_tensor 544 | 545 | if use_token_type: 546 | if token_type_ids is None: 547 | raise ValueError("`token_type_ids` must be specified if" 548 | "`use_token_type` is True.") 549 | token_type_table = tf.get_variable( 550 | name=token_type_embedding_name, 551 | shape=[token_type_vocab_size, width], 552 | initializer=create_initializer(initializer_range)) 553 | # This vocab will be small so we always do one-hot here, since it is always 554 | # faster for a small vocabulary. 555 | flat_token_type_ids = tf.reshape(token_type_ids, [-1]) 556 | one_hot_ids = tf.one_hot(flat_token_type_ids, depth=token_type_vocab_size) 557 | token_type_embeddings = tf.matmul(one_hot_ids, token_type_table) 558 | token_type_embeddings = tf.reshape(token_type_embeddings, 559 | [batch_size, seq_length, width]) 560 | output += token_type_embeddings 561 | 562 | if use_position_embeddings: 563 | assert_op = tf.assert_less_equal(seq_length, max_position_embeddings) 564 | with tf.control_dependencies([assert_op]): 565 | full_position_embeddings = tf.get_variable( 566 | name=position_embedding_name, 567 | shape=[max_position_embeddings, width], 568 | initializer=create_initializer(initializer_range)) 569 | # Since the position embedding table is a learned variable, we create it 570 | # using a (long) sequence length `max_position_embeddings`. The actual 571 | # sequence length might be shorter than this, for faster training of 572 | # tasks that do not have long sequences. 573 | # 574 | # So `full_position_embeddings` is effectively an embedding table 575 | # for position [0, 1, 2, ..., max_position_embeddings-1], and the current 576 | # sequence has positions [0, 1, 2, ... seq_length-1], so we can just 577 | # perform a slice. 578 | position_embeddings = tf.slice(full_position_embeddings, [0, 0], 579 | [seq_length, -1]) 580 | num_dims = len(output.shape.as_list()) 581 | 582 | # Only the last two dimensions are relevant (`seq_length` and `width`), so 583 | # we broadcast among the first dimensions, which is typically just 584 | # the batch size. 585 | position_broadcast_shape = [] 586 | for _ in range(num_dims - 2): 587 | position_broadcast_shape.append(1) 588 | position_broadcast_shape.extend([seq_length, width]) 589 | position_embeddings = tf.reshape(position_embeddings, 590 | position_broadcast_shape) 591 | output += position_embeddings 592 | 593 | output = layer_norm_and_dropout(output, dropout_prob) 594 | return output, position_embeddings 595 | 596 | 597 | def create_attention_mask_from_input_mask(from_tensor, to_mask): 598 | """Create 3D attention mask from a 2D tensor mask. 599 | 600 | Args: 601 | from_tensor: 2D or 3D Tensor of shape [batch_size, from_seq_length, ...]. 602 | to_mask: int32 Tensor of shape [batch_size, to_seq_length]. 603 | 604 | Returns: 605 | float Tensor of shape [batch_size, from_seq_length, to_seq_length]. 606 | """ 607 | from_shape = get_shape_list(from_tensor, expected_rank=[2, 3]) 608 | batch_size = from_shape[0] 609 | from_seq_length = from_shape[1] 610 | 611 | to_shape = get_shape_list(to_mask, expected_rank=2) 612 | to_seq_length = to_shape[1] 613 | 614 | to_mask = tf.cast( 615 | tf.reshape(to_mask, [batch_size, 1, to_seq_length]), tf.float32) 616 | 617 | # We don't assume that `from_tensor` is a mask (although it could be). We 618 | # don't actually care if we attend *from* padding tokens (only *to* padding) 619 | # tokens so we create a tensor of all ones. 620 | # 621 | # `broadcast_ones` = [batch_size, from_seq_length, 1] 622 | broadcast_ones = tf.ones( 623 | shape=[batch_size, from_seq_length, 1], dtype=tf.float32) 624 | 625 | # Here we broadcast along two dimensions to create the mask. 626 | mask = broadcast_ones * to_mask 627 | 628 | return mask 629 | 630 | 631 | def attention_layer(from_tensor, 632 | to_tensor, 633 | attention_mask=None, 634 | num_attention_heads=1, 635 | size_per_head=512, 636 | query_act=None, 637 | key_act=None, 638 | value_act=None, 639 | attention_probs_dropout_prob=0.0, 640 | initializer_range=0.02, 641 | do_return_2d_tensor=False, 642 | batch_size=None, 643 | from_seq_length=None, 644 | to_seq_length=None): 645 | """Performs multi-headed attention from `from_tensor` to `to_tensor`. 646 | 647 | This is an implementation of multi-headed attention based on "Attention 648 | is all you Need". If `from_tensor` and `to_tensor` are the same, then 649 | this is self-attention. Each timestep in `from_tensor` attends to the 650 | corresponding sequence in `to_tensor`, and returns a fixed-with vector. 651 | 652 | This function first projects `from_tensor` into a "query" tensor and 653 | `to_tensor` into "key" and "value" tensors. These are (effectively) a list 654 | of tensors of length `num_attention_heads`, where each tensor is of shape 655 | [batch_size, seq_length, size_per_head]. 656 | 657 | Then, the query and key tensors are dot-producted and scaled. These are 658 | softmaxed to obtain attention probabilities. The value tensors are then 659 | interpolated by these probabilities, then concatenated back to a single 660 | tensor and returned. 661 | 662 | In practice, the multi-headed attention are done with transposes and 663 | reshapes rather than actual separate tensors. 664 | 665 | Args: 666 | from_tensor: float Tensor of shape [batch_size, from_seq_length, 667 | from_width]. 668 | to_tensor: float Tensor of shape [batch_size, to_seq_length, to_width]. 669 | attention_mask: (optional) int32 Tensor of shape [batch_size, 670 | from_seq_length, to_seq_length]. The values should be 1 or 0. The 671 | attention scores will effectively be set to -infinity for any positions in 672 | the mask that are 0, and will be unchanged for positions that are 1. 673 | num_attention_heads: int. Number of attention heads. 674 | size_per_head: int. Size of each attention head. 675 | query_act: (optional) Activation function for the query transform. 676 | key_act: (optional) Activation function for the key transform. 677 | value_act: (optional) Activation function for the value transform. 678 | attention_probs_dropout_prob: (optional) float. Dropout probability of the 679 | attention probabilities. 680 | initializer_range: float. Range of the weight initializer. 681 | do_return_2d_tensor: bool. If True, the output will be of shape [batch_size 682 | * from_seq_length, num_attention_heads * size_per_head]. If False, the 683 | output will be of shape [batch_size, from_seq_length, num_attention_heads 684 | * size_per_head]. 685 | batch_size: (Optional) int. If the input is 2D, this might be the batch size 686 | of the 3D version of the `from_tensor` and `to_tensor`. 687 | from_seq_length: (Optional) If the input is 2D, this might be the seq length 688 | of the 3D version of the `from_tensor`. 689 | to_seq_length: (Optional) If the input is 2D, this might be the seq length 690 | of the 3D version of the `to_tensor`. 691 | 692 | Returns: 693 | float Tensor of shape [batch_size, from_seq_length, 694 | num_attention_heads * size_per_head]. (If `do_return_2d_tensor` is 695 | true, this will be of shape [batch_size * from_seq_length, 696 | num_attention_heads * size_per_head]). 697 | 698 | Raises: 699 | ValueError: Any of the arguments or tensor shapes are invalid. 700 | """ 701 | 702 | def transpose_for_scores(input_tensor, batch_size, num_attention_heads, 703 | seq_length, width): 704 | output_tensor = tf.reshape( 705 | input_tensor, [batch_size, seq_length, num_attention_heads, width]) 706 | 707 | output_tensor = tf.transpose(output_tensor, [0, 2, 1, 3]) 708 | return output_tensor 709 | 710 | from_shape = get_shape_list(from_tensor, expected_rank=[2, 3]) 711 | to_shape = get_shape_list(to_tensor, expected_rank=[2, 3]) 712 | 713 | if len(from_shape) != len(to_shape): 714 | raise ValueError( 715 | "The rank of `from_tensor` must match the rank of `to_tensor`.") 716 | 717 | if len(from_shape) == 3: 718 | batch_size = from_shape[0] 719 | from_seq_length = from_shape[1] 720 | to_seq_length = to_shape[1] 721 | elif len(from_shape) == 2: 722 | if (batch_size is None or from_seq_length is None or to_seq_length is None): 723 | raise ValueError( 724 | "When passing in rank 2 tensors to attention_layer, the values " 725 | "for `batch_size`, `from_seq_length`, and `to_seq_length` " 726 | "must all be specified.") 727 | 728 | # Scalar dimensions referenced here: 729 | # B = batch size (number of sequences) 730 | # F = `from_tensor` sequence length 731 | # T = `to_tensor` sequence length 732 | # N = `num_attention_heads` 733 | # H = `size_per_head` 734 | 735 | from_tensor_2d = reshape_to_matrix(from_tensor) 736 | to_tensor_2d = reshape_to_matrix(to_tensor) 737 | 738 | # `query_layer` = [B*F, N*H] 739 | query_layer = tf.layers.dense( 740 | from_tensor_2d, 741 | num_attention_heads * size_per_head, 742 | activation=query_act, 743 | name="query", 744 | kernel_initializer=create_initializer(initializer_range)) 745 | 746 | # `key_layer` = [B*T, N*H] 747 | key_layer = tf.layers.dense( 748 | to_tensor_2d, 749 | num_attention_heads * size_per_head, 750 | activation=key_act, 751 | name="key", 752 | kernel_initializer=create_initializer(initializer_range)) 753 | 754 | # `value_layer` = [B*T, N*H] 755 | value_layer = tf.layers.dense( 756 | to_tensor_2d, 757 | num_attention_heads * size_per_head, 758 | activation=value_act, 759 | name="value", 760 | kernel_initializer=create_initializer(initializer_range)) 761 | 762 | # `query_layer` = [B, N, F, H] 763 | query_layer = transpose_for_scores(query_layer, batch_size, 764 | num_attention_heads, from_seq_length, 765 | size_per_head) 766 | 767 | # `key_layer` = [B, N, T, H] 768 | key_layer = transpose_for_scores(key_layer, batch_size, num_attention_heads, 769 | to_seq_length, size_per_head) 770 | 771 | # Take the dot product between "query" and "key" to get the raw 772 | # attention scores. 773 | # `attention_scores` = [B, N, F, T] 774 | attention_scores = tf.matmul(query_layer, key_layer, transpose_b=True) 775 | attention_scores = tf.multiply(attention_scores, 776 | 1.0 / math.sqrt(float(size_per_head))) 777 | 778 | if attention_mask is not None: 779 | # `attention_mask` = [B, 1, F, T] 780 | attention_mask = tf.expand_dims(attention_mask, axis=[1]) 781 | 782 | # Since attention_mask is 1.0 for positions we want to attend and 0.0 for 783 | # masked positions, this operation will create a tensor which is 0.0 for 784 | # positions we want to attend and -10000.0 for masked positions. 785 | adder = (1.0 - tf.cast(attention_mask, tf.float32)) * -10000.0 786 | 787 | # Since we are adding it to the raw scores before the softmax, this is 788 | # effectively the same as removing these entirely. 789 | attention_scores += adder 790 | 791 | # Normalize the attention scores to probabilities. 792 | # `attention_probs` = [B, N, F, T] 793 | attention_probs = tf.nn.softmax(attention_scores) 794 | 795 | # This is actually dropping out entire tokens to attend to, which might 796 | # seem a bit unusual, but is taken from the original Transformer paper. 797 | attention_probs = dropout(attention_probs, attention_probs_dropout_prob) 798 | 799 | # `value_layer` = [B, T, N, H] 800 | value_layer = tf.reshape( 801 | value_layer, 802 | [batch_size, to_seq_length, num_attention_heads, size_per_head]) 803 | 804 | # `value_layer` = [B, N, T, H] 805 | value_layer = tf.transpose(value_layer, [0, 2, 1, 3]) 806 | 807 | # `context_layer` = [B, N, F, H] 808 | context_layer = tf.matmul(attention_probs, value_layer) 809 | 810 | # `context_layer` = [B, F, N, H] 811 | context_layer = tf.transpose(context_layer, [0, 2, 1, 3]) 812 | 813 | if do_return_2d_tensor: 814 | # `context_layer` = [B*F, N*H] 815 | context_layer = tf.reshape( 816 | context_layer, 817 | [batch_size * from_seq_length, num_attention_heads * size_per_head]) 818 | else: 819 | # `context_layer` = [B, F, N*H] 820 | context_layer = tf.reshape( 821 | context_layer, 822 | [batch_size, from_seq_length, num_attention_heads * size_per_head]) 823 | 824 | return context_layer 825 | 826 | 827 | def transformer_model(input_tensor, 828 | attention_mask=None, 829 | hidden_size=768, 830 | num_hidden_layers=12, 831 | num_attention_heads=12, 832 | intermediate_size=3072, 833 | intermediate_act_fn=gelu, 834 | hidden_dropout_prob=0.1, 835 | attention_probs_dropout_prob=0.1, 836 | initializer_range=0.02, 837 | do_return_all_layers=False): 838 | """Multi-headed, multi-layer Transformer from "Attention is All You Need". 839 | 840 | This is almost an exact implementation of the original Transformer encoder. 841 | 842 | See the original paper: 843 | https://arxiv.org/abs/1706.03762 844 | 845 | Also see: 846 | https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/models/transformer.py 847 | 848 | Args: 849 | input_tensor: float Tensor of shape [batch_size, seq_length, hidden_size]. 850 | attention_mask: (optional) int32 Tensor of shape [batch_size, seq_length, 851 | seq_length], with 1 for positions that can be attended to and 0 in 852 | positions that should not be. 853 | hidden_size: int. Hidden size of the Transformer. 854 | num_hidden_layers: int. Number of layers (blocks) in the Transformer. 855 | num_attention_heads: int. Number of attention heads in the Transformer. 856 | intermediate_size: int. The size of the "intermediate" (a.k.a., feed 857 | forward) layer. 858 | intermediate_act_fn: function. The non-linear activation function to apply 859 | to the output of the intermediate/feed-forward layer. 860 | hidden_dropout_prob: float. Dropout probability for the hidden layers. 861 | attention_probs_dropout_prob: float. Dropout probability of the attention 862 | probabilities. 863 | initializer_range: float. Range of the initializer (stddev of truncated 864 | normal). 865 | do_return_all_layers: Whether to also return all layers or just the final 866 | layer. 867 | 868 | Returns: 869 | float Tensor of shape [batch_size, seq_length, hidden_size], the final 870 | hidden layer of the Transformer. 871 | 872 | Raises: 873 | ValueError: A Tensor shape or parameter is invalid. 874 | """ 875 | if hidden_size % num_attention_heads != 0: 876 | raise ValueError( 877 | "The hidden size (%d) is not a multiple of the number of attention " 878 | "heads (%d)" % (hidden_size, num_attention_heads)) 879 | 880 | attention_head_size = int(hidden_size / num_attention_heads) 881 | input_shape = get_shape_list(input_tensor, expected_rank=3) 882 | batch_size = input_shape[0] 883 | seq_length = input_shape[1] 884 | input_width = input_shape[2] 885 | 886 | # The Transformer performs sum residuals on all layers so the input needs 887 | # to be the same as the hidden size. 888 | if input_width != hidden_size: 889 | raise ValueError("The width of the input tensor (%d) != hidden size (%d)" % 890 | (input_width, hidden_size)) 891 | 892 | # We keep the representation as a 2D tensor to avoid re-shaping it back and 893 | # forth from a 3D tensor to a 2D tensor. Re-shapes are normally free on 894 | # the GPU/CPU but may not be free on the TPU, so we want to minimize them to 895 | # help the optimizer. 896 | prev_output = reshape_to_matrix(input_tensor) 897 | 898 | all_layer_outputs = [] 899 | for layer_idx in range(num_hidden_layers): 900 | with tf.variable_scope("layer_%d" % layer_idx): 901 | layer_input = prev_output 902 | 903 | with tf.variable_scope("attention"): 904 | attention_heads = [] 905 | with tf.variable_scope("self"): 906 | attention_head = attention_layer( 907 | from_tensor=layer_input, 908 | to_tensor=layer_input, 909 | attention_mask=attention_mask, 910 | num_attention_heads=num_attention_heads, 911 | size_per_head=attention_head_size, 912 | attention_probs_dropout_prob=attention_probs_dropout_prob, 913 | initializer_range=initializer_range, 914 | do_return_2d_tensor=True, 915 | batch_size=batch_size, 916 | from_seq_length=seq_length, 917 | to_seq_length=seq_length) 918 | attention_heads.append(attention_head) 919 | 920 | attention_output = None 921 | if len(attention_heads) == 1: 922 | attention_output = attention_heads[0] 923 | else: 924 | # In the case where we have other sequences, we just concatenate 925 | # them to the self-attention head before the projection. 926 | attention_output = tf.concat(attention_heads, axis=-1) 927 | 928 | # Run a linear projection of `hidden_size` then add a residual 929 | # with `layer_input`. 930 | with tf.variable_scope("output"): 931 | attention_output = tf.layers.dense( 932 | attention_output, 933 | hidden_size, 934 | kernel_initializer=create_initializer(initializer_range)) 935 | attention_output = dropout(attention_output, hidden_dropout_prob) 936 | attention_output = layer_norm(attention_output + layer_input) 937 | 938 | # The activation is only applied to the "intermediate" hidden layer. 939 | with tf.variable_scope("intermediate"): 940 | intermediate_output = tf.layers.dense( 941 | attention_output, 942 | intermediate_size, 943 | activation=intermediate_act_fn, 944 | kernel_initializer=create_initializer(initializer_range)) 945 | 946 | # Down-project back to `hidden_size` then add the residual. 947 | with tf.variable_scope("output"): 948 | layer_output = tf.layers.dense( 949 | intermediate_output, 950 | hidden_size, 951 | kernel_initializer=create_initializer(initializer_range)) 952 | layer_output = dropout(layer_output, hidden_dropout_prob) 953 | layer_output = layer_norm(layer_output + attention_output) 954 | prev_output = layer_output 955 | all_layer_outputs.append(layer_output) 956 | 957 | if do_return_all_layers: 958 | final_outputs = [] 959 | for layer_output in all_layer_outputs: 960 | final_output = reshape_from_matrix(layer_output, input_shape) 961 | final_outputs.append(final_output) 962 | return final_outputs 963 | else: 964 | final_output = reshape_from_matrix(prev_output, input_shape) 965 | return final_output 966 | 967 | 968 | def get_shape_list(tensor, expected_rank=None, name=None): 969 | """Returns a list of the shape of tensor, preferring static dimensions. 970 | 971 | Args: 972 | tensor: A tf.Tensor object to find the shape of. 973 | expected_rank: (optional) int. The expected rank of `tensor`. If this is 974 | specified and the `tensor` has a different rank, and exception will be 975 | thrown. 976 | name: Optional name of the tensor for the error message. 977 | 978 | Returns: 979 | A list of dimensions of the shape of tensor. All static dimensions will 980 | be returned as python integers, and dynamic dimensions will be returned 981 | as tf.Tensor scalars. 982 | """ 983 | if name is None: 984 | name = tensor.name 985 | 986 | if expected_rank is not None: 987 | assert_rank(tensor, expected_rank, name) 988 | 989 | shape = tensor.shape.as_list() 990 | 991 | non_static_indexes = [] 992 | for (index, dim) in enumerate(shape): 993 | if dim is None: 994 | non_static_indexes.append(index) 995 | 996 | if not non_static_indexes: 997 | return shape 998 | 999 | dyn_shape = tf.shape(tensor) 1000 | for index in non_static_indexes: 1001 | shape[index] = dyn_shape[index] 1002 | return shape 1003 | 1004 | 1005 | def reshape_to_matrix(input_tensor): 1006 | """Reshapes a >= rank 2 tensor to a rank 2 tensor (i.e., a matrix).""" 1007 | ndims = input_tensor.shape.ndims 1008 | if ndims < 2: 1009 | raise ValueError("Input tensor must have at least rank 2. Shape = %s" % 1010 | (input_tensor.shape)) 1011 | if ndims == 2: 1012 | return input_tensor 1013 | 1014 | width = input_tensor.shape[-1] 1015 | output_tensor = tf.reshape(input_tensor, [-1, width]) 1016 | return output_tensor 1017 | 1018 | 1019 | def reshape_from_matrix(output_tensor, orig_shape_list): 1020 | """Reshapes a rank 2 tensor back to its original rank >= 2 tensor.""" 1021 | if len(orig_shape_list) == 2: 1022 | return output_tensor 1023 | 1024 | output_shape = get_shape_list(output_tensor) 1025 | 1026 | orig_dims = orig_shape_list[0:-1] 1027 | width = output_shape[-1] 1028 | 1029 | return tf.reshape(output_tensor, orig_dims + [width]) 1030 | 1031 | 1032 | def assert_rank(tensor, expected_rank, name=None): 1033 | """Raises an exception if the tensor rank is not of the expected rank. 1034 | 1035 | Args: 1036 | tensor: A tf.Tensor to check the rank of. 1037 | expected_rank: Python integer or list of integers, expected rank. 1038 | name: Optional name of the tensor for the error message. 1039 | 1040 | Raises: 1041 | ValueError: If the expected shape doesn't match the actual shape. 1042 | """ 1043 | if name is None: 1044 | name = tensor.name 1045 | 1046 | expected_rank_dict = {} 1047 | if isinstance(expected_rank, six.integer_types): 1048 | expected_rank_dict[expected_rank] = True 1049 | else: 1050 | for x in expected_rank: 1051 | expected_rank_dict[x] = True 1052 | 1053 | actual_rank = tensor.shape.ndims 1054 | if actual_rank not in expected_rank_dict: 1055 | scope_name = tf.get_variable_scope().name 1056 | raise ValueError( 1057 | "For the tensor `%s` in scope `%s`, the actual rank " 1058 | "`%d` (shape = %s) is not equal to the expected rank `%s`" % 1059 | (name, scope_name, actual_rank, str(tensor.shape), str(expected_rank))) 1060 | --------------------------------------------------------------------------------