├── .dockerignore ├── .gitignore ├── .vscode └── settings.json ├── Dockerfile ├── README.md ├── code ├── assemble.py ├── conlleval.py ├── create_raw_text.py ├── electra-pretrain │ ├── .gitignore │ ├── LICENSE │ ├── README.md │ ├── build_pretraining_dataset.py │ ├── config │ │ ├── base_discriminator_config.json │ │ ├── base_generator_config.json │ │ ├── large_discriminator_config.json │ │ └── large_generator_config.json │ ├── configure_pretraining.py │ ├── model │ │ ├── __init__.py │ │ ├── modeling.py │ │ ├── optimization.py │ │ └── tokenization.py │ ├── pretrain.sh │ ├── pretrain │ │ ├── __init__.py │ │ ├── pretrain_data.py │ │ └── pretrain_helpers.py │ ├── run_pretraining.py │ └── util │ │ ├── __init__.py │ │ ├── training_utils.py │ │ └── utils.py ├── modeling.py ├── optimization.py ├── pipeline.py ├── prepare.sh ├── pretrain.sh ├── run.sh ├── run_biaffine_ner.py ├── simple_run.sh ├── tokenization.py └── utils.py └── user_data └── extra_data ├── dev.txt ├── test.txt └── train.txt /.dockerignore: -------------------------------------------------------------------------------- 1 | .git/ 2 | code/__pycache__ 3 | __pycache__/ 4 | user_data/models/ 5 | user_data/pretrain_tfrecords/ 6 | user_data/texts/ 7 | user_data/tcdata/ 8 | user_data/emb/ 9 | user_data/chinese_roberta_wwm_ext_L-12_H-768_A-12/ -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | tcdata/ 132 | user_data/* 133 | !user_data/extra_data 134 | !user_data/track3 135 | -------------------------------------------------------------------------------- /.vscode/settings.json: -------------------------------------------------------------------------------- 1 | { 2 | "python.pythonPath": "/home/xueyou/.conda/envs/jason_py3/bin/python" 3 | } -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM nvcr.io/nvidia/tensorflow:19.10-py3 2 | 3 | # set noninteractive installation 4 | ENV DEBIAN_FRONTEND=noninteractive 5 | 6 | # install tzdata & curl package 7 | RUN apt update && apt-get install -y tzdata wget curl 8 | 9 | RUN ln -fs /usr/share/zoneinfo/Asia/Shanghai /etc/localtime \ 10 | && dpkg-reconfigure -f noninteractive tzdata 11 | 12 | # pretained models and datas 13 | COPY user_data/electra /user_data/electra 14 | COPY user_data/extra_data /user_data/extra_data 15 | # COPY user_data/track3 /user_data/track3 16 | 17 | # add code 18 | COPY Dockerfile /Dockerfile 19 | 20 | COPY code /code 21 | 22 | WORKDIR /code 23 | CMD ["sh","run.sh"] -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # CCKS2021-赛道二-中文NLP地址要素解析 2 | 3 | 团队:xueyouluo 4 | 5 | 初赛:1 - 93.63 6 | 7 | 复赛:3 - 91.32 8 | 9 | > 这里的代码是复赛的全流程代码,需要在32G显存的卡上才能正常跑通,如果没有这么大的显存,可以考虑将seq_length改成32,以及减小batch size。 10 | 11 | ## 解决方案 12 | 13 | ### 初赛 14 | 15 | 整体还是以预训练+finetune的思路,主要在模型结构、预训练、模型泛化能力提升、数据增强、融合、伪标签、后处理等方面做了优化。 16 | 17 | #### 模型 18 | 19 | 现在的实体识别方案很多,包括BERT+CRF的序列标注、基于Span的方法、基于MRC的方法,我这里使用的是基于BERT的Biaffine结构,直接预测文本构成的所有span的类别。相比单纯基于span预测和基于MRC的预测,Biaffine的结构可以同时考虑所有span之间的关系,从而提高预测的准确率。 20 | 21 | > Biaffine意思双仿射,如果`W*X`是单仿射的话,`X*W*Y`就是双仿射了。本质上就是输入一个长度为`L`的序列,预测一个`L*L*C`的tensor,预测每个span的类别信息。 22 | 23 | 具体来说参考了论文[Named Entity Recognition as Dependency Parsing](https://arxiv.org/abs/2005.07150),但是稍有区别: 24 | 25 | - 纯粹基于bert进行finetune,不利用fasttext、bert等做context embedding抽取,这也是为了简化模型 26 | - 不区分char word的embedding,默认就是char【中文的BERT基本都是char】 27 | - 原来的论文中有上下文的多句话,这里默认都是一句话【数据决定】 28 | - 同时改进了原有greedy的decoding方法,使用基于DAG的动态规划算法找到全局最优解 29 | 30 | 但是这种方法也有一些局限: 31 | 32 | - 对边界判断不是特别准 33 | - 有大量的负样本 34 | 35 | > 原来我也实现过[Biaffine-BERT-NER](https://github.com/xueyouluo/Biaffine-BERT-NER),但这里的版本优化了一些。 36 | 37 | #### 预训练 38 | 39 | 在比较了大部分开源的预训练模型后,哈工大的electra效果比较好,因此我们采用了electra的预训练方法。使用了本赛道的所有数据+赛道三的初赛所有数据,构建了预训练样本,分别继续预训练base和large的模型33K步【大概15个epoch】。 40 | 41 | > 继续预训练模型可以提升1个百分点左右的效果,还是非常有效的。 42 | 43 | #### 泛化能力提升 44 | 45 | 这些应该是属于比较基本的操作了,主要包括: 46 | 47 | - 使用了对抗学习(FGM)的方法,但代价是训练速度慢了一倍 48 | - 在Dropout方面加入了spatial dropout和embedding dropout 49 | - 使用SWA的方法避免局部最优解 50 | 51 | 需要在验证集上调参找到比较合适的值。 52 | 53 | #### 数据增强 54 | 55 | 我们用到了开源的一份地址解析数据,来自[《Neural Chinese Address Parsing》](https://github.com/leodotnet/neural-chinese-address-parsing)。参考赛道二的标注规范,使用规则将数据进行清洗,并用这份数据作为数据增强的语料。同时利用统计信息稍微优化了一下数据,即认为一个span如果被标注次数大于10,并且有一个类别占比不到10%且标注数量小于5就认为是不合理的并将其抛弃。 56 | 57 | 我们使用了同类型实体替换的方法进行数据增强,然后将预训练后的模型在这份数据上finetune。最后用赛道本身的数据进行二次finetune。初赛上,上面的流程走下来可以在dev上达到94.71,线上92.56。 58 | 59 | #### 融合 60 | 61 | 融合的提升非常明显。在融合上,我们使用了electra-base和electra-large两个模型,分别进行预训练和finetune,然后5-fold。 62 | 63 | 最后对实体进行投票,其中base权重1/3,large权重2/3,只选择投票结果大于3的实体作为最终结果。 64 | 65 | > 初赛上,base单独5-fold融合为93.0,large单独5-fold融合为93.477。二者加权融合为93.537。 66 | 67 | #### 伪标签 68 | 69 | 在融合的基础上,我们进一步使用了伪标签,即将上面的融合后预测的测试集结果作为伪标签,重新训练了base模型的一个fold,再进行预测,最终线上可以到93.5920。后面我也实验了训练5-fold的模型,测试下来可以到93.6087。 70 | 71 | #### 后处理 72 | 73 | 我这边后处理比较简单,主要对特殊符号进行了处理,由于一些特殊符号在训练集没有见过,导致模型预测错误。对于包含特殊符号的实体,如果特殊符号是在实体的边界,那么直接去除特殊符号,保留原来的实体类型;如果不是,则去除这个实体。在伪标签结果的基础上加后处理,线上到93.6212。 74 | 75 | #### 实验结果 76 | 77 | | 序号 | 实验 | Dev指标 | 线上指标 | 78 | | :--: | :-------------------------------------: | ------- | :------: | 79 | | 1 | Biaffine + roberta ext | 92.15 | | 80 | | 2 | Biaffine + google bert | 92.33 | | 81 | | 3 | 2 + FGM | 92.79 | | 82 | | 4 | 3 + spatial dropout + embedding dropout | 92.94 | 90.65 | 83 | | 5 | 4 + extra data + finetune | 93.74 | | 84 | | 6 | 5 + 数据增强 | 93.98 | | 85 | | 7 | 6 + roberta ext pretrain | 94.15 | 92.08 | 86 | | 8 | 5 + electra base | 94.19 | | 87 | | 9 | 5 + electra large | 94.32 | 92.13 | 88 | | 10 | 5 + electra base pretrain | 94.71 | 92.56 | 89 | | 11 | 5 + electra large pretrain | 94.54 | | 90 | | 12 | 10 + 5-fold | - | 93.009 | 91 | | 13 | 11 + 5-fold | - | 93.499 | 92 | | 14 | 12 + 13 | - | 93.537 | 93 | | 15 | 14 + pseudo tag | - | 93.62 | 94 | | 16 | 15 + 5-fold | - | 93.63 | 95 | 96 | ### 复赛 97 | 98 | 复赛上我对原来的流程基本没有做什么改动【主要也是我也没想到什么好改进的点了】,就是预训练改了一下。 99 | 100 | 复赛由于线上训练时间12h的限制,我不可能跑那么久的预训练了【我线下训练large的模型花了20多个小时😂】,因此预训练的语料只用了本赛道的数据集+开源的数据集来减少预训练的时间。 101 | 102 | > 唯一非常折腾我的是,large模型在复赛的时候效果一直比不上base模型,可能是预训练不够导致的。 103 | 104 | 我在复赛的时候都是全流程提交的,直接线上调参了。大概的结果如下【都是5-fold】: 105 | 106 | | 序号 | 实验 | 线上指标 | 107 | | :--: | :----------------------: | :------: | 108 | | 1 | Electra-base | 89.15 | 109 | | 2 | Electra-large | 89.58 | 110 | | 3 | Electra-base + pretrain | 90.74 | 111 | | 4 | Electra-large + pretrain | 90.75 | 112 | | 5 | 3 + 4 | 91.08 | 113 | | 6 | 5 + fake 1-fold | 91.31 | 114 | | 7 | 6 + fake 5-fold | 91.32 | 115 | 116 | 最终复赛的结果就是91.32,离第一还是有2个千分点差距的。更多细节就看代码吧,毕竟全都在代码里面了。 117 | 118 | ## 运行 119 | 120 | ### 运行环境 121 | 122 | 我们选择了英伟达提供的[docker](nvcr.io/nvidia/tensorflow:19.10-py3)作为基础镜像进行训练,主要是为了避免配环境的各种问题。 123 | 124 | 具体: 125 | 126 | - Unbuntu == 16.04 127 | - Python == 3.6.8 128 | - GPU V100 32G 129 | - 1.14.0 <= Tensorflow-gpu <= 1.15.* 130 | 131 | ### 数据准备 132 | 133 | #### 赛道数据 134 | 135 | 这里不提供比赛的数据,大家自己下载好放在tcdata目录下。 136 | 137 | #### 预训练模型 138 | 139 | 预训练模型我们使用了哈工大开源的[中文ELECTRA模型](https://github.com/ymcui/Chinese-ELECTRA#%E5%A4%A7%E8%AF%AD%E6%96%99%E7%89%88%E6%96%B0%E7%89%88180g%E6%95%B0%E6%8D%AE),具体为大语料版本的模型: 140 | 141 | - [ELECTRA-180g-large, Chinese](https://drive.google.com/file/d/1P9yAuW0-HR7WvZ2r2weTnx3slo6f5u9q/view?usp=sharing) 142 | 143 | - [ELECTRA-180g-base, Chinese](https://drive.google.com/file/d/1RlmfBgyEwKVBFagafYvJgyCGuj7cTHfh/view?usp=sharing) 144 | 145 | 下载后解压在user_data/electra目录下。 146 | 147 | #### 额外数据 148 | 149 | 下载[neural-chinese-address-parsing](https://github.com/leodotnet/neural-chinese-address-parsing)中data目录下train、dev、test数据到user_data/extra_data目录下。 150 | 151 | #### 目录结构 152 | 153 | ``` 154 | ├── code 155 | │   ├── electra-pretrain 156 | │   └── ... 157 | ├── tcdata 158 | │   ├── dev.conll 159 | │   ├── final_test.txt 160 | │   └── train.conll 161 | ├── user_data 162 | │   ├── electra 163 | │   │   ├── electra_180g_base 164 | │   │   │   ├── base_discriminator_config.json 165 | │   │   │   ├── base_generator_config.json 166 | │   │   │   ├── electra_180g_base.ckpt.data-00000-of-00001 167 | │   │   │   ├── electra_180g_base.ckpt.index 168 | │   │   │   ├── electra_180g_base.ckpt.meta 169 | │   │   │   └── vocab.txt 170 | │   │   └── electra_180g_large 171 | │   │   ├── electra_180g_large.ckpt.data-00000-of-00001 172 | │   │   ├── electra_180g_large.ckpt.index 173 | │   │   ├── electra_180g_large.ckpt.meta 174 | │   │   ├── large_discriminator_config.json 175 | │   │   ├── large_generator_config.json 176 | │   │   └── vocab.txt 177 | │   ├── extra_data 178 | │   │   ├── dev.txt 179 | │   │   ├── test.txt 180 | │   │   └── train.txt 181 | │   └── track3 # 这里可以不需要 182 | │   ├── final_test.txt #这是初赛的测试集 183 | │   ├── Xeon3NLP_round1_test_20210524.txt #可以不用,复赛没有使用这个数据 184 | │   └── Xeon3NLP_round1_train_20210524.txt #可以不用,复赛没有使用这个数据 185 | ``` 186 | 187 | ### 运行 188 | 189 | 在code目录下运行 190 | 191 | ``` 192 | sh run.sh 193 | ``` 194 | 195 | 具体训练细节参考`pipeline.py`文件。 196 | 197 | 也有一个简化版本的,把seq_len改成了32,没有5-fold,自己测试跑下来dev上大概为94。 198 | 199 | ``` 200 | sh simple_run.sh 201 | ``` 202 | -------------------------------------------------------------------------------- /code/assemble.py: -------------------------------------------------------------------------------- 1 | ''' 2 | 模型结果融合 3 | ''' 4 | import re 5 | from collections import Counter, defaultdict 6 | from glob import glob 7 | 8 | from utils import convert_data_format, iob_iobes 9 | 10 | 11 | def refine_entity(w,s,e): 12 | # 去除包含特殊字符的实体 13 | if re.findall('[,。()()]',w): 14 | nw = w.strip(',。()()') 15 | if not nw: 16 | return False,None 17 | else: 18 | start = w.find(nw) 19 | s = s + start 20 | e = s + len(nw) - 1 21 | return True,(s,e) 22 | else: 23 | return True,(s,e) 24 | 25 | def convert(entity, refine=False): 26 | tmp = [] 27 | for k,words in entity.items(): 28 | for w,spans in words.items(): 29 | for span in spans: 30 | if refine: 31 | should_keep,span = refine_entity(w,span[0],span[1]) 32 | if not should_keep: 33 | continue 34 | tmp.append((k,w,span[0],span[1])) 35 | return tmp 36 | 37 | def get_entities(text,tags): 38 | tag_words = [] 39 | word = '' 40 | tag = '' 41 | for i,(c,t) in enumerate(zip(text,tags)): 42 | if t[0] in ['B','S','O']: 43 | if word: 44 | tag_words.append((word,i,tag)) 45 | if t[0] == 'O': 46 | word = '' 47 | tag = '' 48 | continue 49 | word = c 50 | tag = t[2:] 51 | else: 52 | word += c 53 | if word: 54 | tag_words.append((word,i+1,tag)) 55 | 56 | entities = {} 57 | for w,i,t in tag_words: 58 | if t not in entities: 59 | entities[t] = {} 60 | if w in entities[t]: 61 | entities[t][w].append([i-len(w),i-1]) 62 | else: 63 | entities[t][w] = [[i-len(w),i-1]] 64 | return entities 65 | 66 | def check_special(text): 67 | text = re.sub('[\u4e00-\u9fa5]','',text) 68 | text = re.sub('[0A-]','',text) 69 | if text.strip(): 70 | return True 71 | else: 72 | return False 73 | 74 | def merge_by_4_tuple(raw_texts,data,weights,threshold=3.0, refine=False): 75 | ''' 76 | 根据(类型、实体文本、起始位置、结束位置)四元组进行投票确定最终的结果 77 | ''' 78 | new_tags = [] 79 | ent_cnt = 0 80 | special_cnt = 0 81 | check_fail = 0 82 | fail_cnt = 0 83 | 84 | for i,gtags in enumerate(data): 85 | _,text = raw_texts[i] 86 | cnt = Counter() 87 | assert len(weights) == len(gtags), 'weight {} != tags {}'.format(len(weights),len(gtags)) 88 | for j,tags in enumerate(gtags): 89 | entities = convert(get_entities(text,tags)) 90 | ratio = weights[j] 91 | for x in entities: 92 | cnt[x] += ratio 93 | 94 | ntags = ['O'] * len(text) 95 | for m,n in cnt.most_common(): 96 | # k = 类型, w = 实体文本, s = 实体起始位置, e = 实体结束位置 97 | (k,w,s,e) = m 98 | if n < threshold: 99 | fail_cnt += 1 100 | continue 101 | 102 | if refine: 103 | should_keep,span = refine_entity(w,s,e) 104 | if not should_keep: 105 | continue 106 | else: 107 | s,e = span 108 | 109 | # 检查是否有其他实体占据span 110 | if not all(x=='O' for x in ntags[s:e+1]): 111 | continue 112 | 113 | ent_cnt += 1 114 | try: 115 | if check_special(text[s:e+1]): 116 | special_cnt += 1 117 | except: 118 | check_fail += 1 119 | ntags[s:e+1] = ['I-'+k] * (e-s+1) 120 | ntags[s] = 'B-'+k 121 | new_tags.append(iob_iobes(ntags)) 122 | 123 | with open('/tmp/entity_cnt.txt','w') as f: 124 | f.write('fail_cnt - {}, ent_cnt - {}, special_cnt - {}\n'.format(fail_cnt,ent_cnt,special_cnt)) 125 | 126 | return new_tags 127 | 128 | 129 | def assemble_fake(): 130 | base_dir = '../user_data/models' 131 | output_file= '../user_data/tcdata/fake.conll' 132 | 133 | patterns = [ 134 | base_dir + '/k-fold/bif_electra_base_pretrain_fold_*/export/f1_export/result.txt', 135 | base_dir + '/k-fold/bif_electra_large_pretrain_fold_*/export/f1_export/result.txt', 136 | ] 137 | 138 | weights = [1/2] * 5 + [1/2] * 5 139 | threshold = 3.0 140 | refine = True 141 | 142 | data = [] 143 | raw_texts = [] 144 | 145 | for pattern in patterns: 146 | for fname in glob(pattern): 147 | for i,line in enumerate(open(fname)): 148 | idx,text,tags = line.strip().split('\x01') 149 | if len(data) <= i: 150 | data.append([]) 151 | data[i].append(tags.split(' ')) 152 | if len(raw_texts) <= i: 153 | raw_texts.append((idx,text)) 154 | 155 | assert len(data[0]) == len(weights) 156 | new_tags = merge_by_4_tuple(raw_texts,data,weights,threshold,refine) 157 | 158 | seen_texts = set() 159 | 160 | with open(output_file,'w') as f: 161 | for (idx,text),tags in zip(raw_texts,new_tags): 162 | if len(text) != len(tags): 163 | continue 164 | if text in seen_texts: 165 | continue 166 | else: 167 | seen_texts.add(text) 168 | 169 | for c,t in zip(text,tags): 170 | f.write(c + ' ' + t + '\n') 171 | f.write('\n') 172 | 173 | def assemble_final(): 174 | base_dir = '../user_data/models' 175 | output_file= './result.txt' 176 | 177 | patterns = [ 178 | base_dir + '/k-fold/bif_fake_tags_fold_*/export/f1_export/result.txt', 179 | ] 180 | 181 | weights = [1] * 5 182 | threshold = 3.0 183 | refine = True 184 | 185 | data = [] 186 | raw_texts = [] 187 | 188 | for pattern in patterns: 189 | for fname in glob(pattern): 190 | for i,line in enumerate(open(fname)): 191 | idx,text,tags = line.strip().split('\x01') 192 | if len(data) <= i: 193 | data.append([]) 194 | data[i].append(tags.split(' ')) 195 | if len(raw_texts) <= i: 196 | raw_texts.append((idx,text)) 197 | 198 | assert len(data[0]) == len(weights) 199 | new_tags = merge_by_4_tuple(raw_texts,data,weights,threshold,refine) 200 | 201 | with open(output_file,'w') as f: 202 | for (idx,text),tags in zip(raw_texts,new_tags): 203 | assert len(text) == len(tags) 204 | f.write('\x01'.join([idx,text,' '.join(tags)]) + '\n') -------------------------------------------------------------------------------- /code/conlleval.py: -------------------------------------------------------------------------------- 1 | # Python version of the evaluation script from CoNLL'00- 2 | # Originates from: https://github.com/spyysalo/conlleval.py 3 | 4 | 5 | # Intentional differences: 6 | # - accept any space as delimiter by default 7 | # - optional file argument (default STDIN) 8 | # - option to set boundary (-b argument) 9 | # - LaTeX output (-l argument) not supported 10 | # - raw tags (-r argument) not supported 11 | 12 | import sys 13 | import re 14 | import codecs 15 | from collections import defaultdict, namedtuple 16 | 17 | ANY_SPACE = '' 18 | 19 | 20 | class FormatError(Exception): 21 | pass 22 | 23 | Metrics = namedtuple('Metrics', 'tp fp fn prec rec fscore') 24 | 25 | 26 | class EvalCounts(object): 27 | def __init__(self): 28 | self.correct_chunk = 0 # number of correctly identified chunks 29 | self.correct_tags = 0 # number of correct chunk tags 30 | self.found_correct = 0 # number of chunks in corpus 31 | self.found_guessed = 0 # number of identified chunks 32 | self.token_counter = 0 # token counter (ignores sentence breaks) 33 | 34 | # counts by type 35 | self.t_correct_chunk = defaultdict(int) 36 | self.t_found_correct = defaultdict(int) 37 | self.t_found_guessed = defaultdict(int) 38 | 39 | 40 | def parse_args(argv): 41 | import argparse 42 | parser = argparse.ArgumentParser( 43 | description='evaluate tagging results using CoNLL criteria', 44 | formatter_class=argparse.ArgumentDefaultsHelpFormatter 45 | ) 46 | arg = parser.add_argument 47 | arg('-b', '--boundary', metavar='STR', default='-X-', 48 | help='sentence boundary') 49 | arg('-d', '--delimiter', metavar='CHAR', default=ANY_SPACE, 50 | help='character delimiting items in input') 51 | arg('-o', '--otag', metavar='CHAR', default='O', 52 | help='alternative outside tag') 53 | arg('file', nargs='?', default=None) 54 | return parser.parse_args(argv) 55 | 56 | 57 | def parse_tag(t): 58 | m = re.match(r'^([^-]*)-(.*)$', t) 59 | return m.groups() if m else (t, '') 60 | 61 | 62 | def evaluate(iterable, options=None): 63 | if options is None: 64 | options = parse_args([]) # use defaults 65 | 66 | counts = EvalCounts() 67 | num_features = None # number of features per line 68 | in_correct = False # currently processed chunks is correct until now 69 | last_correct = 'O' # previous chunk tag in corpus 70 | last_correct_type = '' # type of previously identified chunk tag 71 | last_guessed = 'O' # previously identified chunk tag 72 | last_guessed_type = '' # type of previous chunk tag in corpus 73 | 74 | for line in iterable: 75 | line = line.rstrip('\r\n') 76 | 77 | if options.delimiter == ANY_SPACE: 78 | features = line.split() 79 | else: 80 | features = line.split(options.delimiter) 81 | 82 | if num_features is None: 83 | num_features = len(features) 84 | elif num_features != len(features) and len(features) != 0: 85 | raise FormatError('unexpected number of features: %d (%d)' % 86 | (len(features), num_features), line) 87 | 88 | if len(features) == 0 or features[0] == options.boundary: 89 | features = [options.boundary, 'O', 'O'] 90 | if len(features) < 3: 91 | raise FormatError('unexpected number of features in line %s' % line) 92 | 93 | guessed, guessed_type = parse_tag(features.pop()) 94 | correct, correct_type = parse_tag(features.pop()) 95 | first_item = features.pop(0) 96 | 97 | if first_item == options.boundary: 98 | guessed = 'O' 99 | 100 | end_correct = end_of_chunk(last_correct, correct, 101 | last_correct_type, correct_type) 102 | end_guessed = end_of_chunk(last_guessed, guessed, 103 | last_guessed_type, guessed_type) 104 | start_correct = start_of_chunk(last_correct, correct, 105 | last_correct_type, correct_type) 106 | start_guessed = start_of_chunk(last_guessed, guessed, 107 | last_guessed_type, guessed_type) 108 | 109 | if in_correct: 110 | if (end_correct and end_guessed and 111 | last_guessed_type == last_correct_type): 112 | in_correct = False 113 | counts.correct_chunk += 1 114 | counts.t_correct_chunk[last_correct_type] += 1 115 | elif (end_correct != end_guessed or guessed_type != correct_type): 116 | in_correct = False 117 | 118 | if start_correct and start_guessed and guessed_type == correct_type: 119 | in_correct = True 120 | 121 | if start_correct: 122 | counts.found_correct += 1 123 | counts.t_found_correct[correct_type] += 1 124 | if start_guessed: 125 | counts.found_guessed += 1 126 | counts.t_found_guessed[guessed_type] += 1 127 | if first_item != options.boundary: 128 | if correct == guessed and guessed_type == correct_type: 129 | counts.correct_tags += 1 130 | counts.token_counter += 1 131 | 132 | last_guessed = guessed 133 | last_correct = correct 134 | last_guessed_type = guessed_type 135 | last_correct_type = correct_type 136 | 137 | if in_correct: 138 | counts.correct_chunk += 1 139 | counts.t_correct_chunk[last_correct_type] += 1 140 | 141 | return counts 142 | 143 | 144 | def uniq(iterable): 145 | seen = set() 146 | return [i for i in iterable if not (i in seen or seen.add(i))] 147 | 148 | 149 | def calculate_metrics(correct, guessed, total): 150 | tp, fp, fn = correct, guessed-correct, total-correct 151 | p = 0 if tp + fp == 0 else 1.*tp / (tp + fp) 152 | r = 0 if tp + fn == 0 else 1.*tp / (tp + fn) 153 | f = 0 if p + r == 0 else 2 * p * r / (p + r) 154 | return Metrics(tp, fp, fn, p, r, f) 155 | 156 | 157 | def metrics(counts): 158 | c = counts 159 | overall = calculate_metrics( 160 | c.correct_chunk, c.found_guessed, c.found_correct 161 | ) 162 | by_type = {} 163 | for t in uniq(list(c.t_found_correct) + list(c.t_found_guessed)): 164 | by_type[t] = calculate_metrics( 165 | c.t_correct_chunk[t], c.t_found_guessed[t], c.t_found_correct[t] 166 | ) 167 | return overall, by_type 168 | 169 | 170 | def report(counts, out=None): 171 | if out is None: 172 | out = sys.stdout 173 | 174 | overall, by_type = metrics(counts) 175 | 176 | c = counts 177 | out.write('processed %d tokens with %d phrases; ' % 178 | (c.token_counter, c.found_correct)) 179 | out.write('found: %d phrases; correct: %d.\n' % 180 | (c.found_guessed, c.correct_chunk)) 181 | 182 | if c.token_counter > 0: 183 | out.write('accuracy: %6.2f%%; ' % 184 | (100.*c.correct_tags/c.token_counter)) 185 | out.write('precision: %6.2f%%; ' % (100.*overall.prec)) 186 | out.write('recall: %6.2f%%; ' % (100.*overall.rec)) 187 | out.write('FB1: %6.2f\n' % (100.*overall.fscore)) 188 | 189 | for i, m in sorted(by_type.items()): 190 | out.write('%17s: ' % i) 191 | out.write('precision: %6.2f%%; ' % (100.*m.prec)) 192 | out.write('recall: %6.2f%%; ' % (100.*m.rec)) 193 | out.write('FB1: %6.2f %d\n' % (100.*m.fscore, c.t_found_guessed[i])) 194 | 195 | 196 | def report_notprint(counts, out=None): 197 | if out is None: 198 | out = sys.stdout 199 | 200 | overall, by_type = metrics(counts) 201 | 202 | c = counts 203 | final_report = [] 204 | line = [] 205 | line.append('processed %d tokens with %d phrases; ' % 206 | (c.token_counter, c.found_correct)) 207 | line.append('found: %d phrases; correct: %d.\n' % 208 | (c.found_guessed, c.correct_chunk)) 209 | final_report.append("".join(line)) 210 | 211 | if c.token_counter > 0: 212 | line = [] 213 | line.append('accuracy: %6.2f%%; ' % 214 | (100.*c.correct_tags/c.token_counter)) 215 | line.append('precision: %6.2f%%; ' % (100.*overall.prec)) 216 | line.append('recall: %6.2f%%; ' % (100.*overall.rec)) 217 | line.append('FB1: %6.2f\n' % (100.*overall.fscore)) 218 | final_report.append("".join(line)) 219 | 220 | for i, m in sorted(by_type.items()): 221 | line = [] 222 | line.append('%17s: ' % i) 223 | line.append('precision: %6.2f%%; ' % (100.*m.prec)) 224 | line.append('recall: %6.2f%%; ' % (100.*m.rec)) 225 | line.append('FB1: %6.2f %d\n' % (100.*m.fscore, c.t_found_guessed[i])) 226 | final_report.append("".join(line)) 227 | return final_report 228 | 229 | 230 | def end_of_chunk(prev_tag, tag, prev_type, type_): 231 | # check if a chunk ended between the previous and current word 232 | # arguments: previous and current chunk tags, previous and current types 233 | chunk_end = False 234 | 235 | if prev_tag == 'E': chunk_end = True 236 | if prev_tag == 'S': chunk_end = True 237 | 238 | if prev_tag == 'B' and tag == 'B': chunk_end = True 239 | if prev_tag == 'B' and tag == 'S': chunk_end = True 240 | if prev_tag == 'B' and tag == 'O': chunk_end = True 241 | if prev_tag == 'I' and tag == 'B': chunk_end = True 242 | if prev_tag == 'I' and tag == 'S': chunk_end = True 243 | if prev_tag == 'I' and tag == 'O': chunk_end = True 244 | 245 | if prev_tag != 'O' and prev_tag != '.' and prev_type != type_: 246 | chunk_end = True 247 | 248 | # these chunks are assumed to have length 1 249 | if prev_tag == ']': chunk_end = True 250 | if prev_tag == '[': chunk_end = True 251 | 252 | return chunk_end 253 | 254 | 255 | def start_of_chunk(prev_tag, tag, prev_type, type_): 256 | # check if a chunk started between the previous and current word 257 | # arguments: previous and current chunk tags, previous and current types 258 | chunk_start = False 259 | 260 | if tag == 'B': chunk_start = True 261 | if tag == 'S': chunk_start = True 262 | 263 | if prev_tag == 'E' and tag == 'E': chunk_start = True 264 | if prev_tag == 'E' and tag == 'I': chunk_start = True 265 | if prev_tag == 'S' and tag == 'E': chunk_start = True 266 | if prev_tag == 'S' and tag == 'I': chunk_start = True 267 | if prev_tag == 'O' and tag == 'E': chunk_start = True 268 | if prev_tag == 'O' and tag == 'I': chunk_start = True 269 | 270 | if tag != 'O' and tag != '.' and prev_type != type_: 271 | chunk_start = True 272 | 273 | # these chunks are assumed to have length 1 274 | if tag == '[': chunk_start = True 275 | if tag == ']': chunk_start = True 276 | 277 | return chunk_start 278 | 279 | 280 | def return_report(input_file): 281 | with codecs.open(input_file, "r", "utf8") as f: 282 | counts = evaluate(f) 283 | return report_notprint(counts) 284 | 285 | 286 | def main(argv): 287 | args = parse_args(argv[1:]) 288 | 289 | if args.file is None: 290 | counts = evaluate(sys.stdin, args) 291 | else: 292 | with open(args.file) as f: 293 | counts = evaluate(f, args) 294 | report(counts) 295 | 296 | if __name__ == '__main__': 297 | sys.exit(main(sys.argv)) -------------------------------------------------------------------------------- /code/create_raw_text.py: -------------------------------------------------------------------------------- 1 | import re 2 | import json 3 | import random 4 | 5 | from collections import Counter,defaultdict 6 | 7 | from utils import normalize, read_data, convert_back_to_bio, convert_data_format, iob_iobes 8 | 9 | random.seed(20190525) 10 | 11 | TCDATA_DIR = '../user_data/tcdata/' 12 | USERDATA_DIR = '../user_data/' 13 | 14 | 15 | def read_conll(fname): 16 | lines = [] 17 | line = '' 18 | for x in open(fname): 19 | x = x.strip() 20 | if not x: 21 | lines.append(line) 22 | line = '' 23 | continue 24 | else: 25 | line += x.split(' ')[0] 26 | return lines 27 | 28 | def read_track3(fname): 29 | lines = [] 30 | for x in open(fname): 31 | x = json.loads(x) 32 | lines.append(x['query']) 33 | for y in x['candidate']: 34 | lines.append(y['text']) 35 | return [normalize(x) for x in lines] 36 | 37 | def create_preatrain_data(): 38 | # 构建预训练语料 39 | data = open(TCDATA_DIR + 'final_test.txt').readlines() 40 | data = [x.strip().split('\x01')[1] for x in data] 41 | train = read_conll(TCDATA_DIR + 'train.conll') 42 | dev = read_conll(TCDATA_DIR + 'dev.conll') 43 | # 复赛没有使用 44 | # train3 = read_track3( 45 | # USERDATA_DIR + 'track3/Xeon3NLP_round1_train_20210524.txt') 46 | # test3 = read_track3( 47 | # USERDATA_DIR + 'track3/Xeon3NLP_round1_test_20210524.txt') 48 | extra_data = read_data([USERDATA_DIR + 'extra_data/train.txt', USERDATA_DIR + 49 | 'extra_data/dev.txt', USERDATA_DIR + 'extra_data/test.txt']) 50 | extra_data = [''.join([x[0] for x in item]) for item in extra_data] 51 | extra_data = [normalize(x) for x in extra_data] 52 | # old_test = open(USERDATA_DIR + 'track3/final_test.txt').readlines() 53 | # old_test = [x.strip().split('\x01')[1] for x in old_test] 54 | 55 | texts = list(set(data+train+dev+extra_data)) 56 | texts = [t for t in texts if t.strip()] 57 | random.shuffle(texts) 58 | 59 | with open(USERDATA_DIR + 'texts/raw_text.txt', 'w') as f: 60 | for x in texts: 61 | f.write(x+'\n') 62 | 63 | def convert_distance(item,tags): 64 | # 根据规则将assit中与距离相关的转换为distance标签 65 | text = item['text'] 66 | spans = [x for x in re.finditer('(0+|(十?[一二三四五六七八九几]+(十|百)?[一二三四五六七八九几]?))米',text)] 67 | for sp in spans: 68 | start,end = sp.span() 69 | if tags[start][2:] == 'assist': 70 | tags[start:end] = ['I-distance'] * (end-start) 71 | tags[start] = 'B-distance' 72 | if end < len(tags) and tags[end][0] == 'I': 73 | tags[end] = 'B' + tags[end][1:] 74 | return tags,spans 75 | 76 | def convert_village(item,tags): 77 | # 根据规则转换village_group标签 78 | text = item['text'] 79 | spans = [x for x in re.finditer('(0+|(十?[一二三四五六七八九])|([一二三四五六七八九]十[一二三四五六七八九]?))[组队社]',text)] 80 | for sp in spans: 81 | start,end = sp.span() 82 | if start > 0 and tags[start-1][2:] == 'community': 83 | tags[start:end] = ['I-village_group'] * (end-start) 84 | tags[start] = 'B-village_group' 85 | if end < len(tags) and tags[end][0] == 'I': 86 | tags[end] = 'B' + tags[end][1:] 87 | return tags, spans 88 | 89 | def convert_intersection(item,tags,pattern): 90 | # 根据在训练验证集出现过的intersection字段对标签进行转换 91 | text = item['text'] 92 | spans = [x for x in re.finditer(pattern,text)] 93 | for sp in spans: 94 | start,end = sp.span() 95 | if tags[start][2:] == 'assist' or text[start:end] == '路口': 96 | if text[start:end] == '路口': 97 | if tags[start][2:] == 'road' and tags[start+1][2:] == 'assist': 98 | start = start + 1 99 | elif tags[start-1][2:] == 'road' and text[start-1] not in ['街','路']: 100 | tags[start] = 'I-road' 101 | start = start + 1 102 | tags[start:end] = ['I-intersection'] * (end-start) 103 | tags[start] = 'B-intersection' 104 | if end < len(tags) and tags[end][0] == 'I': 105 | tags[end] = 'B' + tags[end][1:] 106 | return tags,spans 107 | 108 | def get_intersection_pattern(): 109 | # 根据赛道2的训练数据获取路口的模式匹配 110 | train = read_data(TCDATA_DIR+'train.conll') 111 | dev = read_data(TCDATA_DIR+'dev.conll') 112 | train = [convert_data_format(x) for x in train] 113 | dev = [convert_data_format(x) for x in dev] 114 | 115 | inter_cnt = Counter() 116 | for x in train+dev: 117 | inter = x['label'].get('intersection','') 118 | if inter: 119 | for k in inter: 120 | inter_cnt[k] += 1 121 | 122 | inter_words = [x[0] for x in inter_cnt.most_common() if len(x[0]) > 1] 123 | pattern = '|'.join(['({})'.format(x) for x in inter_words]) 124 | return pattern 125 | 126 | def check_devzone(name): 127 | for x in ['经济开发区','园区','开发区','工业园','工业区','科技园','工业园区','创意园','产业园','软件谷','软件园','电商园','智慧国','智慧园','未来科技城','科创中心','机电城','工业城','商务园']: 128 | if name.endswith(x): 129 | return True 130 | return False 131 | 132 | def convert_data_format_v2(sentence): 133 | word = '' 134 | tag = '' 135 | text = '' 136 | tag_words = [] 137 | for i,(c,t) in enumerate(sentence): 138 | c = normalize(c) 139 | if t[0] in ['B','S','O']: 140 | if word: 141 | tag_words.append((word,len(text),tag)) 142 | if t[0] == 'O': 143 | word = '' 144 | tag = '' 145 | continue 146 | word = c 147 | tag = t[2:] 148 | else: 149 | word += c 150 | text += c 151 | 152 | if word: 153 | tag_words.append((word,len(text),tag)) 154 | 155 | 156 | entities = {} 157 | for w,i,t in tag_words: 158 | if check_devzone(w): 159 | t = 'devzone' 160 | if t not in entities: 161 | entities[t] = {} 162 | if w in entities[t]: 163 | entities[t][w].append([i-len(w),i-1]) 164 | else: 165 | entities[t][w] = [[i-len(w),i-1]] 166 | 167 | return {"text":text,"label":entities} 168 | 169 | def _get_refine_entity(raw_files): 170 | data = read_data(raw_files) 171 | ent_tp_cnt = defaultdict(Counter) 172 | ent_cnt = Counter() 173 | for sentence in data: 174 | entities = convert_data_format(sentence)['label'] 175 | for k in entities: 176 | for name in entities[k]: 177 | ent_tp_cnt[name][k] += 1 178 | ent_cnt[name] += 1 179 | 180 | for name in ent_tp_cnt: 181 | if ent_cnt[name] < 10: 182 | continue 183 | if len(ent_tp_cnt[name]) == 1: 184 | continue 185 | if len(ent_tp_cnt[name]) >= 2: 186 | pop = [] 187 | for tp in ent_tp_cnt[name]: 188 | if ent_tp_cnt[name][tp] / ent_cnt[name] < 0.1 and ent_tp_cnt[name][tp] < 5: 189 | pop.append(tp) 190 | for tp in pop: 191 | ent_tp_cnt[name].pop(tp) 192 | return ent_tp_cnt 193 | 194 | def _fix_data(ent_tp_cnt, update_files, iob=False): 195 | data = read_data(update_files) 196 | new_data = [] 197 | wcnt = 0 198 | for sentence in data: 199 | entities = convert_data_format(sentence)['label'] 200 | new_entities = {} 201 | for k in entities: 202 | for name in entities[k]: 203 | spans = entities[k][name] 204 | cnt = ent_tp_cnt[name] 205 | nk = k 206 | if k not in cnt: 207 | # print(''.join([w[0] for w in sentence])) 208 | try: 209 | nk = ent_tp_cnt[name].most_common(1)[0][0] 210 | except: 211 | # print('no entity', name,ent_tp_cnt[name],k,entities[k]) 212 | continue 213 | # print("wrong:",name,k,'->',nk) 214 | wcnt += 1 215 | new_entities[nk] = {} 216 | new_entities[nk][name] = spans 217 | if iob: 218 | tags = convert_back_to_bio(new_entities,[w[0] for w in sentence]) 219 | else: 220 | tags = iob_iobes(convert_back_to_bio(new_entities,[w[0] for w in sentence])) 221 | new_data.append([(a[0],b) for a,b in zip(sentence,tags)]) 222 | print('# total wrong',wcnt) 223 | return new_data 224 | 225 | def fix_data(): 226 | ent_tp_cnt = _get_refine_entity([TCDATA_DIR + 'train.conll', TCDATA_DIR + 'dev.conll',TCDATA_DIR + 'extra_train.conll']) 227 | extra_files = TCDATA_DIR + 'extra_train.conll' 228 | new_data = _fix_data(ent_tp_cnt,extra_files,iob=True) 229 | with open(TCDATA_DIR + 'extra_train_v2.conll','w') as f: 230 | for s in new_data: 231 | for x in s: 232 | f.write(x[0] + ' ' + x[1] + '\n') 233 | f.write('\n') 234 | 235 | new_data = _fix_data(ent_tp_cnt,TCDATA_DIR + 'train.conll') 236 | with open(TCDATA_DIR + 'train_v2.conll','w') as f: 237 | for s in new_data: 238 | for x in s: 239 | f.write(x[0] + ' ' + x[1] + '\n') 240 | f.write('\n') 241 | new_data = _fix_data(ent_tp_cnt,TCDATA_DIR + 'dev.conll') 242 | with open(TCDATA_DIR + 'dev_v2.conll','w') as f: 243 | for s in new_data: 244 | for x in s: 245 | f.write(x[0] + ' ' + x[1] + '\n') 246 | f.write('\n') 247 | 248 | def create_extra_train_data(): 249 | # 额外的训练数据 250 | # 数据来源:https://github.com/leodotnet/neural-chinese-address-parsing 251 | data = read_data([USERDATA_DIR + 'extra_data/train.txt', USERDATA_DIR + 252 | 'extra_data/dev.txt', USERDATA_DIR + 'extra_data/test.txt']) 253 | pattern = get_intersection_pattern() 254 | 255 | new_data = [] 256 | for sentence in data: 257 | item = convert_data_format_v2(sentence) 258 | tags = convert_back_to_bio(item['label'],item['text']) 259 | 260 | # 对数据标签进行映射 261 | new_tags = [] 262 | for i,t in enumerate(tags): 263 | tt = t[2:] 264 | if tt in ['country','roomno','otherinfo','redundant']: 265 | new_tags.append('O') 266 | elif tt == 'person': 267 | new_tags.append(t[:2] + 'subpoi') 268 | elif tt == 'devZone': 269 | new_tags.append(t[:2] + 'devzone') 270 | elif tt in ['subRoad','subroad']: 271 | new_tags.append(t[:2] + 'road') 272 | elif tt in ['subRoadno','subroadno']: 273 | new_tags.append(t[:2] + 'roadno') 274 | else: 275 | new_tags.append(t) 276 | 277 | # 处理distance 278 | new_tags,_ = convert_distance(item,new_tags) 279 | # 处理village_group 280 | new_tags,_ = convert_village(item, new_tags) 281 | # 处理intersection 282 | new_tags,_ = convert_intersection(item,new_tags,pattern) 283 | 284 | # 两个路之间的和字改成O 285 | spans = re.finditer('与|和',item['text']) 286 | for sp in spans: 287 | start,end = sp.span() 288 | if new_tags[start][2:]=='assist' and start > 0 and new_tags[start-1][2:] == 'road' and start < len(new_tags) and new_tags[start+1][2:] == 'road': 289 | new_tags[start] = 'O' 290 | 291 | # 去除噪声开头 292 | valid_start = ['B-prov','B-city','B-district','B-town','B-road','B-poi','B-devzone','B-community'] 293 | for i,t in enumerate(new_tags): 294 | if t not in valid_start: 295 | continue 296 | break 297 | new_tags = new_tags[i:] 298 | text = item['text'][i:] 299 | 300 | # 去除过短文本 301 | if len(text) <= 2: 302 | continue 303 | 304 | text = normalize(text) 305 | assert len(new_tags) == len(text),(text,new_tags,item,sentence) 306 | s = [(a,b) for a,b in zip(text,new_tags)] 307 | new_data.append(s) 308 | 309 | with open(TCDATA_DIR + 'extra_train.conll','w') as f: 310 | for s in new_data: 311 | for x in s: 312 | f.write(x[0] + ' ' + x[1] + '\n') 313 | f.write('\n') 314 | 315 | if __name__ == '__main__': 316 | print('# create pretrain data') 317 | create_preatrain_data() 318 | print('# create extra data') 319 | create_extra_train_data() 320 | print('# fix wrong data') 321 | fix_data() 322 | -------------------------------------------------------------------------------- /code/electra-pretrain/.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /code/electra-pretrain/LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | 179 | APPENDIX: How to apply the Apache License to your work. 180 | 181 | To apply the Apache License to your work, attach the following 182 | boilerplate notice, with the fields enclosed by brackets "[]" 183 | replaced with your own identifying information. (Don't include 184 | the brackets!) The text should be enclosed in the appropriate 185 | comment syntax for the file format. We also recommend that a 186 | file or class name and description of purpose be included on the 187 | same "printed page" as the copyright notice for easier 188 | identification within third-party archives. 189 | 190 | Copyright [yyyy] [name of copyright owner] 191 | 192 | Licensed under the Apache License, Version 2.0 (the "License"); 193 | you may not use this file except in compliance with the License. 194 | You may obtain a copy of the License at 195 | 196 | http://www.apache.org/licenses/LICENSE-2.0 197 | 198 | Unless required by applicable law or agreed to in writing, software 199 | distributed under the License is distributed on an "AS IS" BASIS, 200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 201 | See the License for the specific language governing permissions and 202 | limitations under the License. 203 | -------------------------------------------------------------------------------- /code/electra-pretrain/README.md: -------------------------------------------------------------------------------- 1 | # Electra Pretrain 2 | 3 | 在哈工大训练的electra基础上使用领域数据继续进行预训练,一般能够提升下游任务效果。 4 | 5 | ## 改动 6 | 7 | - 由于我们的语料是单句粒度,修改数据构建方法,只构建单句的语料 8 | - 针对中文,使用更简单的tokenizer,即将所有字符直接拆分【主要是适配下游的NER任务】 9 | - 修改预训练代码,支持加载预训练的模型的参数 10 | 11 | ## 使用 12 | 13 | 新建个DATA_DIR,然后在里面新建texts目录,将文本数据放入。 14 | 15 | 需要根据自己的语料,修改configure_pretraining的参数,包括max_seq_len,num_train_steps等。 16 | 17 | 运行pretrain.sh【根据自己的实际场景修改参数】。 18 | 19 | > 建议自己阅读run_pretrain.py的代码,理解里面的各种参数配置。 20 | 21 | ## 效果 22 | 23 | 在ccks2021-track2赛道上进行了测试,用track2和track3的数据继续预训练electra-base,训练33k步后,指标为: 24 | 25 | ```python 26 | disc_accuracy = 0.96376425 27 | disc_auc = 0.97588205 28 | disc_loss = 0.11515158 29 | disc_precision = 0.79076445 30 | disc_recall = 0.32165003 31 | global_step = 33000 32 | loss = 6.575825 33 | masked_lm_accuracy = 0.7298883 34 | masked_lm_loss = 1.2599187 35 | sampled_masked_lm_accuracy = 0.6684708 36 | ``` 37 | 38 | 在track2这个NER任务上,直接使用中文的electra-base模型,dev的F1指标为94.19,继续预训练后可以提升到94.74【线上为92.567,单模型】。 -------------------------------------------------------------------------------- /code/electra-pretrain/build_pretraining_dataset.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Writes out text data as tfrecords that ELECTRA can be pre-trained on.""" 17 | 18 | import argparse 19 | import multiprocessing 20 | import os 21 | import random 22 | import time 23 | import tensorflow.compat.v1 as tf 24 | 25 | from model import tokenization 26 | from util import utils 27 | 28 | 29 | def create_int_feature(values): 30 | feature = tf.train.Feature(int64_list=tf.train.Int64List(value=list(values))) 31 | return feature 32 | 33 | 34 | class ExampleBuilder(object): 35 | """Given a stream of input text, creates pretraining examples.""" 36 | 37 | def __init__(self, tokenizer, max_length): 38 | self._tokenizer = tokenizer 39 | self._current_sentences = [] 40 | self._max_length = max_length 41 | 42 | def add_line(self, line): 43 | """Adds a line of text to the current example being built.""" 44 | line = line.strip().replace("\n", " ") 45 | bert_tokens = self._tokenizer.tokenize(line) 46 | bert_tokids = self._tokenizer.convert_tokens_to_ids(bert_tokens) 47 | self._current_sentences.append(bert_tokids) 48 | return self._create_example() 49 | 50 | def _create_example(self): 51 | """Creates a pre-training example from the current list of sentences.""" 52 | first_segment = [] 53 | for sentence in self._current_sentences: 54 | first_segment += sentence 55 | 56 | # trim to max_length while accounting for not-yet-added [CLS]/[SEP] tokens 57 | first_segment = first_segment[:self._max_length - 2] 58 | 59 | # prepare to start building the next example 60 | self._current_sentences = [] 61 | return self._make_tf_example(first_segment, None) 62 | 63 | def _make_tf_example(self, first_segment, second_segment): 64 | """Converts two "segments" of text into a tf.train.Example.""" 65 | vocab = self._tokenizer.vocab 66 | input_ids = [vocab["[CLS]"]] + first_segment + [vocab["[SEP]"]] 67 | segment_ids = [0] * len(input_ids) 68 | if second_segment: 69 | input_ids += second_segment + [vocab["[SEP]"]] 70 | segment_ids += [1] * (len(second_segment) + 1) 71 | input_mask = [1] * len(input_ids) 72 | input_ids += [0] * (self._max_length - len(input_ids)) 73 | input_mask += [0] * (self._max_length - len(input_mask)) 74 | segment_ids += [0] * (self._max_length - len(segment_ids)) 75 | tf_example = tf.train.Example(features=tf.train.Features(feature={ 76 | "input_ids": create_int_feature(input_ids), 77 | "input_mask": create_int_feature(input_mask), 78 | "segment_ids": create_int_feature(segment_ids) 79 | })) 80 | return tf_example 81 | 82 | 83 | class ExampleWriter(object): 84 | """Writes pre-training examples to disk.""" 85 | 86 | def __init__(self, job_id, vocab_file, output_dir, max_seq_length, 87 | num_jobs, blanks_separate_docs, 88 | num_out_files=1): 89 | self._blanks_separate_docs = blanks_separate_docs 90 | tokenizer = tokenization.SimpleTokenizer(vocab_file=vocab_file) 91 | self._example_builder = ExampleBuilder(tokenizer, max_seq_length) 92 | self._writers = [] 93 | for i in range(num_out_files): 94 | if i % num_jobs == job_id: 95 | output_fname = os.path.join( 96 | output_dir, "pretrain_data.tfrecord-{:}-of-{:}".format( 97 | i, num_out_files)) 98 | self._writers.append(tf.io.TFRecordWriter(output_fname)) 99 | self.n_written = 0 100 | 101 | def write_examples(self, input_file): 102 | """Writes out examples from the provided input file.""" 103 | with tf.io.gfile.GFile(input_file) as f: 104 | for line in f: 105 | line = line.strip() 106 | if line or self._blanks_separate_docs: 107 | example = self._example_builder.add_line(line) 108 | if example: 109 | self._writers[self.n_written % len(self._writers)].write( 110 | example.SerializeToString()) 111 | self.n_written += 1 112 | if self.n_written % 5000 == 0: 113 | print('processed',self.n_written) 114 | 115 | def finish(self): 116 | for writer in self._writers: 117 | writer.close() 118 | 119 | 120 | def write_examples(job_id, args): 121 | """A single process creating and writing out pre-processed examples.""" 122 | 123 | def log(*args): 124 | msg = " ".join(map(str, args)) 125 | print("Job {}:".format(job_id), msg) 126 | 127 | log("Creating example writer") 128 | example_writer = ExampleWriter( 129 | job_id=job_id, 130 | vocab_file=args.vocab_file, 131 | output_dir=args.output_dir, 132 | max_seq_length=args.max_seq_length, 133 | num_jobs=args.num_processes, 134 | blanks_separate_docs=args.blanks_separate_docs 135 | ) 136 | log("Writing tf examples") 137 | fnames = sorted(tf.io.gfile.listdir(args.corpus_dir)) 138 | fnames = [f for (i, f) in enumerate(fnames) 139 | if i % args.num_processes == job_id] 140 | random.shuffle(fnames) 141 | start_time = time.time() 142 | for file_no, fname in enumerate(fnames): 143 | if file_no > 0: 144 | elapsed = time.time() - start_time 145 | log("processed {:}/{:} files ({:.1f}%), ELAPSED: {:}s, ETA: {:}s, " 146 | "{:} examples written".format( 147 | file_no, len(fnames), 100.0 * file_no / len(fnames), int(elapsed), 148 | int((len(fnames) - file_no) / (file_no / elapsed)), 149 | example_writer.n_written)) 150 | example_writer.write_examples(os.path.join(args.corpus_dir, fname)) 151 | example_writer.finish() 152 | log("Done!") 153 | 154 | 155 | def main(): 156 | parser = argparse.ArgumentParser(description=__doc__) 157 | parser.add_argument("--corpus-dir", required=True, 158 | help="Location of pre-training text files.") 159 | parser.add_argument("--vocab-file", required=True, 160 | help="Location of vocabulary file.") 161 | parser.add_argument("--output-dir", required=True, 162 | help="Where to write out the tfrecords.") 163 | parser.add_argument("--max-seq-length", default=64, type=int, 164 | help="Number of tokens per example.") 165 | parser.add_argument("--num-processes", default=1, type=int, 166 | help="Parallelize across multiple processes.") 167 | parser.add_argument("--blanks-separate-docs", default=False, type=bool, 168 | help="Whether blank lines indicate document boundaries.") 169 | args = parser.parse_args() 170 | 171 | utils.rmkdir(args.output_dir) 172 | if args.num_processes == 1: 173 | write_examples(0, args) 174 | else: 175 | jobs = [] 176 | for i in range(args.num_processes): 177 | job = multiprocessing.Process(target=write_examples, args=(i, args)) 178 | jobs.append(job) 179 | job.start() 180 | for job in jobs: 181 | job.join() 182 | 183 | 184 | if __name__ == "__main__": 185 | main() 186 | -------------------------------------------------------------------------------- /code/electra-pretrain/config/base_discriminator_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "attention_probs_dropout_prob": 0.1, 3 | "directionality": "bidi", 4 | "embedding_size": 768, 5 | "hidden_act": "gelu", 6 | "hidden_dropout_prob": 0.1, 7 | "hidden_size": 768, 8 | "initializer_range": 0.02, 9 | "intermediate_size": 3072, 10 | "layer_norm_eps": 1e-12, 11 | "max_position_embeddings": 512, 12 | "model_type": "electra", 13 | "num_attention_heads": 12, 14 | "num_hidden_layers": 12, 15 | "pad_token_id": 0, 16 | "summary_activation": "gelu", 17 | "summary_last_dropout": 0.1, 18 | "summary_type": "first", 19 | "summary_use_proj": true, 20 | "type_vocab_size": 2, 21 | "vocab_size": 21128 22 | } -------------------------------------------------------------------------------- /code/electra-pretrain/config/base_generator_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "attention_probs_dropout_prob": 0.1, 3 | "directionality": "bidi", 4 | "embedding_size": 768, 5 | "hidden_act": "gelu", 6 | "hidden_dropout_prob": 0.1, 7 | "hidden_size": 192, 8 | "initializer_range": 0.02, 9 | "intermediate_size": 768, 10 | "layer_norm_eps": 1e-12, 11 | "max_position_embeddings": 512, 12 | "model_type": "electra", 13 | "num_attention_heads": 3, 14 | "num_hidden_layers": 12, 15 | "pad_token_id": 0, 16 | "summary_activation": "gelu", 17 | "summary_last_dropout": 0.1, 18 | "summary_type": "first", 19 | "summary_use_proj": true, 20 | "type_vocab_size": 2, 21 | "vocab_size": 21128 22 | } -------------------------------------------------------------------------------- /code/electra-pretrain/config/large_discriminator_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "attention_probs_dropout_prob": 0.1, 3 | "embedding_size": 1024, 4 | "hidden_act": "gelu", 5 | "hidden_dropout_prob": 0.1, 6 | "hidden_size": 1024, 7 | "initializer_range": 0.02, 8 | "intermediate_size": 4096, 9 | "layer_norm_eps": 1e-12, 10 | "max_position_embeddings": 512, 11 | "model_type": "electra", 12 | "num_attention_heads": 16, 13 | "num_hidden_layers": 24, 14 | "pad_token_id": 0, 15 | "summary_activation": "gelu", 16 | "summary_last_dropout": 0.1, 17 | "summary_type": "first", 18 | "summary_use_proj": true, 19 | "type_vocab_size": 2, 20 | "vocab_size": 21128 21 | } -------------------------------------------------------------------------------- /code/electra-pretrain/config/large_generator_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "attention_probs_dropout_prob": 0.1, 3 | "embedding_size": 1024, 4 | "hidden_act": "gelu", 5 | "hidden_dropout_prob": 0.1, 6 | "hidden_size": 256, 7 | "initializer_range": 0.02, 8 | "intermediate_size": 1024, 9 | "layer_norm_eps": 1e-12, 10 | "max_position_embeddings": 512, 11 | "model_type": "electra", 12 | "num_attention_heads": 4, 13 | "num_hidden_layers": 24, 14 | "pad_token_id": 0, 15 | "summary_activation": "gelu", 16 | "summary_last_dropout": 0.1, 17 | "summary_type": "first", 18 | "summary_use_proj": true, 19 | "type_vocab_size": 2, 20 | "vocab_size": 21128 21 | } -------------------------------------------------------------------------------- /code/electra-pretrain/configure_pretraining.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Config controlling hyperparameters for pre-training ELECTRA.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import os 23 | 24 | 25 | class PretrainingConfig(object): 26 | """Defines pre-training hyperparameters.""" 27 | 28 | def __init__(self, model_name, data_dir, **kwargs): 29 | self.model_name = model_name 30 | self.init_checkpoint = kwargs.get('init_checkpoint',"/nfs/users/xueyou/data/bert_pretrain/electra_180g_base/electra_180g_base.ckpt") 31 | self.embedding_file = kwargs.get('embedding_file',None) 32 | self.debug = False # debug mode for quickly running things 33 | self.do_train = True # pre-train ELECTRA 34 | self.do_eval = False # evaluate generator/discriminator on unlabeled data 35 | 36 | # loss functions 37 | self.electra_objective = True # if False, use the BERT objective instead 38 | self.gen_weight = 1.0 # masked language modeling / generator loss 39 | self.disc_weight = 50.0 # discriminator loss 40 | self.mask_prob = 0.15 # percent of input tokens to mask out / replace 41 | 42 | # optimization 43 | self.learning_rate = 2e-4 44 | self.lr_decay_power = 1.0 # linear weight decay by default 45 | self.weight_decay_rate = 0.01 46 | self.num_warmup_steps = 700 47 | self.use_amp = False 48 | self.accumulation_step = 1 49 | 50 | # training settings 51 | self.iterations_per_loop = 200 52 | self.save_checkpoints_steps = 30000 53 | self.num_train_steps = 7000 54 | self.num_eval_steps = 100 55 | 56 | # model settings 57 | self.model_size = "large" # one of "small", "base", or "large" 58 | # override the default transformer hparams for the provided model size; see 59 | # modeling.BertConfig for the possible hparams and util.training_utils for 60 | # the defaults 61 | self.model_hparam_overrides = ( 62 | kwargs["model_hparam_overrides"] 63 | if "model_hparam_overrides" in kwargs else {}) 64 | self.embedding_size = None # bert hidden size by default 65 | self.vocab_size = 21128 # number of tokens in the vocabulary 66 | self.do_lower_case = True # lowercase the input? 67 | 68 | # generator settings 69 | self.uniform_generator = False # generator is uniform at random 70 | self.untied_generator_embeddings = False # tie generator/discriminator 71 | # token embeddings? 72 | self.untied_generator = True # tie all generator/discriminator weights? 73 | self.generator_layers = 1.0 # frac of discriminator layers for generator 74 | self.generator_hidden_size = 0.25 # frac of discrim hidden size for gen 75 | self.disallow_correct = False # force the generator to sample incorrect 76 | # tokens (so 15% of tokens are always 77 | # fake) 78 | self.temperature = 1.0 # temperature for sampling from generator 79 | 80 | # batch sizes 81 | self.max_seq_length = 64 82 | self.train_batch_size = 32 83 | self.eval_batch_size = 128 84 | 85 | # TPU settings 86 | self.use_tpu = False 87 | self.num_tpu_cores = 1 88 | self.tpu_job_name = None 89 | self.tpu_name = None # cloud TPU to use for training 90 | self.tpu_zone = None # GCE zone where the Cloud TPU is located in 91 | self.gcp_project = None # project name for the Cloud TPU-enabled project 92 | 93 | # default locations of data files 94 | self.pretrain_tfrecords = os.path.join( 95 | data_dir, "pretrain_tfrecords/pretrain_data.tfrecord*") 96 | self.vocab_file = kwargs.get('vocab_file','/nfs/users/xueyou/data/bert_pretrain/electra_180g_base/vocab.txt') 97 | self.model_dir = os.path.join(data_dir, "models", model_name) 98 | results_dir = os.path.join(self.model_dir, "results") 99 | self.results_txt = os.path.join(results_dir, "unsup_results.txt") 100 | self.results_pkl = os.path.join(results_dir, "unsup_results.pkl") 101 | 102 | # update defaults with passed-in hyperparameters 103 | self.update(kwargs) 104 | 105 | self.max_predictions_per_seq = int((self.mask_prob + 0.005) * 106 | self.max_seq_length) 107 | 108 | # debug-mode settings 109 | if self.debug: 110 | self.train_batch_size = 8 111 | self.num_train_steps = 20 112 | self.eval_batch_size = 4 113 | self.iterations_per_loop = 1 114 | self.num_eval_steps = 2 115 | 116 | # defaults for different-sized model 117 | if self.model_size == "small": 118 | self.embedding_size = 256 119 | if self.model_size == "base": 120 | self.embedding_size = 768 121 | 122 | if self.model_size == 'large': 123 | self.embedding_size = 1024 124 | 125 | # passed-in-arguments override (for example) debug-mode defaults 126 | self.update(kwargs) 127 | 128 | def update(self, kwargs): 129 | for k, v in kwargs.items(): 130 | if k not in self.__dict__: 131 | raise ValueError("Unknown hparam " + k) 132 | self.__dict__[k] = v 133 | -------------------------------------------------------------------------------- /code/electra-pretrain/model/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. -------------------------------------------------------------------------------- /code/electra-pretrain/model/optimization.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Functions and classes related to optimization (weight updates). 17 | Modified from the original BERT code to allow for having separate learning 18 | rates for different layers of the network. 19 | """ 20 | 21 | from __future__ import absolute_import 22 | from __future__ import division 23 | from __future__ import print_function 24 | 25 | import collections 26 | import re 27 | import tensorflow.compat.v1 as tf 28 | 29 | 30 | def create_optimizer( 31 | loss, learning_rate, num_train_steps, weight_decay_rate=0.0, use_tpu=False, 32 | warmup_steps=0, warmup_proportion=0, lr_decay_power=1.0, 33 | layerwise_lr_decay_power=-1, n_transformer_layers=None, 34 | amp=False,accumulation_step=1): 35 | """Creates an optimizer and training op.""" 36 | global_step = tf.train.get_or_create_global_step() 37 | learning_rate = tf.train.polynomial_decay( 38 | learning_rate, 39 | global_step, 40 | num_train_steps, 41 | end_learning_rate=0.0, 42 | power=lr_decay_power, 43 | cycle=False) 44 | warmup_steps = max(num_train_steps * warmup_proportion, warmup_steps) 45 | learning_rate *= tf.minimum( 46 | 1.0, tf.cast(global_step, tf.float32) / tf.cast(warmup_steps, tf.float32)) 47 | 48 | if layerwise_lr_decay_power > 0: 49 | learning_rate = _get_layer_lrs(learning_rate, layerwise_lr_decay_power, 50 | n_transformer_layers) 51 | optimizer = AdamWeightDecayOptimizer( 52 | learning_rate=learning_rate, 53 | weight_decay_rate=weight_decay_rate, 54 | beta_1=0.9, 55 | beta_2=0.999, 56 | epsilon=1e-6, 57 | exclude_from_weight_decay=["LayerNorm", "layer_norm", "bias"]) 58 | if use_tpu: 59 | optimizer = tf.tpu.CrossShardOptimizer(optimizer) 60 | 61 | tvars = tf.trainable_variables() 62 | 63 | if amp: 64 | optimizer = tf.train.experimental.enable_mixed_precision_graph_rewrite(optimizer) 65 | 66 | grads_and_vars = optimizer.compute_gradients(loss * 1.0 / accumulation_step, tvars) 67 | 68 | if accumulation_step > 1: 69 | print('### Using Gradient Accumulation with {} ###'.format(accumulation_step)) 70 | local_step = tf.get_variable(name="local_step", shape=[], dtype=tf.int32, trainable=False, 71 | initializer=tf.zeros_initializer) 72 | batch_finite = tf.get_variable(name="batch_finite", shape=[], dtype=tf.bool, trainable=False, 73 | initializer=tf.ones_initializer) 74 | accum_vars = [tf.get_variable( 75 | name=tvar.name.split(":")[0] + "/accum", 76 | shape=tvar.shape.as_list(), 77 | dtype=tf.float32, 78 | trainable=False, 79 | initializer=tf.zeros_initializer()) for tvar in tf.trainable_variables()] 80 | 81 | reset_step = tf.cast(tf.math.equal(local_step % accumulation_step, 0), dtype=tf.bool) 82 | local_step = tf.cond(reset_step, lambda:local_step.assign(tf.ones_like(local_step)), lambda:local_step.assign_add(1)) 83 | 84 | grads_and_vars_and_accums = [(gv[0],gv[1],accum_vars[i]) for i, gv in enumerate(grads_and_vars) if gv[0] is not None] 85 | grads, tvars, accum_vars = list(zip(*grads_and_vars_and_accums)) 86 | 87 | all_are_finite = tf.reduce_all([tf.reduce_all(tf.is_finite(g)) for g in grads]) if amp else tf.constant(True, dtype=tf.bool) 88 | batch_finite = tf.cond(reset_step, 89 | lambda: batch_finite.assign(tf.math.logical_and(tf.constant(True, dtype=tf.bool), all_are_finite)), 90 | lambda: batch_finite.assign(tf.math.logical_and(batch_finite, all_are_finite))) 91 | 92 | # This is how the model was pre-trained. 93 | # ensure global norm is a finite number 94 | # to prevent clip_by_global_norm from having a hizzy fit. 95 | (clipped_grads, _) = tf.clip_by_global_norm( 96 | grads, clip_norm=1.0, 97 | use_norm=tf.cond( 98 | all_are_finite, 99 | lambda: tf.global_norm(grads), 100 | lambda: tf.constant(1.0))) 101 | 102 | accum_vars = tf.cond(reset_step, 103 | lambda: [accum_vars[i].assign(grad) for i, grad in enumerate(clipped_grads)], 104 | lambda: [accum_vars[i].assign_add(grad) for i, grad in enumerate(clipped_grads)]) 105 | 106 | def update(accum_vars): 107 | return optimizer.apply_gradients(list(zip(accum_vars, tvars))) 108 | 109 | update_step = tf.identity(tf.cast(tf.math.equal(local_step % accumulation_step, 0), dtype=tf.bool), name="update_step") 110 | update_op = tf.cond(update_step, 111 | lambda: update(accum_vars), lambda: tf.no_op()) 112 | 113 | new_global_step = tf.cond(tf.math.logical_and(update_step, batch_finite), 114 | lambda: global_step+1, 115 | lambda: global_step) 116 | new_global_step = tf.identity(new_global_step, name='step_update') 117 | train_op = tf.group(update_op, [global_step.assign(new_global_step)]) 118 | 119 | else: 120 | grads_and_vars = [(g, v) for g, v in grads_and_vars if g is not None] 121 | grads, tvars = list(zip(*grads_and_vars)) 122 | all_are_finite = tf.reduce_all( 123 | [tf.reduce_all(tf.is_finite(g)) for g in grads]) if amp else tf.constant(True, dtype=tf.bool) 124 | 125 | # This is how the model was pre-trained. 126 | # ensure global norm is a finite number 127 | # to prevent clip_by_global_norm from having a hizzy fit. 128 | (clipped_grads, _) = tf.clip_by_global_norm( 129 | grads, clip_norm=1.0, 130 | use_norm=tf.cond( 131 | all_are_finite, 132 | lambda: tf.global_norm(grads), 133 | lambda: tf.constant(1.0))) 134 | 135 | train_op = optimizer.apply_gradients( 136 | list(zip(clipped_grads, tvars))) 137 | 138 | new_global_step = tf.cond(all_are_finite, lambda: global_step + 1, lambda: global_step) 139 | new_global_step = tf.identity(new_global_step, name='step_update') 140 | train_op = tf.group(train_op, [global_step.assign(new_global_step)]) 141 | 142 | # grads = tf.gradients(loss, tvars) 143 | # (grads, _) = tf.clip_by_global_norm(grads, clip_norm=1.0) 144 | # train_op = optimizer.apply_gradients( 145 | # zip(grads, tvars), global_step=global_step) 146 | # new_global_step = global_step + 1 147 | # train_op = tf.group(train_op, [global_step.assign(new_global_step)]) 148 | return train_op 149 | 150 | 151 | class AdamWeightDecayOptimizer(tf.train.Optimizer): 152 | """A basic Adam optimizer that includes "correct" L2 weight decay.""" 153 | 154 | def __init__(self, 155 | learning_rate, 156 | weight_decay_rate=0.0, 157 | beta_1=0.9, 158 | beta_2=0.999, 159 | epsilon=1e-6, 160 | exclude_from_weight_decay=None, 161 | name="AdamWeightDecayOptimizer"): 162 | """Constructs a AdamWeightDecayOptimizer.""" 163 | super(AdamWeightDecayOptimizer, self).__init__(False, name) 164 | 165 | self.learning_rate = learning_rate 166 | self.weight_decay_rate = weight_decay_rate 167 | self.beta_1 = beta_1 168 | self.beta_2 = beta_2 169 | self.epsilon = epsilon 170 | self.exclude_from_weight_decay = exclude_from_weight_decay 171 | 172 | def _apply_gradients(self, grads_and_vars, learning_rate): 173 | """See base class.""" 174 | assignments = [] 175 | for (grad, param) in grads_and_vars: 176 | if grad is None or param is None: 177 | continue 178 | 179 | param_name = self._get_variable_name(param.name) 180 | 181 | m = tf.get_variable( 182 | name=param_name + "/adam_m", 183 | shape=param.shape.as_list(), 184 | dtype=tf.float32, 185 | trainable=False, 186 | initializer=tf.zeros_initializer()) 187 | v = tf.get_variable( 188 | name=param_name + "/adam_v", 189 | shape=param.shape.as_list(), 190 | dtype=tf.float32, 191 | trainable=False, 192 | initializer=tf.zeros_initializer()) 193 | 194 | # Standard Adam update. 195 | next_m = ( 196 | tf.multiply(self.beta_1, m) + tf.multiply(1.0 - self.beta_1, grad)) 197 | next_v = ( 198 | tf.multiply(self.beta_2, v) + tf.multiply(1.0 - self.beta_2, 199 | tf.square(grad))) 200 | update = next_m / (tf.sqrt(next_v) + self.epsilon) 201 | 202 | # Just adding the square of the weights to the loss function is *not* 203 | # the correct way of using L2 regularization/weight decay with Adam, 204 | # since that will interact with the m and v parameters in strange ways. 205 | # 206 | # Instead we want ot decay the weights in a manner that doesn't interact 207 | # with the m/v parameters. This is equivalent to adding the square 208 | # of the weights to the loss with plain (non-momentum) SGD. 209 | if self.weight_decay_rate > 0: 210 | if self._do_use_weight_decay(param_name): 211 | update += self.weight_decay_rate * param 212 | 213 | update_with_lr = learning_rate * update 214 | next_param = param - update_with_lr 215 | 216 | assignments.extend( 217 | [param.assign(next_param), 218 | m.assign(next_m), 219 | v.assign(next_v)]) 220 | 221 | return assignments 222 | 223 | def apply_gradients(self, grads_and_vars, global_step=None, name=None): 224 | if isinstance(self.learning_rate, dict): 225 | key_to_grads_and_vars = {} 226 | for grad, var in grads_and_vars: 227 | update_for_var = False 228 | for key in self.learning_rate: 229 | if key in var.name: 230 | update_for_var = True 231 | if key not in key_to_grads_and_vars: 232 | key_to_grads_and_vars[key] = [] 233 | key_to_grads_and_vars[key].append((grad, var)) 234 | if not update_for_var: 235 | raise ValueError("No learning rate specified for variable", var) 236 | assignments = [] 237 | for key, key_grads_and_vars in key_to_grads_and_vars.items(): 238 | assignments += self._apply_gradients(key_grads_and_vars, 239 | self.learning_rate[key]) 240 | else: 241 | assignments = self._apply_gradients(grads_and_vars, self.learning_rate) 242 | return tf.group(*assignments, name=name) 243 | 244 | def _do_use_weight_decay(self, param_name): 245 | """Whether to use L2 weight decay for `param_name`.""" 246 | if not self.weight_decay_rate: 247 | return False 248 | if self.exclude_from_weight_decay: 249 | for r in self.exclude_from_weight_decay: 250 | if re.search(r, param_name) is not None: 251 | return False 252 | return True 253 | 254 | def _get_variable_name(self, param_name): 255 | """Get the variable name from the tensor name.""" 256 | m = re.match("^(.*):\\d+$", param_name) 257 | if m is not None: 258 | param_name = m.group(1) 259 | return param_name 260 | 261 | 262 | def _get_layer_lrs(learning_rate, layer_decay, n_layers): 263 | """Have lower learning rates for layers closer to the input.""" 264 | key_to_depths = collections.OrderedDict({ 265 | "/embeddings/": 0, 266 | "/embeddings_project/": 0, 267 | "task_specific/": n_layers + 2, 268 | }) 269 | for layer in range(n_layers): 270 | key_to_depths["encoder/layer_" + str(layer) + "/"] = layer + 1 271 | return { 272 | key: learning_rate * (layer_decay ** (n_layers + 2 - depth)) 273 | for key, depth in key_to_depths.items() 274 | } 275 | -------------------------------------------------------------------------------- /code/electra-pretrain/model/tokenization.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tokenization classes, the same as used for BERT.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import collections 23 | import unicodedata 24 | import six 25 | import tensorflow.compat.v1 as tf 26 | 27 | 28 | 29 | def convert_to_unicode(text): 30 | """Converts `text` to Unicode (if it's not already), assuming utf-8 input.""" 31 | if six.PY3: 32 | if isinstance(text, str): 33 | return text 34 | elif isinstance(text, bytes): 35 | return text.decode("utf-8", "ignore") 36 | else: 37 | raise ValueError("Unsupported string type: %s" % (type(text))) 38 | elif six.PY2: 39 | if isinstance(text, str): 40 | return text.decode("utf-8", "ignore") 41 | elif isinstance(text, unicode): 42 | return text 43 | else: 44 | raise ValueError("Unsupported string type: %s" % (type(text))) 45 | else: 46 | raise ValueError("Not running on Python2 or Python 3?") 47 | 48 | 49 | def printable_text(text): 50 | """Returns text encoded in a way suitable for print or `tf.logging`.""" 51 | 52 | # These functions want `str` for both Python2 and Python3, but in one case 53 | # it's a Unicode string and in the other it's a byte string. 54 | if six.PY3: 55 | if isinstance(text, str): 56 | return text 57 | elif isinstance(text, bytes): 58 | return text.decode("utf-8", "ignore") 59 | else: 60 | raise ValueError("Unsupported string type: %s" % (type(text))) 61 | elif six.PY2: 62 | if isinstance(text, str): 63 | return text 64 | elif isinstance(text, unicode): 65 | return text.encode("utf-8") 66 | else: 67 | raise ValueError("Unsupported string type: %s" % (type(text))) 68 | else: 69 | raise ValueError("Not running on Python2 or Python 3?") 70 | 71 | 72 | def load_vocab(vocab_file): 73 | """Loads a vocabulary file into a dictionary.""" 74 | vocab = collections.OrderedDict() 75 | index = 0 76 | with tf.io.gfile.GFile(vocab_file, "r") as reader: 77 | while True: 78 | token = convert_to_unicode(reader.readline()) 79 | if not token: 80 | break 81 | token = token.strip() 82 | vocab[token] = index 83 | index += 1 84 | return vocab 85 | 86 | 87 | def convert_by_vocab(vocab, items): 88 | """Converts a sequence of [tokens|ids] using the vocab.""" 89 | output = [] 90 | for item in items: 91 | output.append(vocab[item]) 92 | return output 93 | 94 | 95 | def convert_tokens_to_ids(vocab, tokens): 96 | return convert_by_vocab(vocab, tokens) 97 | 98 | 99 | def convert_ids_to_tokens(inv_vocab, ids): 100 | return convert_by_vocab(inv_vocab, ids) 101 | 102 | 103 | def whitespace_tokenize(text): 104 | """Runs basic whitespace cleaning and splitting on a piece of text.""" 105 | text = text.strip() 106 | if not text: 107 | return [] 108 | tokens = text.split() 109 | return tokens 110 | 111 | class SimpleTokenizer(object): 112 | def __init__(self, vocab_file): 113 | self.vocab = load_vocab(vocab_file) 114 | self.inv_vocab = {v: k for k, v in self.vocab.items()} 115 | 116 | def tokenize(self, text): 117 | text = text.lower() 118 | return [token if token in self.vocab else '[UNK]' for token in text.strip()] 119 | 120 | def convert_tokens_to_ids(self, tokens): 121 | return convert_by_vocab(self.vocab, tokens) 122 | 123 | def convert_ids_to_tokens(self, ids): 124 | return convert_by_vocab(self.inv_vocab, ids) 125 | 126 | 127 | class FullTokenizer(object): 128 | """Runs end-to-end tokenziation.""" 129 | 130 | def __init__(self, vocab_file, do_lower_case=True): 131 | self.vocab = load_vocab(vocab_file) 132 | self.inv_vocab = {v: k for k, v in self.vocab.items()} 133 | self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case) 134 | self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab) 135 | 136 | def tokenize(self, text): 137 | split_tokens = [] 138 | for token in self.basic_tokenizer.tokenize(text): 139 | for sub_token in self.wordpiece_tokenizer.tokenize(token): 140 | split_tokens.append(sub_token) 141 | 142 | return split_tokens 143 | 144 | def convert_tokens_to_ids(self, tokens): 145 | return convert_by_vocab(self.vocab, tokens) 146 | 147 | def convert_ids_to_tokens(self, ids): 148 | return convert_by_vocab(self.inv_vocab, ids) 149 | 150 | 151 | class BasicTokenizer(object): 152 | """Runs basic tokenization (punctuation splitting, lower casing, etc.).""" 153 | 154 | def __init__(self, do_lower_case=True): 155 | """Constructs a BasicTokenizer. 156 | 157 | Args: 158 | do_lower_case: Whether to lower case the input. 159 | """ 160 | self.do_lower_case = do_lower_case 161 | 162 | def tokenize(self, text): 163 | """Tokenizes a piece of text.""" 164 | text = convert_to_unicode(text) 165 | text = self._clean_text(text) 166 | 167 | # This was added on November 1st, 2018 for the multilingual and Chinese 168 | # models. This is also applied to the English models now, but it doesn't 169 | # matter since the English models were not trained on any Chinese data 170 | # and generally don't have any Chinese data in them (there are Chinese 171 | # characters in the vocabulary because Wikipedia does have some Chinese 172 | # words in the English Wikipedia.). 173 | text = self._tokenize_chinese_chars(text) 174 | 175 | orig_tokens = whitespace_tokenize(text) 176 | split_tokens = [] 177 | for token in orig_tokens: 178 | if self.do_lower_case: 179 | token = token.lower() 180 | token = self._run_strip_accents(token) 181 | split_tokens.extend(self._run_split_on_punc(token)) 182 | 183 | output_tokens = whitespace_tokenize(" ".join(split_tokens)) 184 | return output_tokens 185 | 186 | def _run_strip_accents(self, text): 187 | """Strips accents from a piece of text.""" 188 | text = unicodedata.normalize("NFD", text) 189 | output = [] 190 | for char in text: 191 | cat = unicodedata.category(char) 192 | if cat == "Mn": 193 | continue 194 | output.append(char) 195 | return "".join(output) 196 | 197 | def _run_split_on_punc(self, text): 198 | """Splits punctuation on a piece of text.""" 199 | chars = list(text) 200 | i = 0 201 | start_new_word = True 202 | output = [] 203 | while i < len(chars): 204 | char = chars[i] 205 | if _is_punctuation(char): 206 | output.append([char]) 207 | start_new_word = True 208 | else: 209 | if start_new_word: 210 | output.append([]) 211 | start_new_word = False 212 | output[-1].append(char) 213 | i += 1 214 | 215 | return ["".join(x) for x in output] 216 | 217 | def _tokenize_chinese_chars(self, text): 218 | """Adds whitespace around any CJK character.""" 219 | output = [] 220 | for char in text: 221 | cp = ord(char) 222 | if self._is_chinese_char(cp): 223 | output.append(" ") 224 | output.append(char) 225 | output.append(" ") 226 | else: 227 | output.append(char) 228 | return "".join(output) 229 | 230 | def _is_chinese_char(self, cp): 231 | """Checks whether CP is the codepoint of a CJK character.""" 232 | # This defines a "chinese character" as anything in the CJK Unicode block: 233 | # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block) 234 | # 235 | # Note that the CJK Unicode block is NOT all Japanese and Korean characters, 236 | # despite its name. The modern Korean Hangul alphabet is a different block, 237 | # as is Japanese Hiragana and Katakana. Those alphabets are used to write 238 | # space-separated words, so they are not treated specially and handled 239 | # like the all of the other languages. 240 | if ((cp >= 0x4E00 and cp <= 0x9FFF) or # 241 | (cp >= 0x3400 and cp <= 0x4DBF) or # 242 | (cp >= 0x20000 and cp <= 0x2A6DF) or # 243 | (cp >= 0x2A700 and cp <= 0x2B73F) or # 244 | (cp >= 0x2B740 and cp <= 0x2B81F) or # 245 | (cp >= 0x2B820 and cp <= 0x2CEAF) or 246 | (cp >= 0xF900 and cp <= 0xFAFF) or # 247 | (cp >= 0x2F800 and cp <= 0x2FA1F)): # 248 | return True 249 | 250 | return False 251 | 252 | def _clean_text(self, text): 253 | """Performs invalid character removal and whitespace cleanup on text.""" 254 | output = [] 255 | for char in text: 256 | cp = ord(char) 257 | if cp == 0 or cp == 0xfffd or _is_control(char): 258 | continue 259 | if _is_whitespace(char): 260 | output.append(" ") 261 | else: 262 | output.append(char) 263 | return "".join(output) 264 | 265 | 266 | class WordpieceTokenizer(object): 267 | """Runs WordPiece tokenziation.""" 268 | 269 | def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=200): 270 | self.vocab = vocab 271 | self.unk_token = unk_token 272 | self.max_input_chars_per_word = max_input_chars_per_word 273 | 274 | def tokenize(self, text): 275 | """Tokenizes a piece of text into its word pieces. 276 | 277 | This uses a greedy longest-match-first algorithm to perform tokenization 278 | using the given vocabulary. 279 | 280 | For example: 281 | input = "unaffable" 282 | output = ["un", "##aff", "##able"] 283 | 284 | Args: 285 | text: A single token or whitespace separated tokens. This should have 286 | already been passed through `BasicTokenizer. 287 | 288 | Returns: 289 | A list of wordpiece tokens. 290 | """ 291 | 292 | text = convert_to_unicode(text) 293 | 294 | output_tokens = [] 295 | for token in whitespace_tokenize(text): 296 | chars = list(token) 297 | if len(chars) > self.max_input_chars_per_word: 298 | output_tokens.append(self.unk_token) 299 | continue 300 | 301 | is_bad = False 302 | start = 0 303 | sub_tokens = [] 304 | while start < len(chars): 305 | end = len(chars) 306 | cur_substr = None 307 | while start < end: 308 | substr = "".join(chars[start:end]) 309 | if start > 0: 310 | substr = "##" + substr 311 | if substr in self.vocab: 312 | cur_substr = substr 313 | break 314 | end -= 1 315 | if cur_substr is None: 316 | is_bad = True 317 | break 318 | sub_tokens.append(cur_substr) 319 | start = end 320 | 321 | if is_bad: 322 | output_tokens.append(self.unk_token) 323 | else: 324 | output_tokens.extend(sub_tokens) 325 | return output_tokens 326 | 327 | 328 | def _is_whitespace(char): 329 | """Checks whether `chars` is a whitespace character.""" 330 | # \t, \n, and \r are technically contorl characters but we treat them 331 | # as whitespace since they are generally considered as such. 332 | if char == " " or char == "\t" or char == "\n" or char == "\r": 333 | return True 334 | cat = unicodedata.category(char) 335 | if cat == "Zs": 336 | return True 337 | return False 338 | 339 | 340 | def _is_control(char): 341 | """Checks whether `chars` is a control character.""" 342 | # These are technically control characters but we count them as whitespace 343 | # characters. 344 | if char == "\t" or char == "\n" or char == "\r": 345 | return False 346 | cat = unicodedata.category(char) 347 | if cat.startswith("C"): 348 | return True 349 | return False 350 | 351 | 352 | def _is_punctuation(char): 353 | """Checks whether `chars` is a punctuation character.""" 354 | cp = ord(char) 355 | # We treat all non-letter/number ASCII as punctuation. 356 | # Characters such as "^", "$", and "`" are not in the Unicode 357 | # Punctuation class but we treat them as punctuation anyways, for 358 | # consistency. 359 | if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or 360 | (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)): 361 | return True 362 | cat = unicodedata.category(char) 363 | if cat.startswith("P"): 364 | return True 365 | return False 366 | -------------------------------------------------------------------------------- /code/electra-pretrain/pretrain.sh: -------------------------------------------------------------------------------- 1 | export DATA_DIR=../../user_data 2 | export ELECTRA_DIR=../../user_data/electra 3 | 4 | echo 'Prepare pretraining data...' 5 | python build_pretraining_dataset.py \ 6 | --corpus-dir=${DATA_DIR}/texts \ 7 | --max-seq-length=64 \ 8 | --vocab-file=${ELECTRA_DIR}/electra_180g_base/vocab.txt \ 9 | --output-dir=${DATA_DIR}/pretrain_tfrecords 10 | 11 | echo "Pretrain base electra model ~= 1 hour on V100" 12 | python run_pretraining.py \ 13 | --data-dir=${DATA_DIR} \ 14 | --model-name=base \ 15 | --hparams='{"use_amp": true, "learning_rate": 0.0002,"model_size": "base","eval_batch_size":128,"train_batch_size": 128, "init_checkpoint": "../../user_data/electra/electra_180g_base/electra_180g_base.ckpt", "vocab_file": "../../user_data/electra/electra_180g_base/vocab.txt"}' 16 | 17 | 18 | echo "pretrain large electra model ~= 2.2 hours on V100" 19 | python run_pretraining.py \ 20 | --data-dir=${DATA_DIR} \ 21 | --model-name=large \ 22 | --hparams='{"num_train_steps": 5000, "num_warmup_steps": 500, "model_size": "large", "train_batch_size": 43, "learning_rate": 5e-05, "init_checkpoint": "../../user_data/electra/electra_180g_large/electra_180g_large.ckpt", "use_amp": true, "accumulation_step": 3, "vocab_file": "../../user_data/electra/electra_180g_large/vocab.txt"}' 23 | -------------------------------------------------------------------------------- /code/electra-pretrain/pretrain/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. -------------------------------------------------------------------------------- /code/electra-pretrain/pretrain/pretrain_data.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Helpers for preparing pre-training data and supplying them to the model.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import collections 23 | 24 | import numpy as np 25 | import tensorflow.compat.v1 as tf 26 | 27 | import configure_pretraining 28 | from model import tokenization 29 | from util import utils 30 | 31 | 32 | def get_input_fn(config: configure_pretraining.PretrainingConfig, is_training, 33 | num_cpu_threads=8): 34 | """Creates an `input_fn` closure to be passed to TPUEstimator.""" 35 | 36 | input_files = [] 37 | for input_pattern in config.pretrain_tfrecords.split(","): 38 | input_files.extend(tf.io.gfile.glob(input_pattern)) 39 | 40 | def input_fn(params): 41 | """The actual input function.""" 42 | # batch_size = params["batch_size"] 43 | batch_size = config.train_batch_size 44 | 45 | 46 | name_to_features = { 47 | "input_ids": tf.io.FixedLenFeature([config.max_seq_length], tf.int64), 48 | "input_mask": tf.io.FixedLenFeature([config.max_seq_length], tf.int64), 49 | "segment_ids": tf.io.FixedLenFeature([config.max_seq_length], tf.int64), 50 | } 51 | 52 | d = tf.data.Dataset.from_tensor_slices(tf.constant(input_files)) 53 | d = d.repeat() 54 | d = d.shuffle(buffer_size=len(input_files)) 55 | 56 | # `cycle_length` is the number of parallel files that get read. 57 | cycle_length = min(num_cpu_threads, len(input_files)) 58 | 59 | # `sloppy` mode means that the interleaving is not exact. This adds 60 | # even more randomness to the training pipeline. 61 | d = d.apply( 62 | tf.data.experimental.parallel_interleave( 63 | tf.data.TFRecordDataset, 64 | sloppy=is_training, 65 | cycle_length=cycle_length)) 66 | d = d.shuffle(buffer_size=10000) 67 | 68 | # We must `drop_remainder` on training because the TPU requires fixed 69 | # size dimensions. For eval, we assume we are evaluating on the CPU or GPU 70 | # and we *don"t* want to drop the remainder, otherwise we wont cover 71 | # every sample. 72 | d = d.apply( 73 | tf.data.experimental.map_and_batch( 74 | lambda record: _decode_record(record, name_to_features), 75 | batch_size=batch_size, 76 | num_parallel_batches=num_cpu_threads, 77 | drop_remainder=True)) 78 | return d 79 | 80 | return input_fn 81 | 82 | 83 | def _decode_record(record, name_to_features): 84 | """Decodes a record to a TensorFlow example.""" 85 | example = tf.io.parse_single_example(record, name_to_features) 86 | 87 | # tf.Example only supports tf.int64, but the TPU only supports tf.int32. 88 | # So cast all int64 to int32. 89 | for name in list(example.keys()): 90 | t = example[name] 91 | if t.dtype == tf.int64: 92 | t = tf.cast(t, tf.int32) 93 | example[name] = t 94 | 95 | return example 96 | 97 | 98 | # model inputs - it's a bit nicer to use a namedtuple rather than keep the 99 | # features as a dict 100 | Inputs = collections.namedtuple( 101 | "Inputs", ["input_ids", "input_mask", "segment_ids", "masked_lm_positions", 102 | "masked_lm_ids", "masked_lm_weights"]) 103 | 104 | 105 | def features_to_inputs(features): 106 | return Inputs( 107 | input_ids=features["input_ids"], 108 | input_mask=features["input_mask"], 109 | segment_ids=features["segment_ids"], 110 | masked_lm_positions=(features["masked_lm_positions"] 111 | if "masked_lm_positions" in features else None), 112 | masked_lm_ids=(features["masked_lm_ids"] 113 | if "masked_lm_ids" in features else None), 114 | masked_lm_weights=(features["masked_lm_weights"] 115 | if "masked_lm_weights" in features else None), 116 | ) 117 | 118 | 119 | def get_updated_inputs(inputs, **kwargs): 120 | features = inputs._asdict() 121 | for k, v in kwargs.items(): 122 | features[k] = v 123 | return features_to_inputs(features) 124 | 125 | 126 | ENDC = "\033[0m" 127 | COLORS = ["\033[" + str(n) + "m" for n in list(range(91, 97)) + [90]] 128 | RED = COLORS[0] 129 | BLUE = COLORS[3] 130 | CYAN = COLORS[5] 131 | GREEN = COLORS[1] 132 | 133 | 134 | def print_tokens(inputs: Inputs, inv_vocab, updates_mask=None): 135 | """Pretty-print model inputs.""" 136 | pos_to_tokid = {} 137 | for tokid, pos, weight in zip( 138 | inputs.masked_lm_ids[0], inputs.masked_lm_positions[0], 139 | inputs.masked_lm_weights[0]): 140 | if weight == 0: 141 | pass 142 | else: 143 | pos_to_tokid[pos] = tokid 144 | 145 | text = "" 146 | provided_update_mask = (updates_mask is not None) 147 | if not provided_update_mask: 148 | updates_mask = np.zeros_like(inputs.input_ids) 149 | for pos, (tokid, um) in enumerate( 150 | zip(inputs.input_ids[0], updates_mask[0])): 151 | token = inv_vocab[tokid] 152 | if token == "[PAD]": 153 | break 154 | if pos in pos_to_tokid: 155 | token = RED + token + " (" + inv_vocab[pos_to_tokid[pos]] + ")" + ENDC 156 | if provided_update_mask: 157 | assert um == 1 158 | else: 159 | if provided_update_mask: 160 | assert um == 0 161 | text += token + " " 162 | utils.log(tokenization.printable_text(text)) 163 | -------------------------------------------------------------------------------- /code/electra-pretrain/pretrain/pretrain_helpers.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Helper functions for pre-training. These mainly deal with the gathering and 17 | scattering needed so the generator only makes predictions for the small number 18 | of masked tokens. 19 | """ 20 | 21 | from __future__ import absolute_import 22 | from __future__ import division 23 | from __future__ import print_function 24 | 25 | import tensorflow.compat.v1 as tf 26 | 27 | import configure_pretraining 28 | from model import modeling 29 | from model import tokenization 30 | from pretrain import pretrain_data 31 | 32 | 33 | def gather_positions(sequence, positions): 34 | """Gathers the vectors at the specific positions over a minibatch. 35 | 36 | Args: 37 | sequence: A [batch_size, seq_length] or 38 | [batch_size, seq_length, depth] tensor of values 39 | positions: A [batch_size, n_positions] tensor of indices 40 | 41 | Returns: A [batch_size, n_positions] or 42 | [batch_size, n_positions, depth] tensor of the values at the indices 43 | """ 44 | shape = modeling.get_shape_list(sequence, expected_rank=[2, 3]) 45 | depth_dimension = (len(shape) == 3) 46 | if depth_dimension: 47 | B, L, D = shape 48 | else: 49 | B, L = shape 50 | D = 1 51 | sequence = tf.expand_dims(sequence, -1) 52 | position_shift = tf.expand_dims(L * tf.range(B), -1) 53 | flat_positions = tf.reshape(positions + position_shift, [-1]) 54 | flat_sequence = tf.reshape(sequence, [B * L, D]) 55 | gathered = tf.gather(flat_sequence, flat_positions) 56 | if depth_dimension: 57 | return tf.reshape(gathered, [B, -1, D]) 58 | else: 59 | return tf.reshape(gathered, [B, -1]) 60 | 61 | 62 | def scatter_update(sequence, updates, positions): 63 | """Scatter-update a sequence. 64 | 65 | Args: 66 | sequence: A [batch_size, seq_len] or [batch_size, seq_len, depth] tensor 67 | updates: A tensor of size batch_size*seq_len(*depth) 68 | positions: A [batch_size, n_positions] tensor 69 | 70 | Returns: A tuple of two tensors. First is a [batch_size, seq_len] or 71 | [batch_size, seq_len, depth] tensor of "sequence" with elements at 72 | "positions" replaced by the values at "updates." Updates to index 0 are 73 | ignored. If there are duplicated positions the update is only applied once. 74 | Second is a [batch_size, seq_len] mask tensor of which inputs were updated. 75 | """ 76 | shape = modeling.get_shape_list(sequence, expected_rank=[2, 3]) 77 | depth_dimension = (len(shape) == 3) 78 | if depth_dimension: 79 | B, L, D = shape 80 | else: 81 | B, L = shape 82 | D = 1 83 | sequence = tf.expand_dims(sequence, -1) 84 | N = modeling.get_shape_list(positions)[1] 85 | 86 | shift = tf.expand_dims(L * tf.range(B), -1) 87 | flat_positions = tf.reshape(positions + shift, [-1, 1]) 88 | flat_updates = tf.reshape(updates, [-1, D]) 89 | updates = tf.scatter_nd(flat_positions, flat_updates, [B * L, D]) 90 | updates = tf.reshape(updates, [B, L, D]) 91 | 92 | flat_updates_mask = tf.ones([B * N], tf.int32) 93 | updates_mask = tf.scatter_nd(flat_positions, flat_updates_mask, [B * L]) 94 | updates_mask = tf.reshape(updates_mask, [B, L]) 95 | not_first_token = tf.concat([tf.zeros((B, 1), tf.int32), 96 | tf.ones((B, L - 1), tf.int32)], -1) 97 | updates_mask *= not_first_token 98 | updates_mask_3d = tf.expand_dims(updates_mask, -1) 99 | 100 | # account for duplicate positions 101 | if sequence.dtype == tf.float32: 102 | updates_mask_3d = tf.cast(updates_mask_3d, tf.float32) 103 | updates /= tf.maximum(1.0, updates_mask_3d) 104 | else: 105 | assert sequence.dtype == tf.int32 106 | updates = tf.math.floordiv(updates, tf.maximum(1, updates_mask_3d)) 107 | updates_mask = tf.minimum(updates_mask, 1) 108 | updates_mask_3d = tf.minimum(updates_mask_3d, 1) 109 | 110 | updated_sequence = (((1 - updates_mask_3d) * sequence) + 111 | (updates_mask_3d * updates)) 112 | if not depth_dimension: 113 | updated_sequence = tf.squeeze(updated_sequence, -1) 114 | 115 | return updated_sequence, updates_mask 116 | 117 | 118 | def _get_candidates_mask(inputs: pretrain_data.Inputs, vocab, 119 | disallow_from_mask=None): 120 | """Returns a mask tensor of positions in the input that can be masked out.""" 121 | ignore_ids = [vocab["[SEP]"], vocab["[CLS]"], vocab["[MASK]"]] 122 | candidates_mask = tf.ones_like(inputs.input_ids, tf.bool) 123 | for ignore_id in ignore_ids: 124 | candidates_mask &= tf.not_equal(inputs.input_ids, ignore_id) 125 | candidates_mask &= tf.cast(inputs.input_mask, tf.bool) 126 | if disallow_from_mask is not None: 127 | candidates_mask &= ~disallow_from_mask 128 | return candidates_mask 129 | 130 | 131 | def mask(config: configure_pretraining.PretrainingConfig, 132 | inputs: pretrain_data.Inputs, mask_prob, proposal_distribution=1.0, 133 | disallow_from_mask=None, already_masked=None): 134 | """Implementation of dynamic masking. The optional arguments aren't needed for 135 | BERT/ELECTRA and are from early experiments in "strategically" masking out 136 | tokens instead of uniformly at random. 137 | 138 | Args: 139 | config: configure_pretraining.PretrainingConfig 140 | inputs: pretrain_data.Inputs containing input input_ids/input_mask 141 | mask_prob: percent of tokens to mask 142 | proposal_distribution: for non-uniform masking can be a [B, L] tensor 143 | of scores for masking each position. 144 | disallow_from_mask: a boolean tensor of [B, L] of positions that should 145 | not be masked out 146 | already_masked: a boolean tensor of [B, N] of already masked-out tokens 147 | for multiple rounds of masking 148 | Returns: a pretrain_data.Inputs with masking added 149 | """ 150 | # Get the batch size, sequence length, and max masked-out tokens 151 | N = config.max_predictions_per_seq 152 | B, L = modeling.get_shape_list(inputs.input_ids) 153 | 154 | # Find indices where masking out a token is allowed 155 | vocab = tokenization.FullTokenizer( 156 | config.vocab_file, do_lower_case=config.do_lower_case).vocab 157 | candidates_mask = _get_candidates_mask(inputs, vocab, disallow_from_mask) 158 | 159 | # Set the number of tokens to mask out per example 160 | num_tokens = tf.cast(tf.reduce_sum(inputs.input_mask, -1), tf.float32) 161 | num_to_predict = tf.maximum(1, tf.minimum( 162 | N, tf.cast(tf.round(num_tokens * mask_prob), tf.int32))) 163 | masked_lm_weights = tf.cast(tf.sequence_mask(num_to_predict, N), tf.float32) 164 | if already_masked is not None: 165 | masked_lm_weights *= (1 - already_masked) 166 | 167 | # Get a probability of masking each position in the sequence 168 | candidate_mask_float = tf.cast(candidates_mask, tf.float32) 169 | sample_prob = (proposal_distribution * candidate_mask_float) 170 | sample_prob /= tf.reduce_sum(sample_prob, axis=-1, keepdims=True) 171 | 172 | # Sample the positions to mask out 173 | sample_prob = tf.stop_gradient(sample_prob) 174 | sample_logits = tf.log(sample_prob) 175 | masked_lm_positions = tf.random.categorical( 176 | sample_logits, N, dtype=tf.int32) 177 | masked_lm_positions *= tf.cast(masked_lm_weights, tf.int32) 178 | 179 | # Get the ids of the masked-out tokens 180 | shift = tf.expand_dims(L * tf.range(B), -1) 181 | flat_positions = tf.reshape(masked_lm_positions + shift, [-1, 1]) 182 | masked_lm_ids = tf.gather_nd(tf.reshape(inputs.input_ids, [-1]), 183 | flat_positions) 184 | masked_lm_ids = tf.reshape(masked_lm_ids, [B, -1]) 185 | masked_lm_ids *= tf.cast(masked_lm_weights, tf.int32) 186 | 187 | # Update the input ids 188 | replace_with_mask_positions = masked_lm_positions * tf.cast( 189 | tf.less(tf.random.uniform([B, N]), 0.85), tf.int32) 190 | inputs_ids, _ = scatter_update( 191 | inputs.input_ids, tf.fill([B, N], vocab["[MASK]"]), 192 | replace_with_mask_positions) 193 | 194 | return pretrain_data.get_updated_inputs( 195 | inputs, 196 | input_ids=tf.stop_gradient(inputs_ids), 197 | masked_lm_positions=masked_lm_positions, 198 | masked_lm_ids=masked_lm_ids, 199 | masked_lm_weights=masked_lm_weights 200 | ) 201 | 202 | 203 | def unmask(inputs: pretrain_data.Inputs): 204 | unmasked_input_ids, _ = scatter_update( 205 | inputs.input_ids, inputs.masked_lm_ids, inputs.masked_lm_positions) 206 | return pretrain_data.get_updated_inputs(inputs, input_ids=unmasked_input_ids) 207 | 208 | 209 | def sample_from_softmax(logits, disallow=None): 210 | if disallow is not None: 211 | logits -= 1000.0 * disallow 212 | uniform_noise = tf.random.uniform( 213 | modeling.get_shape_list(logits), minval=0, maxval=1) 214 | gumbel_noise = -tf.log(-tf.log(uniform_noise + 1e-9) + 1e-9) 215 | return tf.one_hot(tf.argmax(tf.nn.softmax(logits + gumbel_noise), -1, 216 | output_type=tf.int32), logits.shape[-1]) 217 | -------------------------------------------------------------------------------- /code/electra-pretrain/run_pretraining.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Pre-trains an ELECTRA model.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import argparse 23 | import collections 24 | import json 25 | 26 | import tensorflow.compat.v1 as tf 27 | 28 | import configure_pretraining 29 | from model import modeling 30 | from model import optimization 31 | from pretrain import pretrain_data 32 | from pretrain import pretrain_helpers 33 | from util import training_utils 34 | from util import utils 35 | 36 | 37 | class PretrainingModel(object): 38 | """Transformer pre-training using the replaced-token-detection task.""" 39 | 40 | def __init__(self, config: configure_pretraining.PretrainingConfig, 41 | features, is_training): 42 | # Set up model config 43 | self._config = config 44 | self._bert_config = training_utils.get_bert_config(config) 45 | if config.debug: 46 | self._bert_config.num_hidden_layers = 3 47 | self._bert_config.hidden_size = 144 48 | self._bert_config.intermediate_size = 144 * 4 49 | self._bert_config.num_attention_heads = 4 50 | 51 | # Mask the input 52 | masked_inputs = pretrain_helpers.mask( 53 | config, pretrain_data.features_to_inputs(features), config.mask_prob) 54 | 55 | # Generator 56 | embedding_size = ( 57 | self._bert_config.hidden_size if config.embedding_size is None else 58 | config.embedding_size) 59 | if config.uniform_generator: 60 | mlm_output = self._get_masked_lm_output(masked_inputs, None) 61 | elif config.electra_objective and config.untied_generator: 62 | generator = self._build_transformer( 63 | masked_inputs, is_training, 64 | bert_config=get_generator_config(config, self._bert_config), 65 | embedding_size=(None if config.untied_generator_embeddings 66 | else embedding_size), 67 | untied_embeddings=config.untied_generator_embeddings, 68 | name="generator", 69 | embedding_file=config.embedding_file) 70 | mlm_output = self._get_masked_lm_output(masked_inputs, generator) 71 | else: 72 | generator = self._build_transformer( 73 | masked_inputs, is_training, embedding_size=embedding_size, 74 | embedding_file=config.embedding_file) 75 | mlm_output = self._get_masked_lm_output(masked_inputs, generator) 76 | fake_data = self._get_fake_data(masked_inputs, mlm_output.logits) 77 | self.mlm_output = mlm_output 78 | self.total_loss = config.gen_weight * mlm_output.loss 79 | 80 | # Discriminator 81 | disc_output = None 82 | if config.electra_objective: 83 | discriminator = self._build_transformer( 84 | fake_data.inputs, is_training, reuse=not config.untied_generator, 85 | embedding_size=embedding_size, 86 | embedding_file=config.embedding_file) 87 | disc_output = self._get_discriminator_output( 88 | fake_data.inputs, discriminator, fake_data.is_fake_tokens) 89 | self.total_loss += config.disc_weight * disc_output.loss 90 | 91 | # Evaluation 92 | eval_fn_inputs = { 93 | "input_ids": masked_inputs.input_ids, 94 | "masked_lm_preds": mlm_output.preds, 95 | "mlm_loss": mlm_output.per_example_loss, 96 | "masked_lm_ids": masked_inputs.masked_lm_ids, 97 | "masked_lm_weights": masked_inputs.masked_lm_weights, 98 | "input_mask": masked_inputs.input_mask 99 | } 100 | if config.electra_objective: 101 | eval_fn_inputs.update({ 102 | "disc_loss": disc_output.per_example_loss, 103 | "disc_labels": disc_output.labels, 104 | "disc_probs": disc_output.probs, 105 | "disc_preds": disc_output.preds, 106 | "sampled_tokids": tf.argmax(fake_data.sampled_tokens, -1, 107 | output_type=tf.int32) 108 | }) 109 | eval_fn_keys = eval_fn_inputs.keys() 110 | eval_fn_values = [eval_fn_inputs[k] for k in eval_fn_keys] 111 | 112 | def metric_fn(*args): 113 | """Computes the loss and accuracy of the model.""" 114 | d = {k: arg for k, arg in zip(eval_fn_keys, args)} 115 | metrics = dict() 116 | metrics["masked_lm_accuracy"] = tf.metrics.accuracy( 117 | labels=tf.reshape(d["masked_lm_ids"], [-1]), 118 | predictions=tf.reshape(d["masked_lm_preds"], [-1]), 119 | weights=tf.reshape(d["masked_lm_weights"], [-1])) 120 | metrics["masked_lm_loss"] = tf.metrics.mean( 121 | values=tf.reshape(d["mlm_loss"], [-1]), 122 | weights=tf.reshape(d["masked_lm_weights"], [-1])) 123 | if config.electra_objective: 124 | metrics["sampled_masked_lm_accuracy"] = tf.metrics.accuracy( 125 | labels=tf.reshape(d["masked_lm_ids"], [-1]), 126 | predictions=tf.reshape(d["sampled_tokids"], [-1]), 127 | weights=tf.reshape(d["masked_lm_weights"], [-1])) 128 | if config.disc_weight > 0: 129 | metrics["disc_loss"] = tf.metrics.mean(d["disc_loss"]) 130 | metrics["disc_auc"] = tf.metrics.auc( 131 | d["disc_labels"] * d["input_mask"], 132 | d["disc_probs"] * tf.cast(d["input_mask"], tf.float32)) 133 | metrics["disc_accuracy"] = tf.metrics.accuracy( 134 | labels=d["disc_labels"], predictions=d["disc_preds"], 135 | weights=d["input_mask"]) 136 | metrics["disc_precision"] = tf.metrics.accuracy( 137 | labels=d["disc_labels"], predictions=d["disc_preds"], 138 | weights=d["disc_preds"] * d["input_mask"]) 139 | metrics["disc_recall"] = tf.metrics.accuracy( 140 | labels=d["disc_labels"], predictions=d["disc_preds"], 141 | weights=d["disc_labels"] * d["input_mask"]) 142 | return metrics 143 | self.eval_metrics = (metric_fn, eval_fn_values) 144 | 145 | def _get_masked_lm_output(self, inputs: pretrain_data.Inputs, model): 146 | """Masked language modeling softmax layer.""" 147 | masked_lm_weights = inputs.masked_lm_weights 148 | with tf.variable_scope("generator_predictions"): 149 | if self._config.uniform_generator: 150 | logits = tf.zeros(self._bert_config.vocab_size) 151 | logits_tiled = tf.zeros( 152 | modeling.get_shape_list(inputs.masked_lm_ids) + 153 | [self._bert_config.vocab_size]) 154 | logits_tiled += tf.reshape(logits, [1, 1, self._bert_config.vocab_size]) 155 | logits = logits_tiled 156 | else: 157 | relevant_hidden = pretrain_helpers.gather_positions( 158 | model.get_sequence_output(), inputs.masked_lm_positions) 159 | hidden = tf.layers.dense( 160 | relevant_hidden, 161 | units=modeling.get_shape_list(model.get_embedding_table())[-1], 162 | activation=modeling.get_activation(self._bert_config.hidden_act), 163 | kernel_initializer=modeling.create_initializer( 164 | self._bert_config.initializer_range)) 165 | hidden = modeling.layer_norm(hidden) 166 | output_bias = tf.get_variable( 167 | "output_bias", 168 | shape=[self._bert_config.vocab_size], 169 | initializer=tf.zeros_initializer()) 170 | logits = tf.matmul(hidden, model.get_embedding_table(), 171 | transpose_b=True) 172 | logits = tf.nn.bias_add(logits, output_bias) 173 | 174 | oh_labels = tf.one_hot( 175 | inputs.masked_lm_ids, depth=self._bert_config.vocab_size, 176 | dtype=tf.float32) 177 | 178 | probs = tf.nn.softmax(logits) 179 | log_probs = tf.nn.log_softmax(logits) 180 | label_log_probs = -tf.reduce_sum(log_probs * oh_labels, axis=-1) 181 | 182 | numerator = tf.reduce_sum(inputs.masked_lm_weights * label_log_probs) 183 | denominator = tf.reduce_sum(masked_lm_weights) + 1e-6 184 | loss = numerator / denominator 185 | preds = tf.argmax(log_probs, axis=-1, output_type=tf.int32) 186 | 187 | MLMOutput = collections.namedtuple( 188 | "MLMOutput", ["logits", "probs", "loss", "per_example_loss", "preds"]) 189 | return MLMOutput( 190 | logits=logits, probs=probs, per_example_loss=label_log_probs, 191 | loss=loss, preds=preds) 192 | 193 | def _get_discriminator_output(self, inputs, discriminator, labels): 194 | """Discriminator binary classifier.""" 195 | with tf.variable_scope("discriminator_predictions"): 196 | hidden = tf.layers.dense( 197 | discriminator.get_sequence_output(), 198 | units=self._bert_config.hidden_size, 199 | activation=modeling.get_activation(self._bert_config.hidden_act), 200 | kernel_initializer=modeling.create_initializer( 201 | self._bert_config.initializer_range)) 202 | logits = tf.squeeze(tf.layers.dense(hidden, units=1), -1) 203 | weights = tf.cast(inputs.input_mask, tf.float32) 204 | labelsf = tf.cast(labels, tf.float32) 205 | losses = tf.nn.sigmoid_cross_entropy_with_logits( 206 | logits=logits, labels=labelsf) * weights 207 | per_example_loss = (tf.reduce_sum(losses, axis=-1) / 208 | (1e-6 + tf.reduce_sum(weights, axis=-1))) 209 | loss = tf.reduce_sum(losses) / (1e-6 + tf.reduce_sum(weights)) 210 | probs = tf.nn.sigmoid(logits) 211 | preds = tf.cast(tf.round((tf.sign(logits) + 1) / 2), tf.int32) 212 | DiscOutput = collections.namedtuple( 213 | "DiscOutput", ["loss", "per_example_loss", "probs", "preds", 214 | "labels"]) 215 | return DiscOutput( 216 | loss=loss, per_example_loss=per_example_loss, probs=probs, 217 | preds=preds, labels=labels, 218 | ) 219 | 220 | def _get_fake_data(self, inputs, mlm_logits): 221 | """Sample from the generator to create corrupted input.""" 222 | inputs = pretrain_helpers.unmask(inputs) 223 | disallow = tf.one_hot( 224 | inputs.masked_lm_ids, depth=self._bert_config.vocab_size, 225 | dtype=tf.float32) if self._config.disallow_correct else None 226 | sampled_tokens = tf.stop_gradient(pretrain_helpers.sample_from_softmax( 227 | mlm_logits / self._config.temperature, disallow=disallow)) 228 | sampled_tokids = tf.argmax(sampled_tokens, -1, output_type=tf.int32) 229 | updated_input_ids, masked = pretrain_helpers.scatter_update( 230 | inputs.input_ids, sampled_tokids, inputs.masked_lm_positions) 231 | labels = masked * (1 - tf.cast( 232 | tf.equal(updated_input_ids, inputs.input_ids), tf.int32)) 233 | updated_inputs = pretrain_data.get_updated_inputs( 234 | inputs, input_ids=updated_input_ids) 235 | FakedData = collections.namedtuple("FakedData", [ 236 | "inputs", "is_fake_tokens", "sampled_tokens"]) 237 | return FakedData(inputs=updated_inputs, is_fake_tokens=labels, 238 | sampled_tokens=sampled_tokens) 239 | 240 | def _build_transformer(self, inputs: pretrain_data.Inputs, is_training, 241 | bert_config=None, name="electra", reuse=False, embedding_file=None, **kwargs): 242 | """Build a transformer encoder network.""" 243 | if bert_config is None: 244 | bert_config = self._bert_config 245 | with tf.variable_scope(tf.get_variable_scope(), reuse=reuse): 246 | return modeling.BertModel( 247 | bert_config=bert_config, 248 | is_training=is_training, 249 | input_ids=inputs.input_ids, 250 | input_mask=inputs.input_mask, 251 | token_type_ids=inputs.segment_ids, 252 | use_one_hot_embeddings=self._config.use_tpu, 253 | scope=name, 254 | embedding_file=embedding_file, 255 | **kwargs) 256 | 257 | 258 | def get_generator_config(config: configure_pretraining.PretrainingConfig, 259 | bert_config: modeling.BertConfig): 260 | """Get model config for the generator network.""" 261 | gen_config = modeling.BertConfig.from_dict(bert_config.to_dict()) 262 | gen_config.hidden_size = int(round( 263 | bert_config.hidden_size * config.generator_hidden_size)) 264 | gen_config.num_hidden_layers = int(round( 265 | bert_config.num_hidden_layers * config.generator_layers)) 266 | gen_config.intermediate_size = 4 * gen_config.hidden_size 267 | gen_config.num_attention_heads = max(1, gen_config.hidden_size // 64) 268 | return gen_config 269 | 270 | 271 | def model_fn_builder(config: configure_pretraining.PretrainingConfig): 272 | """Build the model for training.""" 273 | 274 | def model_fn(features, labels, mode, params): 275 | """Build the model for training.""" 276 | model = PretrainingModel(config, features, 277 | mode == tf.estimator.ModeKeys.TRAIN) 278 | utils.log("Model is built!") 279 | if mode == tf.estimator.ModeKeys.TRAIN: 280 | tvars = tf.trainable_variables() 281 | # for t in tvars: 282 | # print(t) 283 | initialized_variable_names = {} 284 | init_checkpoint = config.init_checkpoint 285 | if init_checkpoint: 286 | (assignment_map, 287 | initialized_variable_names) = modeling.get_assignment_map_from_checkpoint(tvars, init_checkpoint,update_vocab=config.embedding_file is not None) 288 | tf.train.init_from_checkpoint(init_checkpoint, assignment_map) 289 | utils.log("**** Trainable Variables ****") 290 | 291 | for var in tvars: 292 | init_string = "" 293 | if var.name in initialized_variable_names: 294 | init_string = ", *INIT_FROM_CKPT*" 295 | utils.log(" name = %s, shape = %s%s"% ( var.name, var.shape, 296 | init_string)) 297 | train_op = optimization.create_optimizer( 298 | model.total_loss, config.learning_rate, config.num_train_steps, 299 | weight_decay_rate=config.weight_decay_rate, 300 | use_tpu=config.use_tpu, 301 | warmup_steps=config.num_warmup_steps, 302 | lr_decay_power=config.lr_decay_power, 303 | amp=config.use_amp, 304 | accumulation_step=config.accumulation_step 305 | ) 306 | # output_spec = tf.estimator.tpu.TPUEstimatorSpec( 307 | output_spec = tf.estimator.EstimatorSpec( 308 | mode=mode, 309 | loss=model.total_loss, 310 | train_op=train_op, 311 | training_hooks=[training_utils.ETAHook( 312 | {} if config.use_tpu else dict(loss=model.total_loss), 313 | config.num_train_steps, config.iterations_per_loop, 314 | config.use_tpu,model_name=config.model_name)] 315 | ) 316 | elif mode == tf.estimator.ModeKeys.EVAL: 317 | output_spec = tf.estimator.tpu.TPUEstimatorSpec( 318 | mode=mode, 319 | loss=model.total_loss, 320 | eval_metrics=model.eval_metrics, 321 | evaluation_hooks=[training_utils.ETAHook( 322 | {} if config.use_tpu else dict(loss=model.total_loss), 323 | config.num_eval_steps, config.iterations_per_loop, 324 | config.use_tpu, is_training=False)]) 325 | else: 326 | raise ValueError("Only TRAIN and EVAL modes are supported") 327 | return output_spec 328 | 329 | return model_fn 330 | 331 | 332 | def train_or_eval(config: configure_pretraining.PretrainingConfig): 333 | """Run pre-training or evaluate the pre-trained model.""" 334 | if config.do_train == config.do_eval: 335 | raise ValueError("Exactly one of `do_train` or `do_eval` must be True.") 336 | if config.debug: 337 | utils.rmkdir(config.model_dir) 338 | utils.heading("Config:") 339 | utils.log_config(config) 340 | 341 | is_per_host = tf.estimator.tpu.InputPipelineConfig.PER_HOST_V2 342 | tpu_cluster_resolver = None 343 | tf_config = tf.ConfigProto() 344 | tf_config.gpu_options.allow_growth = True 345 | tf_config.allow_soft_placement = True 346 | 347 | if config.use_tpu and config.tpu_name: 348 | tpu_cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver( 349 | config.tpu_name, zone=config.tpu_zone, project=config.gcp_project) 350 | 351 | run_config = tf.estimator.RunConfig( 352 | model_dir=config.model_dir, 353 | session_config=tf_config, 354 | save_checkpoints_steps=config.save_checkpoints_steps, 355 | keep_checkpoint_max=1) 356 | 357 | # tpu_config = tf.estimator.tpu.TPUConfig( 358 | # iterations_per_loop=config.iterations_per_loop, 359 | # num_shards=(config.num_tpu_cores if config.do_train else 360 | # config.num_tpu_cores), 361 | # tpu_job_name=config.tpu_job_name, 362 | # per_host_input_for_training=is_per_host) 363 | # run_config = tf.estimator.tpu.RunConfig( 364 | # cluster=tpu_cluster_resolver, 365 | # model_dir=config.model_dir, 366 | # save_checkpoints_steps=config.save_checkpoints_steps, 367 | # keep_checkpoint_max=1, 368 | # tpu_config=tpu_config) 369 | model_fn = model_fn_builder(config=config) 370 | 371 | estimator = tf.estimator.Estimator( 372 | model_fn=model_fn, 373 | config=run_config) 374 | # estimator = tf.estimator.tpu.TPUEstimator( 375 | # use_tpu=config.use_tpu, 376 | # model_fn=model_fn, 377 | # config=run_config, 378 | # train_batch_size=config.train_batch_size, 379 | # eval_batch_size=config.eval_batch_size) 380 | 381 | if config.do_train: 382 | utils.heading("Running training") 383 | estimator.train(input_fn=pretrain_data.get_input_fn(config, True), 384 | max_steps=config.num_train_steps) 385 | if config.do_eval: 386 | utils.heading("Running evaluation") 387 | result = estimator.evaluate( 388 | input_fn=pretrain_data.get_input_fn(config, False), 389 | steps=config.num_eval_steps) 390 | for key in sorted(result.keys()): 391 | utils.log(" {:} = {:}".format(key, str(result[key]))) 392 | return result 393 | 394 | 395 | def train_one_step(config: configure_pretraining.PretrainingConfig): 396 | """Builds an ELECTRA model an trains it for one step; useful for debugging.""" 397 | train_input_fn = pretrain_data.get_input_fn(config, True) 398 | features = tf.data.make_one_shot_iterator(train_input_fn(dict( 399 | batch_size=config.train_batch_size))).get_next() 400 | model = PretrainingModel(config, features, True) 401 | with tf.Session() as sess: 402 | sess.run(tf.global_variables_initializer()) 403 | utils.log(sess.run(model.total_loss)) 404 | 405 | 406 | def main(): 407 | parser = argparse.ArgumentParser(description=__doc__) 408 | parser.add_argument("--data-dir", required=True, 409 | help="Location of data files (model weights, etc).") 410 | parser.add_argument("--model-name", required=True, 411 | help="The name of the model being fine-tuned.") 412 | parser.add_argument("--hparams", default="{}", 413 | help="JSON dict of model hyperparameters.") 414 | args = parser.parse_args() 415 | if args.hparams.endswith(".json"): 416 | hparams = utils.load_json(args.hparams) 417 | else: 418 | hparams = json.loads(args.hparams) 419 | tf.logging.set_verbosity(tf.logging.ERROR) 420 | train_or_eval(configure_pretraining.PretrainingConfig( 421 | args.model_name, args.data_dir, **hparams)) 422 | 423 | 424 | if __name__ == "__main__": 425 | main() 426 | -------------------------------------------------------------------------------- /code/electra-pretrain/util/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. -------------------------------------------------------------------------------- /code/electra-pretrain/util/training_utils.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Utilities for training the models.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import datetime 23 | import re 24 | import time 25 | import tensorflow.compat.v1 as tf 26 | 27 | from model import modeling 28 | from util import utils 29 | 30 | 31 | class ETAHook(tf.estimator.SessionRunHook): 32 | """Print out the time remaining during training/evaluation.""" 33 | 34 | def __init__(self, to_log, n_steps, iterations_per_loop, on_tpu, 35 | log_every=1, is_training=True, model_name='base'): 36 | self._to_log = to_log 37 | self._n_steps = n_steps 38 | self._iterations_per_loop = iterations_per_loop 39 | self._on_tpu = on_tpu 40 | self._log_every = log_every 41 | self._is_training = is_training 42 | self._steps_run_so_far = 0 43 | self._global_step = None 44 | self._global_step_tensor = None 45 | self._start_step = None 46 | self._start_time = None 47 | self.log_file = open('/tmp/{}_pretrain.log'.format(model_name),'w') 48 | 49 | def begin(self): 50 | self._global_step_tensor = tf.train.get_or_create_global_step() 51 | 52 | def before_run(self, run_context): 53 | if self._start_time is None: 54 | self._start_time = time.time() 55 | return tf.estimator.SessionRunArgs(self._to_log) 56 | 57 | def after_run(self, run_context, run_values): 58 | self._global_step = run_context.session.run(self._global_step_tensor) 59 | self._steps_run_so_far += self._iterations_per_loop if self._on_tpu else 1 60 | if self._start_step is None: 61 | self._start_step = self._global_step - (self._iterations_per_loop 62 | if self._on_tpu else 1) 63 | self.log(run_values) 64 | 65 | def end(self, session): 66 | self._global_step = session.run(self._global_step_tensor) 67 | self.log() 68 | self.log_file.close() 69 | 70 | def log(self, run_values=None): 71 | step = self._global_step if self._is_training else self._steps_run_so_far 72 | if step % self._log_every != 0: 73 | return 74 | msg = "{:}/{:} = {:.1f}%".format(step, self._n_steps, 75 | 100.0 * step / self._n_steps) 76 | time_elapsed = time.time() - self._start_time 77 | time_per_step = time_elapsed / ( 78 | (step - self._start_step) if self._is_training else step) 79 | msg += ", SPS: {:.1f}".format(1 / time_per_step) 80 | msg += ", ELAP: " + secs_to_str(time_elapsed) 81 | msg += ", ETA: " + secs_to_str( 82 | (self._n_steps - step) * time_per_step) 83 | if run_values is not None: 84 | for tag, value in run_values.results.items(): 85 | msg += " - " + str(tag) + (": {:.4f}".format(value)) 86 | utils.log(msg) 87 | self.log_file.write(msg + '\n') 88 | self.log_file.flush() 89 | 90 | 91 | def secs_to_str(secs): 92 | s = str(datetime.timedelta(seconds=int(round(secs)))) 93 | s = re.sub("^0:", "", s) 94 | s = re.sub("^0", "", s) 95 | s = re.sub("^0:", "", s) 96 | s = re.sub("^0", "", s) 97 | return s 98 | 99 | 100 | def get_bert_config(config): 101 | """Get model hyperparameters based on a pretraining/finetuning config""" 102 | if config.model_size == "large": 103 | args = {"hidden_size": 1024, "num_hidden_layers": 24} 104 | elif config.model_size == "base": 105 | args = {"hidden_size": 768, "num_hidden_layers": 12} 106 | elif config.model_size == "small": 107 | args = {"hidden_size": 256, "num_hidden_layers": 24} 108 | else: 109 | raise ValueError("Unknown model size", config.model_size) 110 | args["vocab_size"] = config.vocab_size 111 | args.update(**config.model_hparam_overrides) 112 | # by default the ff size and num attn heads are determined by the hidden size 113 | args["num_attention_heads"] = max(1, args["hidden_size"] // 64) 114 | args["intermediate_size"] = 4 * args["hidden_size"] 115 | args.update(**config.model_hparam_overrides) 116 | return modeling.BertConfig.from_dict(args) 117 | -------------------------------------------------------------------------------- /code/electra-pretrain/util/utils.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """A collection of general utility functions.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import json 23 | import pickle 24 | import sys 25 | 26 | import tensorflow.compat.v1 as tf 27 | 28 | 29 | def load_json(path): 30 | with tf.io.gfile.GFile(path, "r") as f: 31 | return json.load(f) 32 | 33 | 34 | def write_json(o, path): 35 | if "/" in path: 36 | tf.io.gfile.makedirs(path.rsplit("/", 1)[0]) 37 | with tf.io.gfile.GFile(path, "w") as f: 38 | json.dump(o, f) 39 | 40 | 41 | def load_pickle(path): 42 | with tf.io.gfile.GFile(path, "rb") as f: 43 | return pickle.load(f) 44 | 45 | 46 | def write_pickle(o, path): 47 | if "/" in path: 48 | tf.io.gfile.makedirs(path.rsplit("/", 1)[0]) 49 | with tf.io.gfile.GFile(path, "wb") as f: 50 | pickle.dump(o, f, -1) 51 | 52 | 53 | def mkdir(path): 54 | if not tf.io.gfile.exists(path): 55 | tf.io.gfile.makedirs(path) 56 | 57 | 58 | def rmrf(path): 59 | if tf.io.gfile.exists(path): 60 | tf.io.gfile.rmtree(path) 61 | 62 | 63 | def rmkdir(path): 64 | rmrf(path) 65 | mkdir(path) 66 | 67 | 68 | def log(*args): 69 | msg = " ".join(map(str, args)) 70 | sys.stdout.write(msg + "\n") 71 | sys.stdout.flush() 72 | 73 | 74 | def log_config(config): 75 | for key, value in sorted(config.__dict__.items()): 76 | log(key, value) 77 | log() 78 | 79 | 80 | def heading(*args): 81 | log(80 * "=") 82 | log(*args) 83 | log(80 * "=") 84 | 85 | 86 | def nest_dict(d, prefixes, delim="_"): 87 | """Go from {prefix_key: value} to {prefix: {key: value}}.""" 88 | nested = {} 89 | for k, v in d.items(): 90 | for prefix in prefixes: 91 | if k.startswith(prefix + delim): 92 | if prefix not in nested: 93 | nested[prefix] = {} 94 | nested[prefix][k.split(delim, 1)[1]] = v 95 | else: 96 | nested[k] = v 97 | return nested 98 | 99 | 100 | def flatten_dict(d, delim="_"): 101 | """Go from {prefix: {key: value}} to {prefix_key: value}.""" 102 | flattened = {} 103 | for k, v in d.items(): 104 | if isinstance(v, dict): 105 | for k2, v2 in v.items(): 106 | flattened[k + delim + k2] = v2 107 | else: 108 | flattened[k] = v 109 | return flattened 110 | -------------------------------------------------------------------------------- /code/optimization.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Functions and classes related to optimization (weight updates).""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import re 22 | import tensorflow as tf 23 | 24 | def create_optimizer(loss, init_lr, num_train_steps, num_warmup_steps, hvd=None, amp=False, accumulation_step=1,freeze_bert=False, head_lr_ratio=1.0): 25 | """Creates an optimizer training op. 26 | 27 | Args: 28 | loss: training loss 29 | init_lr: initial learning rate 30 | num_train_steps: total training steps 31 | num_warmup_steps: warmup steps 32 | hvd: whether use hvd for distribute training 33 | amp: whether use auto-mix-precision to speed up training 34 | accumulation_step: gradient accumulation steps 35 | freeze_bert: whether to freeze bert variables 36 | head_lr_ratio: bert and head should have different learning rate 37 | """ 38 | global_step = tf.train.get_or_create_global_step() 39 | 40 | learning_rate = tf.constant(value=init_lr, shape=[], dtype=tf.float32) 41 | 42 | # Implements linear decay of the learning rate. 43 | learning_rate = tf.train.polynomial_decay( 44 | learning_rate, 45 | global_step, 46 | num_train_steps, 47 | end_learning_rate=0.0,#if not use_swa else init_lr/2, 48 | power=1.0, 49 | cycle=False) 50 | 51 | # Implements linear warmup. I.e., if global_step < num_warmup_steps, the 52 | # learning rate will be `global_step/num_warmup_steps * init_lr`. 53 | if num_warmup_steps: 54 | global_steps_int = tf.cast(global_step, tf.int32) 55 | warmup_steps_int = tf.constant(num_warmup_steps, dtype=tf.int32) 56 | 57 | global_steps_float = tf.cast(global_steps_int, tf.float32) 58 | warmup_steps_float = tf.cast(warmup_steps_int, tf.float32) 59 | 60 | warmup_percent_done = global_steps_float / warmup_steps_float 61 | warmup_learning_rate = init_lr * warmup_percent_done 62 | 63 | is_warmup = tf.cast(global_steps_int < warmup_steps_int, tf.float32) 64 | learning_rate = ( 65 | (1.0 - is_warmup) * learning_rate + is_warmup * warmup_learning_rate) 66 | 67 | # It is recommended that you use this optimizer for fine tuning, since this 68 | # is how the model was trained (note that the Adam m/v variables are NOT 69 | # loaded from init_checkpoint.) 70 | optimizer = AdamWeightDecayOptimizer( 71 | learning_rate=learning_rate, 72 | head_lr_ratio=head_lr_ratio, 73 | weight_decay_rate=0.01, 74 | beta_1=0.9, 75 | beta_2=0.999, 76 | epsilon=1e-6, 77 | exclude_from_weight_decay=["LayerNorm", "layer_norm", "bias"]) 78 | 79 | if hvd is not None: 80 | from horovod.tensorflow.compression import Compression 81 | optimizer = hvd.DistributedOptimizer(optimizer, sparse_as_dense = True, compression=Compression.fp16 if amp else Compression.none) 82 | 83 | if amp: 84 | loss_scaler = tf.train.experimental.DynamicLossScale(initial_loss_scale=2**32, increment_period=1000, multiplier=2.0) 85 | optimizer = tf.train.experimental.enable_mixed_precision_graph_rewrite(optimizer, loss_scaler) 86 | loss_scale_value = tf.identity(loss_scaler(), name="loss_scale") 87 | 88 | tvars = tf.trainable_variables() 89 | if freeze_bert: 90 | tvars = [var for var in tvars if 'bert' not in var.name] 91 | grads_and_vars = optimizer.compute_gradients(loss, tvars) 92 | 93 | if accumulation_step > 1: 94 | tf.logging.info('### Using Gradient Accumulation with {} ###'.format(accumulation_step)) 95 | 96 | local_step = tf.get_variable(name="local_step", shape=[], dtype=tf.int32, trainable=False, 97 | initializer=tf.zeros_initializer) 98 | batch_finite = tf.get_variable(name="batch_finite", shape=[], dtype=tf.bool, trainable=False, 99 | initializer=tf.ones_initializer) 100 | accum_vars = [tf.get_variable( 101 | name=tvar.name.split(":")[0] + "/accum", 102 | shape=tvar.shape.as_list(), 103 | dtype=tf.float32, 104 | trainable=False, 105 | initializer=tf.zeros_initializer()) for tvar in tf.trainable_variables()] 106 | 107 | reset_step = tf.cast(tf.math.equal(local_step % accumulation_step, 0), dtype=tf.bool) 108 | local_step = tf.cond(reset_step, lambda:local_step.assign(tf.ones_like(local_step)), lambda:local_step.assign_add(1)) 109 | 110 | grads_and_vars_and_accums = [(gv[0],gv[1],accum_vars[i]) for i, gv in enumerate(grads_and_vars) if gv[0] is not None] 111 | grads, tvars, accum_vars = list(zip(*grads_and_vars_and_accums)) 112 | 113 | all_are_finite = tf.reduce_all([tf.reduce_all(tf.is_finite(g)) for g in grads]) if amp else tf.constant(True, dtype=tf.bool) 114 | batch_finite = tf.cond(reset_step, 115 | lambda: batch_finite.assign(tf.math.logical_and(tf.constant(True, dtype=tf.bool), all_are_finite)), 116 | lambda: batch_finite.assign(tf.math.logical_and(batch_finite, all_are_finite))) 117 | 118 | # This is how the model was pre-trained. 119 | # ensure global norm is a finite number 120 | # to prevent clip_by_global_norm from having a hizzy fit. 121 | (clipped_grads, _) = tf.clip_by_global_norm( 122 | grads, clip_norm=1.0, 123 | use_norm=tf.cond( 124 | all_are_finite, 125 | lambda: tf.global_norm(grads), 126 | lambda: tf.constant(1.0))) 127 | 128 | accum_vars = tf.cond(reset_step, 129 | lambda: [accum_vars[i].assign(grad) for i, grad in enumerate(clipped_grads)], 130 | lambda: [accum_vars[i].assign_add(grad) for i, grad in enumerate(clipped_grads)]) 131 | 132 | def update(accum_vars): 133 | return optimizer.apply_gradients(list(zip(accum_vars, tvars))) 134 | 135 | update_step = tf.identity(tf.cast(tf.math.equal(local_step % accumulation_step, 0), dtype=tf.bool), name="update_step") 136 | update_op = tf.cond(update_step, 137 | lambda: update(accum_vars), lambda: tf.no_op()) 138 | 139 | new_global_step = tf.cond(tf.math.logical_and(update_step, 140 | tf.cast(hvd.allreduce(tf.cast(batch_finite, tf.int32)), tf.bool) if hvd is not None else batch_finite), 141 | lambda: global_step+1, 142 | lambda: global_step) 143 | new_global_step = tf.identity(new_global_step, name='step_update') 144 | train_op = tf.group(update_op, [global_step.assign(new_global_step)]) 145 | else: 146 | grads_and_vars = [(g, v) for g, v in grads_and_vars if g is not None] 147 | grads, tvars = list(zip(*grads_and_vars)) 148 | all_are_finite = tf.reduce_all( 149 | [tf.reduce_all(tf.is_finite(g)) for g in grads]) if amp else tf.constant(True, dtype=tf.bool) 150 | 151 | # This is how the model was pre-trained. 152 | # ensure global norm is a finite number 153 | # to prevent clip_by_global_norm from having a hizzy fit. 154 | (clipped_grads, _) = tf.clip_by_global_norm( 155 | grads, clip_norm=1.0, 156 | use_norm=tf.cond( 157 | all_are_finite, 158 | lambda: tf.global_norm(grads), 159 | lambda: tf.constant(1.0))) 160 | 161 | train_op = optimizer.apply_gradients( 162 | list(zip(clipped_grads, tvars))) 163 | 164 | new_global_step = tf.cond(all_are_finite, lambda: global_step + 1, lambda: global_step) 165 | new_global_step = tf.identity(new_global_step, name='step_update') 166 | train_op = tf.group(train_op, [global_step.assign(new_global_step)]) 167 | return train_op, learning_rate 168 | 169 | class AdamWeightDecayOptimizer(tf.train.Optimizer): 170 | """A basic Adam optimizer that includes "correct" L2 weight decay.""" 171 | 172 | def __init__(self, 173 | learning_rate, 174 | head_lr_ratio=1.0, 175 | weight_decay_rate=0.0, 176 | beta_1=0.9, 177 | beta_2=0.999, 178 | epsilon=1e-6, 179 | exclude_from_weight_decay=None, 180 | name="AdamWeightDecayOptimizer"): 181 | """Constructs a AdamWeightDecayOptimizer.""" 182 | super(AdamWeightDecayOptimizer, self).__init__(False, name) 183 | 184 | self.learning_rate = learning_rate 185 | self.weight_decay_rate = weight_decay_rate 186 | self.beta_1 = beta_1 187 | self.beta_2 = beta_2 188 | self.epsilon = epsilon 189 | self.exclude_from_weight_decay = exclude_from_weight_decay 190 | self.head_lr_ratio = head_lr_ratio 191 | 192 | def _apply_gradients(self, grads_and_vars, learning_rate): 193 | """See base class.""" 194 | assignments = [] 195 | for (grad, param) in grads_and_vars: 196 | if grad is None or param is None: 197 | continue 198 | 199 | param_name = self._get_variable_name(param.name) 200 | 201 | m = tf.get_variable( 202 | name=param_name + "/adam_m", 203 | shape=param.shape.as_list(), 204 | dtype=tf.float32, 205 | trainable=False, 206 | initializer=tf.zeros_initializer()) 207 | v = tf.get_variable( 208 | name=param_name + "/adam_v", 209 | shape=param.shape.as_list(), 210 | dtype=tf.float32, 211 | trainable=False, 212 | initializer=tf.zeros_initializer()) 213 | 214 | # Standard Adam update. 215 | next_m = ( 216 | tf.multiply(self.beta_1, m) + tf.multiply(1.0 - self.beta_1, grad)) 217 | next_v = ( 218 | tf.multiply(self.beta_2, v) + tf.multiply(1.0 - self.beta_2, 219 | tf.square(grad))) 220 | update = next_m / (tf.sqrt(next_v) + self.epsilon) 221 | 222 | # Just adding the square of the weights to the loss function is *not* 223 | # the correct way of using L2 regularization/weight decay with Adam, 224 | # since that will interact with the m and v parameters in strange ways. 225 | # 226 | # Instead we want ot decay the weights in a manner that doesn't interact 227 | # with the m/v parameters. This is equivalent to adding the square 228 | # of the weights to the loss with plain (non-momentum) SGD. 229 | if self.weight_decay_rate > 0: 230 | if self._do_use_weight_decay(param_name): 231 | update += self.weight_decay_rate * param 232 | 233 | update_with_lr = learning_rate * update 234 | next_param = param - update_with_lr 235 | 236 | assignments.extend( 237 | [param.assign(next_param), 238 | m.assign(next_m), 239 | v.assign(next_v)]) 240 | 241 | return assignments 242 | 243 | def apply_gradients(self, grads_and_vars, global_step=None, name=None): 244 | """See base class.""" 245 | if self.head_lr_ratio > 1.0: 246 | def is_backbone(n): 247 | return 'bert' in n 248 | assignments = [] 249 | backbone_gvs = [] 250 | head_gvs = [] 251 | for grad,var in grads_and_vars: 252 | if is_backbone(var.name): 253 | backbone_gvs.append((grad,var)) 254 | else: 255 | head_gvs.append((grad,var)) 256 | assignments += self._apply_gradients(backbone_gvs,self.learning_rate) 257 | assignments += self._apply_gradients(head_gvs,self.learning_rate * self.head_lr_ratio) 258 | else: 259 | assignments = self._apply_gradients(grads_and_vars, self.learning_rate) 260 | return tf.group(*assignments,name=name) 261 | 262 | def _do_use_weight_decay(self, param_name): 263 | """Whether to use L2 weight decay for `param_name`.""" 264 | if not self.weight_decay_rate: 265 | return False 266 | if self.exclude_from_weight_decay: 267 | for r in self.exclude_from_weight_decay: 268 | if re.search(r, param_name) is not None: 269 | return False 270 | return True 271 | 272 | def _get_variable_name(self, param_name): 273 | """Get the variable name from the tensor name.""" 274 | m = re.match("^(.*):\\d+$", param_name) 275 | if m is not None: 276 | param_name = m.group(1) 277 | return param_name 278 | -------------------------------------------------------------------------------- /code/pipeline.py: -------------------------------------------------------------------------------- 1 | import subprocess 2 | import time 3 | import os 4 | import logging 5 | import copy 6 | 7 | import threading 8 | from multiprocessing import Process 9 | import multiprocessing as mp 10 | 11 | logging.basicConfig() 12 | logger = logging.getLogger('pipeline') 13 | fh = logging.FileHandler('/tmp/pipeline.log') 14 | fh.setLevel(logging.INFO) 15 | ch = logging.StreamHandler() 16 | ch.setLevel(logging.INFO) 17 | logger.setLevel(logging.INFO) 18 | logger.addHandler(fh) 19 | logger.addHandler(ch) 20 | 21 | class Timer(object): 22 | def __init__(self): 23 | self.start_time = time.time() 24 | 25 | def get_current_time(self): 26 | return (time.time() - self.start_time) / 3600 27 | 28 | def train_model(args,cmd): 29 | logger.info('start to train model {}: {}'.format(args.get('OUTPUT_DIR',''),timer.get_current_time())) 30 | os.system(cmd.format(**args)) 31 | logger.info('finish train model {}: {}'.format(args.get('OUTPUT_DIR',''),timer.get_current_time())) 32 | 33 | timer = Timer() 34 | 35 | # 数据准备 36 | logger.info(f'data prepare start: {timer.get_current_time()} ') 37 | 38 | prepare = subprocess.Popen( 39 | "bash prepare.sh", shell=True 40 | ) 41 | prepare.wait() 42 | logger.info(f'data prepare finished: {timer.get_current_time()}') 43 | 44 | # 预训练 45 | logger.info(f'electra pretrain start: {timer.get_current_time()} ') 46 | 47 | pretrain = subprocess.Popen( 48 | "bash pretrain.sh", shell=True 49 | ) 50 | pretrain.wait() 51 | 52 | logger.info(f'electra pretrain finished: {timer.get_current_time()}') 53 | 54 | 55 | # extra pretrain 56 | basic_base_args = { 57 | "BERT_DIR":"../user_data/electra/electra_180g_base", 58 | "CONFIG_FILE":"../user_data/electra/electra_180g_base/base_discriminator_config.json", 59 | "INIT_CHECKPOINT":"../user_data/models/base/model.ckpt-7000", 60 | "DATA_DIR":'../user_data/tcdata', 61 | "SEED":20190525, 62 | "EMBEDDING_DROPOUT":0.1, 63 | "OUTPUT_DIR":"../user_data/models/bif_extra_enhance_electra_base_pretrain" 64 | } 65 | 66 | basic_large_args = { 67 | "BERT_DIR":"../user_data/electra/electra_180g_large", 68 | "CONFIG_FILE":"../user_data/electra/electra_180g_large/large_discriminator_config.json", 69 | "INIT_CHECKPOINT":"../user_data/models/large/model.ckpt-5000", 70 | "DATA_DIR":'../user_data/tcdata', 71 | "SEED":807, 72 | "EMBEDDING_DROPOUT":0.2, 73 | "OUTPUT_DIR":"../user_data/models/bif_extra_enhance_electra_large_pretrain" 74 | } 75 | 76 | bif_extra_cmd = ''' 77 | python run_biaffine_ner.py \ 78 | --task_name=ner \ 79 | --vocab_file={BERT_DIR}/vocab.txt \ 80 | --bert_config_file={CONFIG_FILE} \ 81 | --init_checkpoint={INIT_CHECKPOINT} \ 82 | --do_lower_case=True \ 83 | --max_seq_length=64 \ 84 | --train_batch_size=32 \ 85 | --learning_rate=3e-5 \ 86 | --num_train_epochs=1.0 \ 87 | --neg_sample=1.0 \ 88 | --save_checkpoints_steps=450 \ 89 | --do_train_and_eval=true \ 90 | --do_train=false \ 91 | --do_eval=false \ 92 | --do_predict=false \ 93 | --use_fgm=true \ 94 | --fgm_epsilon=0.8 \ 95 | --fgm_loss_ratio=1.0 \ 96 | --spatial_dropout=0.3 \ 97 | --embedding_dropout={EMBEDDING_DROPOUT} \ 98 | --head_lr_ratio=1.0 \ 99 | --pooling_type=last \ 100 | --extra_pretrain=true \ 101 | --enhance_data=true \ 102 | --electra=true \ 103 | --dp_decode=true \ 104 | --amp=true \ 105 | --seed={SEED} \ 106 | --data_dir={DATA_DIR} \ 107 | --output_dir={OUTPUT_DIR} 108 | ''' 109 | 110 | pool = mp.Pool(processes = 2) 111 | 112 | # base bif 113 | pool.apply_async(train_model,(basic_base_args, bif_extra_cmd)) 114 | 115 | # large bif 116 | pool.apply_async(train_model,(basic_large_args, bif_extra_cmd)) 117 | 118 | 119 | pool.close() 120 | pool.join() 121 | 122 | 123 | # finetune 124 | bif_finetune_cmd = ''' 125 | python run_biaffine_ner.py \ 126 | --task_name=ner \ 127 | --vocab_file={BERT_DIR}/vocab.txt \ 128 | --bert_config_file={CONFIG_FILE} \ 129 | --init_checkpoint={INIT_CHECKPOINT} \ 130 | --do_lower_case=True \ 131 | --max_seq_length=64 \ 132 | --train_batch_size=16 \ 133 | --learning_rate=2e-5 \ 134 | --num_train_epochs=5.0 \ 135 | --neg_sample=1.0 \ 136 | --save_checkpoints_steps=276 \ 137 | --do_train_and_eval=true \ 138 | --do_train=false \ 139 | --do_eval=false \ 140 | --do_predict=false \ 141 | --use_fgm=true \ 142 | --fgm_epsilon=0.8 \ 143 | --fgm_loss_ratio=1.0 \ 144 | --spatial_dropout={SPATIAL_DROPOUT} \ 145 | --embedding_dropout={EMBEDDING_DROPOUT} \ 146 | --head_lr_ratio=1.0 \ 147 | --pooling_type=last \ 148 | --start_swa_step=0 \ 149 | --swa_steps=100 \ 150 | --biaffine_size=150 \ 151 | --electra=false \ 152 | --dp_decode=true \ 153 | --amp=true \ 154 | --seed={SEED} \ 155 | --fold_id={FOLD_ID} \ 156 | --fold_num={FOLD_NUM} \ 157 | --data_dir={DATA_DIR} \ 158 | --output_dir={OUTPUT_DIR} 159 | ''' 160 | 161 | basic_base_finetune_args = { 162 | "BERT_DIR":"../user_data/electra/electra_180g_base", 163 | "CONFIG_FILE":"../user_data/electra/electra_180g_base/base_discriminator_config.json", 164 | "INIT_CHECKPOINT":"../user_data/models/bif_extra_enhance_electra_base_pretrain/export/f1_export/model.ckpt", 165 | "DATA_DIR":'../user_data/tcdata', 166 | "OUTPUT_DIR":'', 167 | "SEED": 666, 168 | "SPATIAL_DROPOUT":0.1, 169 | "EMBEDDING_DROPOUT":0.1, 170 | "FOLD_ID": 0, 171 | "FOLD_NUM": 5 172 | } 173 | 174 | basic_large_finetune_args = { 175 | "BERT_DIR":"../user_data/electra/electra_180g_large", 176 | "CONFIG_FILE":"../user_data/electra/electra_180g_large/large_discriminator_config.json", 177 | "INIT_CHECKPOINT":"../user_data/models/bif_extra_enhance_electra_large_pretrain/export/f1_export/model.ckpt", 178 | "DATA_DIR":'../user_data/tcdata', 179 | "OUTPUT_DIR":'', 180 | "SEED": 777, 181 | "SPATIAL_DROPOUT":0.2, 182 | "EMBEDDING_DROPOUT":0.2, 183 | "FOLD_ID": 0, 184 | "FOLD_NUM": 5 185 | } 186 | 187 | 188 | base_outdir_format = "../user_data/models/k-fold/bif_electra_base_pretrain_fold_{}" 189 | large_outdir_format = "../user_data/models/k-fold/bif_electra_large_pretrain_fold_{}" 190 | 191 | pool = mp.Pool(processes = 3) 192 | for i in range(5): 193 | # base 194 | # bif 195 | args = copy.deepcopy(basic_base_finetune_args) 196 | args['FOLD_ID'] = i 197 | args['OUTPUT_DIR'] = base_outdir_format.format(i) 198 | pool.apply_async(train_model,(args, bif_finetune_cmd)) 199 | 200 | pool.close() 201 | pool.join() 202 | 203 | pool = mp.Pool(processes = 1) 204 | for i in range(5): 205 | args = copy.deepcopy(basic_large_finetune_args) 206 | args['FOLD_ID'] = i 207 | args['OUTPUT_DIR'] = large_outdir_format.format(i) 208 | pool.apply_async(train_model,(args, bif_finetune_cmd)) 209 | 210 | pool.close() 211 | pool.join() 212 | 213 | bif_pred_cmd = ''' 214 | python run_biaffine_ner.py \ 215 | --task_name=ner \ 216 | --vocab_file={BERT_DIR}/vocab.txt \ 217 | --bert_config_file={CONFIG_FILE} \ 218 | --do_lower_case=True \ 219 | --max_seq_length=64 \ 220 | --do_predict=true \ 221 | --use_fgm=true \ 222 | --pooling_type=last \ 223 | --biaffine_size=150 \ 224 | --dp_decode=true \ 225 | --fake_data=true \ 226 | --fold_id={FOLD_ID} \ 227 | --fold_num={FOLD_NUM} \ 228 | --data_dir={DATA_DIR} \ 229 | --output_dir={OUTPUT_DIR}/export/f1_export 230 | ''' 231 | 232 | pool = mp.Pool(processes = 4) 233 | for i in range(5): 234 | # # base 235 | args = copy.deepcopy(basic_base_finetune_args) 236 | args['FOLD_ID'] = i 237 | args['OUTPUT_DIR'] = base_outdir_format.format(i) 238 | pool.apply_async(train_model,(args, bif_pred_cmd)) 239 | 240 | args = copy.deepcopy(basic_large_finetune_args) 241 | args['FOLD_ID'] = i 242 | args['OUTPUT_DIR'] = large_outdir_format.format(i) 243 | pool.apply_async(train_model,(args, bif_pred_cmd)) 244 | 245 | pool.close() 246 | pool.join() 247 | 248 | logger.info(f'assemble fake start: {timer.get_current_time()}') 249 | from assemble import assemble_fake 250 | assemble_fake() 251 | logger.info(f'assemble fake finish: {timer.get_current_time()}') 252 | 253 | bif_fake_cmd = ''' 254 | python run_biaffine_ner.py \ 255 | --task_name=ner \ 256 | --vocab_file={BERT_DIR}/vocab.txt \ 257 | --bert_config_file={CONFIG_FILE} \ 258 | --init_checkpoint={INIT_CHECKPOINT} \ 259 | --do_lower_case=True \ 260 | --max_seq_length=64 \ 261 | --train_batch_size=32 \ 262 | --learning_rate=2e-5 \ 263 | --num_train_epochs=5.0 \ 264 | --neg_sample=0.15 \ 265 | --save_checkpoints_steps=500 \ 266 | --do_train_and_eval=true \ 267 | --do_train=false \ 268 | --do_eval=false \ 269 | --do_predict=false \ 270 | --use_fgm=true \ 271 | --fgm_epsilon=0.8 \ 272 | --fgm_loss_ratio=1.0 \ 273 | --spatial_dropout=0.1 \ 274 | --embedding_dropout=0.1 \ 275 | --head_lr_ratio=1.0 \ 276 | --pooling_type=last \ 277 | --start_swa_step=0 \ 278 | --swa_steps=100 \ 279 | --biaffine_size=150 \ 280 | --electra=false \ 281 | --dp_decode=true \ 282 | --amp=true \ 283 | --fake_data=true \ 284 | --fold_id={FOLD_ID} \ 285 | --fold_num={FOLD_NUM} \ 286 | --data_dir={DATA_DIR} \ 287 | --output_dir={OUTPUT_DIR} 288 | ''' 289 | 290 | fake_outdir_format = "../user_data/models/k-fold/bif_fake_tags_fold_{}" 291 | 292 | pool = mp.Pool(processes = 3) 293 | for i in range(5): 294 | args = copy.deepcopy(basic_base_finetune_args) 295 | args['FOLD_ID'] = i 296 | args['OUTPUT_DIR'] = fake_outdir_format.format(i) 297 | pool.apply_async(train_model,(args, bif_fake_cmd)) 298 | 299 | pool.close() 300 | pool.join() 301 | 302 | fake_pred_cmd = ''' 303 | python run_biaffine_ner.py \ 304 | --task_name=ner \ 305 | --vocab_file={BERT_DIR}/vocab.txt \ 306 | --bert_config_file={CONFIG_FILE} \ 307 | --do_lower_case=True \ 308 | --max_seq_length=64 \ 309 | --do_predict=true \ 310 | --use_fgm=true \ 311 | --pooling_type=last \ 312 | --biaffine_size=150 \ 313 | --dp_decode=true \ 314 | --fake_data=false \ 315 | --fold_id={FOLD_ID} \ 316 | --fold_num={FOLD_NUM} \ 317 | --data_dir={DATA_DIR} \ 318 | --output_dir={OUTPUT_DIR}/export/f1_export 319 | ''' 320 | 321 | pool = mp.Pool(processes = 5) 322 | for i in range(5): 323 | args = copy.deepcopy(basic_base_finetune_args) 324 | args['FOLD_ID'] = i 325 | args['OUTPUT_DIR'] = fake_outdir_format.format(i) 326 | pool.apply_async(train_model,(args, fake_pred_cmd)) 327 | 328 | pool.close() 329 | pool.join() 330 | 331 | from assemble import assemble_final 332 | assemble_final() -------------------------------------------------------------------------------- /code/prepare.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | mkdir -p ../user_data/tcdata 4 | mkdir -p ../user_data/texts 5 | 6 | cp -r ../tcdata ../user_data 7 | 8 | echo "Data preprocess" 9 | python create_raw_text.py -------------------------------------------------------------------------------- /code/pretrain.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | cd electra-pretrain 4 | 5 | bash pretrain.sh 6 | 7 | cd .. -------------------------------------------------------------------------------- /code/run.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | python pipeline.py -------------------------------------------------------------------------------- /code/simple_run.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # 数据预处理 4 | mkdir -p ../user_data/tcdata 5 | mkdir -p ../user_data/texts 6 | 7 | cp -r ../tcdata ../user_data 8 | 9 | echo "Data preprocess" 10 | python create_raw_text.py 11 | 12 | # 预训练 13 | 14 | cd electra-pretrain 15 | 16 | export DATA_DIR=../../user_data 17 | export ELECTRA_DIR=../../user_data/electra 18 | 19 | echo 'Prepare pretraining data...' 20 | python build_pretraining_dataset.py \ 21 | --corpus-dir=${DATA_DIR}/texts \ 22 | --max-seq-length=32 \ 23 | --vocab-file=${ELECTRA_DIR}/electra_180g_base/vocab.txt \ 24 | --output-dir=${DATA_DIR}/pretrain_tfrecords 25 | 26 | echo "Pretrain base electra model ~= 1 hour on Titan RTX 24G" 27 | # 如果只有12G显存,可以将train_batch_size改成32,accumulation_step设置为3,耗时久一点,效果应该差不多 28 | python run_pretraining.py \ 29 | --data-dir=${DATA_DIR} \ 30 | --model-name=base \ 31 | --hparams='{"max_seq_length":32, "accumulation_step": 2, "use_amp": true, "learning_rate": 0.0002,"model_size": "base","eval_batch_size":128,"train_batch_size": 64, "init_checkpoint": "../../user_data/electra/electra_180g_base/electra_180g_base.ckpt", "vocab_file": "../../user_data/electra/electra_180g_base/vocab.txt"}' 32 | 33 | cd .. 34 | 35 | # 利用额外数据训练 36 | 37 | export BERT_DIR=../user_data/electra/electra_180g_base 38 | export CONFIG_FILE=../user_data/electra/electra_180g_base/base_discriminator_config.json 39 | export INIT_CHECKPOINT=../user_data/models/base/model.ckpt-7000 40 | export DATA_DIR=../user_data/tcdata 41 | export SEED=20190525 42 | export EMBEDDING_DROPOUT=0.1 43 | export OUTPUT_DIR=../user_data/models/bif_extra_enhance_electra_base_pretrain 44 | 45 | python run_biaffine_ner.py \ 46 | --task_name=ner \ 47 | --vocab_file=${BERT_DIR}/vocab.txt \ 48 | --bert_config_file=${CONFIG_FILE} \ 49 | --init_checkpoint=${INIT_CHECKPOINT} \ 50 | --do_lower_case=True \ 51 | --max_seq_length=32 \ 52 | --train_batch_size=32 \ 53 | --learning_rate=3e-5 \ 54 | --num_train_epochs=1.0 \ 55 | --neg_sample=1.0 \ 56 | --save_checkpoints_steps=450 \ 57 | --do_train_and_eval=true \ 58 | --do_train=false \ 59 | --do_eval=false \ 60 | --do_predict=false \ 61 | --use_fgm=true \ 62 | --fgm_epsilon=0.8 \ 63 | --fgm_loss_ratio=1.0 \ 64 | --spatial_dropout=0.3 \ 65 | --embedding_dropout=${EMBEDDING_DROPOUT} \ 66 | --head_lr_ratio=1.0 \ 67 | --pooling_type=last \ 68 | --extra_pretrain=true \ 69 | --enhance_data=true \ 70 | --electra=true \ 71 | --dp_decode=true \ 72 | --amp=false \ 73 | --seed=${SEED} \ 74 | --data_dir=${DATA_DIR} \ 75 | --output_dir=${OUTPUT_DIR} 76 | 77 | 78 | # finetune 79 | 80 | export INIT_CHECKPOINT=../user_data/models/bif_extra_enhance_electra_base_pretrain/export/f1_export/model.ckpt 81 | export OUTPUT_DIR=../user_data/models/k-fold/bif_electra_base_pretrain 82 | export SEED=666 83 | export SPATIAL_DROPOUT=0.1 84 | export EMBEDDING_DROPOUT=0.1 85 | 86 | python run_biaffine_ner.py \ 87 | --task_name=ner \ 88 | --vocab_file=${BERT_DIR}/vocab.txt \ 89 | --bert_config_file=${CONFIG_FILE} \ 90 | --init_checkpoint=${INIT_CHECKPOINT} \ 91 | --do_lower_case=True \ 92 | --max_seq_length=32 \ 93 | --train_batch_size=16 \ 94 | --learning_rate=2e-5 \ 95 | --num_train_epochs=5.0 \ 96 | --neg_sample=1.0 \ 97 | --save_checkpoints_steps=276 \ 98 | --do_train_and_eval=true \ 99 | --do_train=false \ 100 | --do_eval=false \ 101 | --do_predict=false \ 102 | --use_fgm=true \ 103 | --fgm_epsilon=0.8 \ 104 | --fgm_loss_ratio=1.0 \ 105 | --spatial_dropout=${SPATIAL_DROPOUT} \ 106 | --embedding_dropout=${EMBEDDING_DROPOUT} \ 107 | --head_lr_ratio=1.0 \ 108 | --pooling_type=last \ 109 | --start_swa_step=0 \ 110 | --swa_steps=100 \ 111 | --biaffine_size=150 \ 112 | --electra=false \ 113 | --dp_decode=true \ 114 | --amp=false \ 115 | --seed=${SEED} \ 116 | --data_dir=${DATA_DIR} \ 117 | --output_dir=${OUTPUT_DIR} -------------------------------------------------------------------------------- /code/tokenization.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Tokenization classes.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import collections 22 | import re 23 | import unicodedata 24 | import six 25 | import tensorflow as tf 26 | 27 | 28 | 29 | def convert_to_unicode(text): 30 | """Converts `text` to Unicode (if it's not already), assuming utf-8 input.""" 31 | if six.PY3: 32 | if isinstance(text, str): 33 | return text 34 | elif isinstance(text, bytes): 35 | return text.decode("utf-8", "ignore") 36 | else: 37 | raise ValueError("Unsupported string type: %s" % (type(text))) 38 | elif six.PY2: 39 | if isinstance(text, str): 40 | return text.decode("utf-8", "ignore") 41 | elif isinstance(text, unicode): 42 | return text 43 | else: 44 | raise ValueError("Unsupported string type: %s" % (type(text))) 45 | else: 46 | raise ValueError("Not running on Python2 or Python 3?") 47 | 48 | 49 | def printable_text(text): 50 | """Returns text encoded in a way suitable for print or `tf.logging`.""" 51 | 52 | # These functions want `str` for both Python2 and Python3, but in one case 53 | # it's a Unicode string and in the other it's a byte string. 54 | if six.PY3: 55 | if isinstance(text, str): 56 | return text 57 | elif isinstance(text, bytes): 58 | return text.decode("utf-8", "ignore") 59 | else: 60 | raise ValueError("Unsupported string type: %s" % (type(text))) 61 | elif six.PY2: 62 | if isinstance(text, str): 63 | return text 64 | elif isinstance(text, unicode): 65 | return text.encode("utf-8") 66 | else: 67 | raise ValueError("Unsupported string type: %s" % (type(text))) 68 | else: 69 | raise ValueError("Not running on Python2 or Python 3?") 70 | 71 | 72 | def load_vocab(vocab_file): 73 | """Loads a vocabulary file into a dictionary.""" 74 | vocab = collections.OrderedDict() 75 | index = 0 76 | with tf.gfile.GFile(vocab_file, "r") as reader: 77 | while True: 78 | token = convert_to_unicode(reader.readline()) 79 | if not token: 80 | break 81 | token = token.strip() 82 | vocab[token] = index 83 | index += 1 84 | return vocab 85 | 86 | 87 | def convert_by_vocab(vocab, items): 88 | """Converts a sequence of [tokens|ids] using the vocab.""" 89 | output = [] 90 | for item in items: 91 | output.append(vocab[item]) 92 | return output 93 | 94 | 95 | def convert_tokens_to_ids(vocab, tokens): 96 | return convert_by_vocab(vocab, tokens) 97 | 98 | 99 | def convert_ids_to_tokens(inv_vocab, ids): 100 | return convert_by_vocab(inv_vocab, ids) 101 | 102 | 103 | def whitespace_tokenize(text): 104 | """Runs basic whitespace cleaning and splitting on a piece of text.""" 105 | text = text.strip() 106 | if not text: 107 | return [] 108 | tokens = text.split() 109 | return tokens 110 | 111 | class SimpleTokenizer(object): 112 | def __init__(self, vocab_file, do_lower_case=True): 113 | self.vocab = load_vocab(vocab_file) 114 | self.do_lower_case = do_lower_case 115 | self.inv_vocab = {v: k for k, v in self.vocab.items()} 116 | 117 | def tokenize(self, text): 118 | if self.do_lower_case: 119 | text = text.lower() 120 | return [token if token in self.vocab else '[UNK]' for token in text.strip()] 121 | 122 | def convert_tokens_to_ids(self, tokens): 123 | return convert_by_vocab(self.vocab, tokens) 124 | 125 | def convert_ids_to_tokens(self, ids): 126 | return convert_by_vocab(self.inv_vocab, ids) 127 | 128 | 129 | class FullTokenizer(object): 130 | """Runs end-to-end tokenziation.""" 131 | 132 | def __init__(self, vocab_file, do_lower_case=True): 133 | self.vocab = load_vocab(vocab_file) 134 | self.inv_vocab = {v: k for k, v in self.vocab.items()} 135 | self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case) 136 | self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab) 137 | 138 | def tokenize(self, text): 139 | split_tokens = [] 140 | for token in self.basic_tokenizer.tokenize(text): 141 | for sub_token in self.wordpiece_tokenizer.tokenize(token): 142 | split_tokens.append(sub_token) 143 | 144 | return split_tokens 145 | 146 | def convert_tokens_to_ids(self, tokens): 147 | return convert_by_vocab(self.vocab, tokens) 148 | 149 | def convert_ids_to_tokens(self, ids): 150 | return convert_by_vocab(self.inv_vocab, ids) 151 | 152 | 153 | class BasicTokenizer(object): 154 | """Runs basic tokenization (punctuation splitting, lower casing, etc.).""" 155 | 156 | def __init__(self, do_lower_case=True): 157 | """Constructs a BasicTokenizer. 158 | 159 | Args: 160 | do_lower_case: Whether to lower case the input. 161 | """ 162 | self.do_lower_case = do_lower_case 163 | 164 | def tokenize(self, text): 165 | """Tokenizes a piece of text.""" 166 | text = convert_to_unicode(text) 167 | text = self._clean_text(text) 168 | 169 | # This was added on November 1st, 2018 for the multilingual and Chinese 170 | # models. This is also applied to the English models now, but it doesn't 171 | # matter since the English models were not trained on any Chinese data 172 | # and generally don't have any Chinese data in them (there are Chinese 173 | # characters in the vocabulary because Wikipedia does have some Chinese 174 | # words in the English Wikipedia.). 175 | text = self._tokenize_chinese_chars(text) 176 | 177 | orig_tokens = whitespace_tokenize(text) 178 | split_tokens = [] 179 | for token in orig_tokens: 180 | if self.do_lower_case: 181 | token = token.lower() 182 | token = self._run_strip_accents(token) 183 | split_tokens.extend(self._run_split_on_punc(token)) 184 | 185 | output_tokens = whitespace_tokenize(" ".join(split_tokens)) 186 | return output_tokens 187 | 188 | def _run_strip_accents(self, text): 189 | """Strips accents from a piece of text.""" 190 | text = unicodedata.normalize("NFD", text) 191 | output = [] 192 | for char in text: 193 | cat = unicodedata.category(char) 194 | if cat == "Mn": 195 | continue 196 | output.append(char) 197 | return "".join(output) 198 | 199 | def _run_split_on_punc(self, text): 200 | """Splits punctuation on a piece of text.""" 201 | chars = list(text) 202 | i = 0 203 | start_new_word = True 204 | output = [] 205 | while i < len(chars): 206 | char = chars[i] 207 | if _is_punctuation(char): 208 | output.append([char]) 209 | start_new_word = True 210 | else: 211 | if start_new_word: 212 | output.append([]) 213 | start_new_word = False 214 | output[-1].append(char) 215 | i += 1 216 | 217 | return ["".join(x) for x in output] 218 | 219 | def _tokenize_chinese_chars(self, text): 220 | """Adds whitespace around any CJK character.""" 221 | output = [] 222 | for char in text: 223 | cp = ord(char) 224 | if self._is_chinese_char(cp): 225 | output.append(" ") 226 | output.append(char) 227 | output.append(" ") 228 | else: 229 | output.append(char) 230 | return "".join(output) 231 | 232 | def _is_chinese_char(self, cp): 233 | """Checks whether CP is the codepoint of a CJK character.""" 234 | # This defines a "chinese character" as anything in the CJK Unicode block: 235 | # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block) 236 | # 237 | # Note that the CJK Unicode block is NOT all Japanese and Korean characters, 238 | # despite its name. The modern Korean Hangul alphabet is a different block, 239 | # as is Japanese Hiragana and Katakana. Those alphabets are used to write 240 | # space-separated words, so they are not treated specially and handled 241 | # like the all of the other languages. 242 | if ((cp >= 0x4E00 and cp <= 0x9FFF) or # 243 | (cp >= 0x3400 and cp <= 0x4DBF) or # 244 | (cp >= 0x20000 and cp <= 0x2A6DF) or # 245 | (cp >= 0x2A700 and cp <= 0x2B73F) or # 246 | (cp >= 0x2B740 and cp <= 0x2B81F) or # 247 | (cp >= 0x2B820 and cp <= 0x2CEAF) or 248 | (cp >= 0xF900 and cp <= 0xFAFF) or # 249 | (cp >= 0x2F800 and cp <= 0x2FA1F)): # 250 | return True 251 | 252 | return False 253 | 254 | def _clean_text(self, text): 255 | """Performs invalid character removal and whitespace cleanup on text.""" 256 | output = [] 257 | for char in text: 258 | cp = ord(char) 259 | if cp == 0 or cp == 0xfffd or _is_control(char): 260 | continue 261 | if _is_whitespace(char): 262 | output.append(" ") 263 | else: 264 | output.append(char) 265 | return "".join(output) 266 | 267 | 268 | class WordpieceTokenizer(object): 269 | """Runs WordPiece tokenziation.""" 270 | 271 | def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=200): 272 | self.vocab = vocab 273 | self.unk_token = unk_token 274 | self.max_input_chars_per_word = max_input_chars_per_word 275 | 276 | def tokenize(self, text): 277 | """Tokenizes a piece of text into its word pieces. 278 | 279 | This uses a greedy longest-match-first algorithm to perform tokenization 280 | using the given vocabulary. 281 | 282 | For example: 283 | input = "unaffable" 284 | output = ["un", "##aff", "##able"] 285 | 286 | Args: 287 | text: A single token or whitespace separated tokens. This should have 288 | already been passed through `BasicTokenizer. 289 | 290 | Returns: 291 | A list of wordpiece tokens. 292 | """ 293 | 294 | text = convert_to_unicode(text) 295 | 296 | output_tokens = [] 297 | for token in whitespace_tokenize(text): 298 | chars = list(token) 299 | if len(chars) > self.max_input_chars_per_word: 300 | output_tokens.append(self.unk_token) 301 | continue 302 | 303 | is_bad = False 304 | start = 0 305 | sub_tokens = [] 306 | while start < len(chars): 307 | end = len(chars) 308 | cur_substr = None 309 | while start < end: 310 | substr = "".join(chars[start:end]) 311 | if start > 0: 312 | substr = "##" + substr 313 | if substr in self.vocab: 314 | cur_substr = substr 315 | break 316 | end -= 1 317 | if cur_substr is None: 318 | is_bad = True 319 | break 320 | sub_tokens.append(cur_substr) 321 | start = end 322 | 323 | if is_bad: 324 | output_tokens.append(self.unk_token) 325 | else: 326 | output_tokens.extend(sub_tokens) 327 | return output_tokens 328 | 329 | 330 | def _is_whitespace(char): 331 | """Checks whether `chars` is a whitespace character.""" 332 | # \t, \n, and \r are technically contorl characters but we treat them 333 | # as whitespace since they are generally considered as such. 334 | if char == " " or char == "\t" or char == "\n" or char == "\r": 335 | return True 336 | cat = unicodedata.category(char) 337 | if cat == "Zs": 338 | return True 339 | return False 340 | 341 | 342 | def _is_control(char): 343 | """Checks whether `chars` is a control character.""" 344 | # These are technically control characters but we count them as whitespace 345 | # characters. 346 | if char == "\t" or char == "\n" or char == "\r": 347 | return False 348 | cat = unicodedata.category(char) 349 | if cat.startswith("C"): 350 | return True 351 | return False 352 | 353 | 354 | def _is_punctuation(char): 355 | """Checks whether `chars` is a punctuation character.""" 356 | cp = ord(char) 357 | # We treat all non-letter/number ASCII as punctuation. 358 | # Characters such as "^", "$", and "`" are not in the Unicode 359 | # Punctuation class but we treat them as punctuation anyways, for 360 | # consistency. 361 | if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or 362 | (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)): 363 | return True 364 | cat = unicodedata.category(char) 365 | if cat.startswith("P"): 366 | return True 367 | return False 368 | -------------------------------------------------------------------------------- /code/utils.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | import re 4 | import random 5 | import json 6 | import glob 7 | import codecs 8 | import os 9 | import shutil 10 | from conlleval import return_report 11 | from tqdm import tqdm 12 | 13 | np.random.seed(20190525) 14 | random.seed(20190525) 15 | 16 | def normalize(text): 17 | text = re.sub('[0-9]','0',text) 18 | text = re.sub('[a-zA-Z]','A',text) 19 | return text 20 | 21 | def convert_data_format(sentence): 22 | word = '' 23 | tag = '' 24 | text = '' 25 | tag_words = [] 26 | for i,(c,t) in enumerate(sentence): 27 | text += c 28 | if t[0] in ['B','S','O']: 29 | if word: 30 | tag_words.append((word,i,tag)) 31 | if t[0] == 'O': 32 | word = '' 33 | tag = '' 34 | continue 35 | word = c 36 | tag = t[2:] 37 | else: 38 | word += c 39 | if word: 40 | tag_words.append((word,i+1,tag)) 41 | 42 | entities = {} 43 | for w,i,t in tag_words: 44 | if t not in entities: 45 | entities[t] = {} 46 | if w in entities[t]: 47 | entities[t][w].append([i-len(w),i-1]) 48 | else: 49 | entities[t][w] = [[i-len(w),i-1]] 50 | 51 | return {"text":text,"label":entities} 52 | 53 | def convert_back_to_bio(entities,text): 54 | tags = ['O'] * len(text) 55 | for t,words in entities.items(): 56 | for w,spans in words.items(): 57 | for s,e in spans: 58 | tags[s:e+1] = ['I-' + t] * (e-s+1) 59 | tags[s] = 'B-' + t 60 | return tags 61 | 62 | def iobes_iob(tags): 63 | """ 64 | IOBES -> IOB 65 | """ 66 | new_tags = [] 67 | for i, tag in enumerate(tags): 68 | if tag.split('-')[0] == 'B': 69 | new_tags.append(tag) 70 | elif tag.split('-')[0] == 'I': 71 | new_tags.append(tag) 72 | elif tag.split('-')[0] == 'S': 73 | new_tags.append(tag.replace('S-', 'B-')) 74 | elif tag.split('-')[0] == 'E': 75 | new_tags.append(tag.replace('E-', 'I-')) 76 | elif tag.split('-')[0] == 'O': 77 | new_tags.append(tag) 78 | else: 79 | raise Exception('Invalid format!') 80 | return new_tags 81 | 82 | def iob_iobes(tags): 83 | """ 84 | IOB -> IOBES 85 | """ 86 | new_tags = [] 87 | for i, tag in enumerate(tags): 88 | if tag == 'O': 89 | new_tags.append(tag) 90 | elif tag.split('-')[0] == 'B': 91 | if i + 1 != len(tags) and \ 92 | tags[i + 1].split('-')[0] == 'I': 93 | new_tags.append(tag) 94 | else: 95 | new_tags.append(tag.replace('B-', 'S-')) 96 | elif tag.split('-')[0] == 'I': 97 | if i + 1 < len(tags) and \ 98 | tags[i + 1].split('-')[0] == 'I': 99 | new_tags.append(tag) 100 | else: 101 | new_tags.append(tag.replace('I-', 'E-')) 102 | else: 103 | raise Exception('Invalid IOB format!') 104 | return new_tags 105 | 106 | def read_data(fnames, zeros=False, lower=False): 107 | ''' 108 | Read all data into memory and convert to iobes tags. 109 | A line must contain at least a word and its tag. 110 | Sentences are separated by empty lines. 111 | Args: 112 | - fnames: a list of filenames contain the data 113 | - zeros: if we need to replace digits to 0s 114 | Return: 115 | - sentences: a list of sentnences, each sentence contains a list of (word,tag) pairs 116 | ''' 117 | sentences = [] 118 | sentence = [] 119 | if not isinstance(fnames, list): 120 | fnames = [fnames] 121 | for fname in fnames: 122 | sentence_num = 0 123 | num = 0 124 | print("read data from file {0}".format(fname)) 125 | for line in codecs.open(fname, 'r', 'utf8'): 126 | num+=1 127 | line = line.rstrip() 128 | line = re.sub("\d+",'0',line) if zeros else line 129 | if not line: 130 | if len(sentence) > 0: 131 | sentences.append(sentence) 132 | sentence_num += 1 133 | sentence = [] 134 | else: 135 | # in case space is a word 136 | if line[0] == " ": 137 | line = "$" + line[1:] 138 | word = line.split() 139 | else: 140 | word= line.split(' ') 141 | assert len(word) >= 2, print(fname,num,[word[0]],line) 142 | word[0] = word[0].lower() if lower else word[0] 143 | sentence.append(word) 144 | if len(sentence) > 0: 145 | sentence_num += 1 146 | sentences.append(sentence) 147 | print("Got {0} sentences from file {1}".format(sentence_num,fname)) 148 | print("Read all the sentences from training files: {0} sentences".format(len(sentences))) 149 | return sentences 150 | 151 | def iob2(tags): 152 | """ 153 | Check that tags have a valid IOB format. 154 | Tags in IOB1 format are converted to IOB2. 155 | """ 156 | for i, tag in enumerate(tags): 157 | if tag == 'O': 158 | continue 159 | split = tag.split('-') 160 | if len(split) != 2 or split[0] not in ['I', 'B']: 161 | return False 162 | if split[0] == 'B': 163 | continue 164 | elif i == 0 or tags[i - 1] == 'O': # conversion IOB1 to IOB2 165 | tags[i] = 'B' + tag[1:] 166 | elif tags[i - 1][1:] == tag[1:]: 167 | continue 168 | else: # conversion IOB1 to IOB2 169 | tags[i] = 'B' + tag[1:] 170 | return True 171 | 172 | def update_tag_scheme(sentences, tag_scheme='iobes', convert_to_iob=False): 173 | """ 174 | Check and update sentences tagging scheme to IOB2. 175 | Only IOB1 and IOB2 schemes are accepted. 176 | """ 177 | for i, s in enumerate(sentences): 178 | tags = [w[-1] for w in s] 179 | if convert_to_iob: 180 | tags = iobes_iob(tags) 181 | # Check that tags are given in the IOB format 182 | if not iob2(tags): 183 | s_str = '\n'.join(' '.join(w) for w in s) 184 | raise Exception('Sentences should be given in IOB format! ' + 185 | 'Please check sentence %i:\n%s' % (i, s_str)) 186 | if tag_scheme == 'iob': 187 | # If format was IOB1, we convert to IOB2 188 | # we already did that in iob2 method 189 | for word, new_tag in zip(s, tags): 190 | word[-1] = new_tag 191 | elif tag_scheme == 'iobes': 192 | new_tags = iob_iobes(tags) 193 | for word, new_tag in zip(s, new_tags): 194 | word[-1] = new_tag 195 | else: 196 | raise Exception('Unknown tagging scheme!') 197 | 198 | def eval_ner(results, path, name): 199 | """ 200 | Run perl script to evaluate model 201 | """ 202 | if not os.path.exists(path): 203 | os.mkdir(path) 204 | output_file = os.path.join(path, name + "_ner_predict.utf8") 205 | with open(output_file, "w") as f: 206 | to_write = [] 207 | for block in results: 208 | for line in block: 209 | to_write.append(line + "\n") 210 | to_write.append("\n") 211 | 212 | f.writelines(to_write) 213 | eval_lines = return_report(output_file) 214 | f1 = float(eval_lines[1].strip().split()[-1]) 215 | return eval_lines, f1 216 | 217 | def convert_to_bio(tags): 218 | for i in range(len(tags)): 219 | t = tags[i] 220 | if t[0]=='B': 221 | j = i+1 222 | while j < len(tags) and tags[j][0] not in ['E','S']: 223 | j += 1 224 | if j >= len(tags): 225 | tags[i] = 'O' 226 | elif tags[j][0] == 'S': 227 | # error 228 | tags[i:j] = ['O'] * (j-i) 229 | elif tags[j][0] == 'E': 230 | tags[i+1:j+1] = ['I-span']*(j-i) 231 | tags = ['B'+t[1:] if t[0]=='S' else t for t in tags] 232 | return tags 233 | 234 | def get_biaffine_pred_prob(text, span_scores, label_list): 235 | candidates = [] 236 | for s in range(len(text)): 237 | for e in range(s,len(text)): 238 | candidates.append((s,e)) 239 | 240 | top_spans = [] 241 | for i,tp in enumerate(np.argmax(span_scores,axis=1)): 242 | if tp > 0: 243 | s,e = candidates[i] 244 | top_spans.append((s,e,label_list[tp],float(span_scores[i][tp]))) 245 | 246 | top_spans = sorted(top_spans, key=lambda x:x[3], reverse=True) 247 | sent_pred_mentions = [] 248 | for ns,ne,t,score in top_spans: 249 | for ts,te,_,_ in sent_pred_mentions: 250 | if ns < ts <= ne < te or ts < ns <= te < ne: 251 | #for both nested and flat ner no clash is allowed 252 | break 253 | if (ns <= ts <= te <= ne or ts <= ns <= ne <= te): 254 | #for flat ner nested mentions are not allowed 255 | break 256 | else: 257 | sent_pred_mentions.append((ns,ne,t,score)) 258 | return sent_pred_mentions 259 | 260 | def get_biaffine_pred_ner(text, span_scores, is_flat_ner=True): 261 | candidates = [] 262 | for s in range(len(text)): 263 | for e in range(s,len(text)): 264 | candidates.append((s,e)) 265 | 266 | top_spans = [] 267 | for i,tp in enumerate(np.argmax(span_scores,axis=1)): 268 | if tp > 0: 269 | s,e = candidates[i] 270 | top_spans.append((s,e,tp,span_scores[i])) 271 | 272 | top_spans = sorted(top_spans, key=lambda x:x[3][x[2]], reverse=True) 273 | 274 | # if not top_spans: 275 | # # 无论如何找一个span 276 | # # 这里是因为cluener里面基本上每句话都有实体,因此这样使用 277 | # # 如果是真实的场景,可以去掉这部分 278 | # tmp_span_scores = span_scores[:,1:] 279 | # for i,tp in enumerate(np.argmax(tmp_span_scores,axis=1)): 280 | # s,e = candidates[i] 281 | # top_spans.append((s,e,tp+1,span_scores[i])) 282 | # top_spans = sorted(top_spans, key=lambda x:x[3][x[2]], reverse=True)[:1] 283 | 284 | sent_pred_mentions = [] 285 | for ns,ne,t,score in top_spans: 286 | for ts,te,_,_ in sent_pred_mentions: 287 | if ns < ts <= ne < te or ts < ns <= te < ne: 288 | #for both nested and flat ner no clash is allowed 289 | break 290 | if is_flat_ner and (ns <= ts <= te <= ne or ts <= ns <= ne <= te): 291 | #for flat ner nested mentions are not allowed 292 | break 293 | else: 294 | sent_pred_mentions.append((ns,ne,t,[float(x) for x in score.flat])) 295 | return sent_pred_mentions 296 | 297 | def get_biaffine_pred_ner_with_dp(text, span_scores, with_logits=True, threshold=0.1): 298 | candidates = [] 299 | for s in range(len(text)): 300 | for e in range(s,len(text)): 301 | candidates.append((s,e)) 302 | 303 | top_spans = {} 304 | for i,tp in enumerate(np.argmax(span_scores,axis=1)): 305 | if tp > 0: 306 | if not with_logits and span_scores[i][tp] < threshold: 307 | continue 308 | s,e = candidates[i] 309 | # if check_special_token(text[s:e+1]): 310 | # continue 311 | top_spans[(s,e)] = (tp,span_scores[i][tp]) 312 | 313 | 314 | if not top_spans: 315 | return [] 316 | 317 | DAG = {} 318 | for k,v in top_spans: 319 | if k not in DAG: 320 | DAG[k] = [] 321 | DAG[k].append(v) 322 | 323 | route = {} 324 | N = len(text) 325 | route[N] = (0,0) 326 | for idx in range(N-1,-1,-1): 327 | if with_logits: 328 | route[idx] = max( 329 | (top_spans.get((idx,x),[0,0])[1] + route[x+1][0],x) 330 | for x in DAG.get(idx,[idx]) 331 | ) 332 | else: 333 | route[idx] = max( 334 | ( np.log(max(top_spans.get((idx,x),[0,0])[1],1e-5)) + route[x+1][0],x) 335 | for x in DAG.get(idx,[idx]) 336 | ) 337 | 338 | start = 0 339 | spans = [] 340 | while start < N: 341 | end = route[start][1] 342 | if (start,end) in top_spans: 343 | tp,score = top_spans[(start,end)] 344 | spans.append((start,end,tp,score)) 345 | start = end + 1 346 | return spans 347 | 348 | class SWAHook(tf.train.SessionRunHook): 349 | def __init__(self, swa_steps, start_swa_step, checkpoint_path): 350 | self.swa_steps = swa_steps 351 | self.start_swa_step = start_swa_step 352 | self.checkpoint_path = checkpoint_path 353 | self.pre_save_step = 0 354 | 355 | def begin(self): 356 | global_step = tf.train.get_global_step() 357 | self._global_step_tensor = tf.identity(global_step,"global_step_read") 358 | tvars = tf.trainable_variables() 359 | self.save_num = tf.Variable(initial_value=0, name="save_num",dtype=tf.float32,trainable=False) 360 | 361 | self.swa_vars = [ 362 | tf.get_variable( 363 | name=tvar.name.split(":")[0] + "/swa", 364 | shape=tvar.shape.as_list(), 365 | dtype=tf.float32, 366 | trainable=False, 367 | initializer=tf.zeros_initializer()) for tvar in tvars] 368 | 369 | self.first_assign = tf.group([y.assign(x) for x,y in zip(tvars,self.swa_vars)] + [self.save_num.assign_add(1)]) 370 | self.update = tf.group([y.assign((y*self.save_num + x)/(self.save_num+1)) for x,y in zip(tvars,self.swa_vars)]+ [self.save_num.assign_add(1)]) 371 | to_save = {x.op.name:y for x,y in zip(tvars,self.swa_vars)} 372 | to_save[global_step.op.name] = global_step 373 | self.saver = tf.train.Saver(to_save,max_to_keep=1) 374 | 375 | def after_run(self, run_context, run_values): 376 | global_step = run_context.session.run(self._global_step_tensor) 377 | if global_step >= self.start_swa_step: 378 | if self.pre_save_step == 0: 379 | run_context.session.run(self.first_assign) 380 | self.pre_save_step = global_step 381 | elif (global_step-self.pre_save_step) % self.swa_steps == 0: 382 | tf.logging.info('update swa') 383 | run_context.session.run(self.update) 384 | self.pre_save_step = global_step 385 | 386 | def end(self, session): 387 | global_step = session.run(self._global_step_tensor) 388 | self.saver.save(session,os.path.join(self.checkpoint_path,'model.ckpt'),global_step=global_step) 389 | 390 | class BestF1Exporter(tf.estimator.Exporter): 391 | def __init__(self, input_fn, examples, label_list, max_seq_length, dp=False, name='f1_export'): 392 | self._name = name 393 | self.input_fn = input_fn 394 | self.predict_examples = examples 395 | self.label_list = label_list 396 | self.max_seq_length = max_seq_length 397 | self._best_eval_result = None 398 | self.dp = dp 399 | 400 | @property 401 | def name(self): 402 | return self._name 403 | 404 | def get_biaffine_result(self,estimator): 405 | final_results = [] 406 | idx = 0 407 | for i,prediction in enumerate(tqdm(estimator.predict(input_fn=self.input_fn,yield_single_examples=True))): 408 | scores = prediction['score'] 409 | offset = 0 410 | bz = prediction['batch_size'] 411 | for j in range(bz): 412 | example = self.predict_examples[idx] 413 | text = example.text 414 | pred_text = example.text[:self.max_seq_length-2] 415 | size = len(pred_text) * (len(pred_text) + 1) // 2 416 | pred_score = scores[offset:offset+size] 417 | idx += 1 418 | offset += size 419 | if self.dp: 420 | results = get_biaffine_pred_ner_with_dp(pred_text,pred_score) 421 | else: 422 | results = get_biaffine_pred_ner(pred_text,pred_score) 423 | labels = {} 424 | for s,e,t,score in results: 425 | span = text[s:e+1] 426 | label = self.label_list[t] 427 | item = [s,e] 428 | if label not in labels: 429 | labels[label] = {span:[item]} 430 | else: 431 | if span in labels[label]: 432 | labels[label][span].append(item) 433 | else: 434 | labels[label][span] = [item] 435 | tags = convert_back_to_bio(labels,text) 436 | tags = [' '.join([c,t,p]) for c,t,p in zip(text,example.label,tags)] 437 | final_results.append(tags) 438 | return final_results 439 | 440 | def export(self, estimator, export_path, checkpoint_path, eval_result, is_the_final_export): 441 | if not os.path.exists(export_path): 442 | tf.io.gfile.makedirs(export_path) 443 | 444 | final_results = self.get_biaffine_result(estimator) 445 | eval_lines, f1 = eval_ner(final_results,export_path,'eval') 446 | for line in eval_lines: 447 | tf.logging.info(line.rstrip()) 448 | if self._best_eval_result is None or f1 > self._best_eval_result: 449 | tf.logging.info('Exporting a better model ({} instead of {}), ckp-path: {}'.format( 450 | f1, self._best_eval_result,checkpoint_path)) 451 | basename = None 452 | for name in glob.glob(checkpoint_path + '.*'): 453 | parts = os.path.basename(name).split('.') 454 | if len(parts) == 3: 455 | parts[1] = parts[1].split('-')[0] 456 | filename = '.'.join(parts) 457 | basename = '.'.join(parts[:2]) 458 | shutil.copy(name, os.path.join(export_path, filename)) 459 | with open(os.path.join(export_path, "checkpoint"), 'w') as f: 460 | f.write("model_checkpoint_path: \"{}\"".format(basename)) 461 | with open(os.path.join(export_path, "best.txt"), 'w') as f: 462 | f.write('Best f1: {}, path: {}\n'.format(f1,checkpoint_path)) 463 | self._best_eval_result = f1 464 | else: 465 | tf.logging.info( 466 | 'Keeping the current best model ({} instead of {}).'.format( 467 | self._best_eval_result, f1)) --------------------------------------------------------------------------------