├── .gitignore ├── AR_architecture.png ├── LICENSE ├── README.md ├── __init__.py ├── bert_example.py ├── chat_rephrase ├── __init__.py ├── predict_for_chat.py └── score_for_qa.txt ├── compute_lcs.py ├── configs └── lasertagger_config.json ├── curLine_file.py ├── domain_rephrase ├── __init__.py ├── predict_for_domain.py └── rephrase_for_domain.sh ├── get_pairs_chinese ├── __init__.py ├── curLine_file.py ├── get_text_pair_lcqmc.py ├── get_text_pair_shixi.py ├── get_text_pair_sv.py └── merge_split_corpus.py ├── official_transformer ├── README ├── __init__.py ├── attention_layer.py ├── beam_search.py ├── embedding_layer.py ├── ffn_layer.py ├── model_params.py ├── model_utils.py ├── tpu.py └── transformer.py ├── phrase_vocabulary_optimization.py ├── predict_main.py ├── predict_utils.py ├── prediction.txt ├── preprocess_main.py ├── qa_rephrase ├── __init__.py ├── predict_for_qa.py └── score_for_qa.txt ├── rephrase.sh ├── rephrase_for_chat.sh ├── rephrase_for_qa.sh ├── rephrase_for_skill.sh ├── rephrase_server.sh ├── rephrase_server ├── __init__.py ├── rephrase_server_flask.py └── test_server.py ├── requirements.txt ├── run_lasertagger.py ├── run_lasertagger_utils.py ├── sari_hook.py ├── score_lib.py ├── score_main.py ├── sentence_fusion_task.png ├── skill_rephrase ├── __init__.py ├── predict_for_skill.py └── score_for_skill.txt ├── tagging.py ├── tagging_converter.py ├── transformer_decoder.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | *-notice.txt 6 | score.txt 7 | .idea/ 8 | # C extensions 9 | *.so 10 | 11 | # Distribution / packaging 12 | .Python 13 | build/ 14 | develop-eggs/ 15 | dist/ 16 | downloads/ 17 | eggs/ 18 | .eggs/ 19 | lib/ 20 | lib64/ 21 | parts/ 22 | sdist/ 23 | var/ 24 | wheels/ 25 | pip-wheel-metadata/ 26 | share/python-wheels/ 27 | *.egg-info/ 28 | .installed.cfg 29 | *.egg 30 | MANIFEST 31 | 32 | # PyInstaller 33 | # Usually these files are written by a python script from a template 34 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 35 | *.manifest 36 | *.spec 37 | 38 | # Installer logs 39 | pip-log.txt 40 | pip-delete-this-directory.txt 41 | 42 | # Unit test / coverage reports 43 | htmlcov/ 44 | .tox/ 45 | .nox/ 46 | .coverage 47 | .coverage.* 48 | .cache 49 | nosetests.xml 50 | coverage.xml 51 | *.cover 52 | *.py,cover 53 | .hypothesis/ 54 | .pytest_cache/ 55 | 56 | # Translations 57 | *.mo 58 | *.pot 59 | 60 | # Django stuff: 61 | *.log 62 | local_settings.py 63 | db.sqlite3 64 | db.sqlite3-journal 65 | 66 | # Flask stuff: 67 | instance/ 68 | .webassets-cache 69 | 70 | # Scrapy stuff: 71 | .scrapy 72 | 73 | # Sphinx documentation 74 | docs/_build/ 75 | 76 | # PyBuilder 77 | target/ 78 | 79 | # Jupyter Notebook 80 | .ipynb_checkpoints 81 | 82 | # IPython 83 | profile_default/ 84 | ipython_config.py 85 | 86 | # pyenv 87 | .python-version 88 | 89 | # pipenv 90 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 91 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 92 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 93 | # install all needed dependencies. 94 | #Pipfile.lock 95 | 96 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 97 | __pypackages__/ 98 | 99 | # Celery stuff 100 | celerybeat-schedule 101 | celerybeat.pid 102 | 103 | # SageMath parsed files 104 | *.sage.py 105 | 106 | # Environments 107 | .env 108 | .venv 109 | env/ 110 | venv/ 111 | ENV/ 112 | env.bak/ 113 | venv.bak/ 114 | 115 | # Spyder project settings 116 | .spyderproject 117 | .spyproject 118 | 119 | # Rope project settings 120 | .ropeproject 121 | 122 | # mkdocs documentation 123 | /site 124 | 125 | # mypy 126 | .mypy_cache/ 127 | .dmypy.json 128 | dmypy.json 129 | 130 | # Pyre type checker 131 | .pyre/ 132 | -------------------------------------------------------------------------------- /AR_architecture.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Mleader2/text_scalpel/131922c5dcfc558a46a7590925e348c39cb24951/AR_architecture.png -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # LaserTagger 2 | 一.概述 3 | 文本复述任务是指把一句/段文本A改写成文本B,要求文本B采用与文本A略有差异的表述方式来表达与之意思相近的文本。 4 | 改进谷歌的LaserTagger模型,使用LCQMC等中文语料训练文本复述模型,即修改一段文本并保持原有语义。 5 | 复述的结果可用于数据增强,文本泛化,从而增加特定场景的语料规模,提高模型泛化能力。 6 | 7 | 8 | 二.模型介绍 9 | 谷歌在文献《Encode, Tag, Realize: High-Precision Text Editing》中采用序列标注的框架进行文本编辑,在文本拆分和自动摘要任务上取得了最佳效果。 10 | 在同样采用BERT作为编码器的条件下,本方法相比于Seq2Seq的方法具有更高的可靠度,更快的训练和推理效率,且在语料规模较小的情况下优势更明显。 11 | 12 |

13 | 14 | 谷歌公开了本文献对应的代码,但是原有任务与当前任务有一定的差异性,需要修改部分代码,主要修改如下: 15 | A.分词方式:原代码针对英文,以空格为间隔分成若干词。现在针对中文,分成若干字。 16 | B.推理效率:原代码每次只对一个文本进行复述,改成每次对batch_size个文本进行复述,推理效率提高6倍。 17 | 18 | 三.文件说明和实验步骤 19 | 1.安装python模块 20 | 参见"requirements.txt", "rephrase.sh" 21 | 2.下载预训练模型 22 | 考虑模型推理的效率,目前本项目采用RoBERTa-tiny-clue(中文版)预训练模型。 23 | 由于目前网络上有不同版本,现将本项目使用的预训练模型上传的百度网盘。链接: https://pan.baidu.com/s/1yho8ihR9C6rBbY-IJjSagA 提取码: 2a97 24 | 如果想采用其他预训练模型,请修改“configs/lasertagger_config.json". 25 | 3.训练和评测模型 26 | 根据自己情况修改脚本"rephrase.sh"中2个文件夹的路径,然后运行 bash rephrase.sh HOST_NAME 27 | 变量HOST_NAME是作者为了方便设定路径使用的,请根据自己情况修改; 28 | 如果只是离线的对文本进行批量的泛化,可以注释脚本中其他部分,只用predict_main.py就可以满足需求。 29 | 4.启动文本复述服务 根据自己需要,可选 30 | 根据自己情况修改"rephrase_server.sh"文件中几个文件夹的路径,使用命令"sh rephrase_server.sh"可以启动一个文本复述的API服务. 31 | 本API服务可以接收一个http的POST请求,解析并对其中的文本进行泛化,具体接口请看“rephrase_server/rephrase_server_flask.py" 32 | 33 | 文本复述的语料需要自己整理语义一致的文本对。如果用自己业务场景下的语料最好,当然数量不能太少,如果没有或不够就加上LCQMC等语料中的正例。 34 | 然后用最长公共子串的长度限制一下,因为这个方法要求source和target的字面表达不能差异过大,可以参考一下“get_text_pair_lcqmc.py”。 35 | 目前,我的train.txt,tune.txt中都是三列即text1,text2,lcs_score,之间用tab"\t"分割。 36 | 37 | 有几个脚本文件如rephrase_for_qa.sh,rephrase_for_chat.sh,rephrase_for_skill.sh是作者自己办公需要的,可以忽略 38 | 39 | 四.实验效果 40 | 1. 在公开数据集Wiki Split上复现模型: 41 | Wiki Split数据集是英文语料,训练模型将一句话拆分成两句话,并保持语义一致,语法合理,语义连贯通顺,如下图所示。 42 | 43 |

44 | 45 | Exact score=15,SARI score=61.5,KEEP score=93,ADDITION score=32,DELETION score=59, 46 | 基本与论文中的Exact score=15.2;SARI score=61.7一致(这些分数均为越高越好)。 47 | 2. 在自己构造的中文数据集训练文本复述模型: 48 | (1)语料来源 49 | (A)一部分语料来自于LCQMC语料中的正例,即语义接近的一对文本; 50 | (B)另一部分语料来自于业务FAQ下面同一答案的问题; 51 | 因为模型的原理,要求文本A和B在具有一定的重合字数,故过滤掉上述两个来源中字面表述差异大的文本,如“我要去厕所”与“卫生间在哪里”。对语料筛选后对模型进行训练和测试。 52 | (2)测试结果: 53 | 对25918对文本进行复述和自动化评估,评测分数如下(越高越好): 54 | Exact score=29,SARI score=64,KEEP score=84,ADDITION score=39,DELETION score=66. 55 | CPU上耗时0.5小时,平均复述一句话需要0.72秒。 56 | 可能是语言和任务不同,在中文文本复述上的评测分数比公开数据集高一些。 57 | 58 | 五.一些trick 59 | 1.可以设定对于某些字或词不做修改 60 | 如对实体识别NER的语料泛化,需要保证模型不能修改其中的实体; 61 | 对业务语料泛化,也可以根据情况保证模型不能修改其中的关键字 如日期,航班号等; 62 | 目前,是通过正则的方式定位这些不能被模型修改的位置,然后将这些位置的location设置为1,具体实现参见tagging.py. 63 | 2.增加复述文本与原文本的差异度 64 | 可以对训练语料中的text_a先进行随机的swag操作,相应地脚本中enable_swap_tag改为true,再训练模型将其改写为text_b; 65 | 实际应用或测试时同样将原始文本text_a先进行随机的swag操作,然后利用模型改写为text_b; 66 | 因为训练语料中text_a是不通顺,但text_b是通顺的,所以实际应用或测试时仍然会得到通顺的复述结果。 67 | 68 | ## How to Cite LaserTagger 69 | 70 | ``` 71 | @inproceedings{malmi2019lasertagger, 72 | title={Encode, Tag, Realize: High-Precision Text Editing}, 73 | author={Eric Malmi and Sebastian Krause and Sascha Rothe and Daniil Mirylenka and Aliaksei Severyn}, 74 | booktitle={EMNLP-IJCNLP}, 75 | year={2019} 76 | } 77 | ``` 78 | 79 | ## License 80 | 81 | Apache 2.0; see [LICENSE](LICENSE) for details. 82 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- 1 | import compute_lcs -------------------------------------------------------------------------------- /bert_example.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2019 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Build BERT Examples from text (source, target) pairs.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | import collections 22 | import tagging 23 | import tensorflow as tf 24 | from utils import my_tokenizer_class 25 | from curLine_file import curLine 26 | 27 | class BertExample(object): 28 | """Class for training and inference examples for BERT. 29 | 30 | Attributes: 31 | editing_task: The EditingTask from which this example was created. Needed 32 | when realizing labels predicted for this example. 33 | features: Feature dictionary. 34 | """ 35 | 36 | def __init__(self, input_ids, 37 | input_mask, 38 | segment_ids, labels, 39 | labels_mask, 40 | token_start_indices, 41 | task, default_label): 42 | input_len = len(input_ids) 43 | if not (input_len == len(input_mask) and input_len == len(segment_ids) and 44 | input_len == len(labels) and input_len == len(labels_mask)): 45 | raise ValueError( 46 | 'All feature lists should have the same length ({})'.format( 47 | input_len)) 48 | 49 | self.features = collections.OrderedDict([ 50 | ('input_ids', input_ids), 51 | ('input_mask', input_mask), 52 | ('segment_ids', segment_ids), 53 | ('labels', labels), 54 | ('labels_mask', labels_mask), 55 | ]) 56 | self._token_start_indices = token_start_indices 57 | self.editing_task = task 58 | self._default_label = default_label 59 | 60 | def pad_to_max_length(self, max_seq_length, pad_token_id): 61 | """Pad the feature vectors so that they all have max_seq_length. 62 | 63 | Args: 64 | max_seq_length: The length that features will have after padding. 65 | pad_token_id: input_ids feature is padded with this ID, other features 66 | with ID 0. 67 | """ 68 | pad_len = max_seq_length - len(self.features['input_ids']) 69 | for key in self.features: 70 | pad_id = pad_token_id if key == 'input_ids' else 0 71 | self.features[key].extend([pad_id] * pad_len) 72 | if len(self.features[key]) != max_seq_length: 73 | raise ValueError('{} has length {} (should be {}).'.format( 74 | key, len(self.features[key]), max_seq_length)) 75 | 76 | def to_tf_example(self): 77 | """Returns this object as a tf.Example.""" 78 | 79 | def int_feature(values): 80 | return tf.train.Feature(int64_list=tf.train.Int64List(value=list(values))) 81 | 82 | tf_features = collections.OrderedDict([ 83 | (key, int_feature(val)) for key, val in self.features.items() 84 | ]) 85 | return tf.train.Example(features=tf.train.Features(feature=tf_features)) 86 | 87 | def get_token_labels(self): 88 | """Returns labels/tags for the original tokens, not for wordpieces.""" 89 | labels = [] 90 | for idx in self._token_start_indices: 91 | # For unmasked and untruncated tokens, use the label in the features, and 92 | # for the truncated tokens, use the default label. 93 | if (idx < len(self.features['labels']) and 94 | self.features['labels_mask'][idx]): 95 | current_label = self.features['labels'][idx] 96 | # if current_label >= 0: 97 | labels.append(self.features['labels'][idx]) 98 | # else: # stop 99 | # labels.append(self._default_label) 100 | else: 101 | labels.append(self._default_label) 102 | if labels[-1]<0: 103 | print(curLine(), idx, len(self.features['labels']), "mask=",self.features['labels_mask'][idx], self.features['labels'][idx], labels[-1] ) 104 | return labels 105 | 106 | 107 | class BertExampleBuilder(object): 108 | """Builder class for BertExample objects.""" 109 | 110 | def __init__(self, label_map, vocab_file, 111 | max_seq_length, do_lower_case, 112 | converter): 113 | """Initializes an instance of BertExampleBuilder. 114 | 115 | Args: 116 | label_map: Mapping from tags to tag IDs. 117 | vocab_file: Path to BERT vocabulary file. 118 | max_seq_length: Maximum sequence length. 119 | do_lower_case: Whether to lower case the input text. Should be True for 120 | uncased models and False for cased models. 121 | converter: Converter from text targets to tags. 122 | """ 123 | self._label_map = label_map 124 | self._tokenizer = my_tokenizer_class(vocab_file, do_lower_case=do_lower_case) 125 | self._max_seq_length = max_seq_length 126 | self._converter = converter 127 | self._pad_id = self._get_pad_id() 128 | self._keep_tag_id = self._label_map['KEEP'] 129 | 130 | def build_bert_example( 131 | self, 132 | sources, 133 | target=None, 134 | use_arbitrary_target_ids_for_infeasible_examples=False, 135 | location=None 136 | ): 137 | """Constructs a BERT Example. 138 | 139 | Args: 140 | sources: List of source texts. 141 | target: Target text or None when building an example during inference. 142 | use_arbitrary_target_ids_for_infeasible_examples: Whether to build an 143 | example with arbitrary target ids even if the target can't be obtained 144 | via tagging. 145 | 146 | Returns: 147 | BertExample, or None if the conversion from text to tags was infeasible 148 | and use_arbitrary_target_ids_for_infeasible_examples == False. 149 | """ 150 | # Compute target labels. 151 | task = tagging.EditingTask(sources, location=location, tokenizer=self._tokenizer) 152 | if target is not None: 153 | tags = self._converter.compute_tags(task, target, tokenizer=self._tokenizer) 154 | if not tags: # 不可转化,取决于 use_arbitrary_target_ids_for_infeasible_examples 155 | if use_arbitrary_target_ids_for_infeasible_examples: 156 | # Create a tag sequence [KEEP, DELETE, KEEP, DELETE, ...] which is 157 | # unlikely to be predicted by chance. 158 | tags = [tagging.Tag('KEEP') if i % 2 == 0 else tagging.Tag('DELETE') 159 | for i, _ in enumerate(task.source_tokens)] 160 | else: 161 | return None 162 | else: 163 | # If target is not provided, we set all target labels to KEEP. 164 | tags = [tagging.Tag('KEEP') for _ in task.source_tokens] 165 | labels = [self._label_map[str(tag)] for tag in tags] 166 | # tokens, labels, token_start_indices = self._split_to_wordpieces( # wordpiece: tag是以word为单位的,组成word的piece的标注与这个word相同 167 | # task.source_tokens, labels) 168 | if len(task.source_tokens) > self._max_seq_length - 2: 169 | print(curLine(), "%d tokens is to long," % len(task.source_tokens), "truncate task.source_tokens:", 170 | task.source_tokens) 171 | token_start_indices = [indices+1 for indices in range(len(task.source_tokens))] 172 | 173 | # 截断到self._max_seq_length - 2 174 | tokens = self._truncate_list(task.source_tokens) 175 | labels = self._truncate_list(labels) 176 | 177 | input_tokens = ['[CLS]'] + tokens + ['[SEP]'] 178 | labels_mask = [0] + [1] * len(labels) + [0] 179 | labels = [0] + labels + [0] 180 | 181 | input_ids = self._tokenizer.convert_tokens_to_ids(input_tokens) 182 | input_mask = [1] * len(input_ids) 183 | segment_ids = [0] * len(input_ids) 184 | example = BertExample( 185 | input_ids=input_ids, 186 | input_mask=input_mask, 187 | segment_ids=segment_ids, 188 | labels=labels, 189 | labels_mask=labels_mask, 190 | token_start_indices=token_start_indices, 191 | task=task, 192 | default_label=self._keep_tag_id) 193 | example.pad_to_max_length(self._max_seq_length, self._pad_id) 194 | return example 195 | 196 | # def _split_to_wordpieces(self, tokens, labels): 197 | # """Splits tokens (and the labels accordingly) to WordPieces. 198 | # 199 | # Args: 200 | # tokens: Tokens to be split. 201 | # labels: Labels (one per token) to be split. 202 | # 203 | # Returns: 204 | # 3-tuple with the split tokens, split labels, and the indices of the 205 | # WordPieces that start a token. 206 | # """ 207 | # bert_tokens = [] # Original tokens split into wordpieces. 208 | # bert_labels = [] # Label for each wordpiece. 209 | # # Index of each wordpiece that starts a new token. 210 | # token_start_indices = [] 211 | # for i, token in enumerate(tokens): 212 | # # '+ 1' is because bert_tokens will be prepended by [CLS] token later. 213 | # token_start_indices.append(len(bert_tokens) + 1) 214 | # pieces = self._tokenizer.tokenize(token) 215 | # bert_tokens.extend(pieces) 216 | # bert_labels.extend([labels[i]] * len(pieces)) 217 | # return bert_tokens, bert_labels, token_start_indices 218 | 219 | def _truncate_list(self, x): 220 | """Returns truncated version of x according to the self._max_seq_length.""" 221 | # Save two slots for the first [CLS] token and the last [SEP] token. 222 | return x[:self._max_seq_length - 2] 223 | 224 | def _get_pad_id(self): 225 | """Returns the ID of the [PAD] token (or 0 if it's not in the vocab).""" 226 | try: 227 | return self._tokenizer.convert_tokens_to_ids(['[PAD]'])[0] 228 | except KeyError: 229 | return 0 230 | -------------------------------------------------------------------------------- /chat_rephrase/__init__.py: -------------------------------------------------------------------------------- 1 | import bert_example 2 | import predict_utils 3 | import tagging_converter 4 | import utils -------------------------------------------------------------------------------- /chat_rephrase/predict_for_chat.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # 为domain识别泛化语料 3 | 4 | from __future__ import absolute_import 5 | from __future__ import division 6 | from __future__ import print_function 7 | import json 8 | from absl import app 9 | from absl import flags 10 | import os, sys, time 11 | from termcolor import colored 12 | import tensorflow as tf 13 | 14 | block_list = os.path.realpath(__file__).split("/") 15 | path = "/".join(block_list[:-2]) 16 | sys.path.append(path) 17 | import bert_example 18 | import predict_utils 19 | import tagging_converter 20 | import utils 21 | 22 | from curLine_file import curLine 23 | 24 | FLAGS = flags.FLAGS 25 | 26 | flags.DEFINE_string( 27 | 'input_file', None, 28 | 'Path to the input file containing examples for which to compute ' 29 | 'predictions.') 30 | flags.DEFINE_enum( 31 | 'input_format', None, ['wikisplit', 'discofuse'], 32 | 'Format which indicates how to parse the input_file.') 33 | flags.DEFINE_string( 34 | 'output_file', None, 35 | 'Path to the TSV file where the predictions are written to.') 36 | flags.DEFINE_string( 37 | 'label_map_file', None, 38 | 'Path to the label map file. Either a JSON file ending with ".json", that ' 39 | 'maps each possible tag to an ID, or a text file that has one tag per ' 40 | 'line.') 41 | flags.DEFINE_string('vocab_file', None, 'Path to the BERT vocabulary file.') 42 | flags.DEFINE_integer('max_seq_length', 128, 'Maximum sequence length.') 43 | flags.DEFINE_bool( 44 | 'do_lower_case', False, 45 | 'Whether to lower case the input text. Should be True for uncased ' 46 | 'models and False for cased models.') 47 | flags.DEFINE_bool('enable_swap_tag', True, 'Whether to enable the SWAP tag.') 48 | flags.DEFINE_string('saved_model', None, 'Path to an exported TF model.') 49 | 50 | def predict_and_write(predictor, sources_batch, previous_line_list,context_list, writer, num_predicted, start_time, batch_num): 51 | prediction_batch = predictor.predict_batch(sources_batch=sources_batch) 52 | assert len(prediction_batch) == len(sources_batch) 53 | for id, [prediction, sources] in enumerate(zip(prediction_batch, sources_batch)): 54 | output = "" 55 | if len(prediction) > 1 and prediction != sources: # TODO ignore keep totelly and short prediction 56 | output= "%s%s" % (context_list[id], prediction) # 需要和context拼接么 57 | writer.write("%s\t%s\n" % (previous_line_list[id], output)) 58 | batch_num = batch_num + 1 59 | num_predicted += len(prediction_batch) 60 | if batch_num % 200 == 0: 61 | cost_time = (time.time() - start_time) / 3600.0 62 | print("%s batch_id=%d, predict %d examples, cost %.3fh." % 63 | (curLine(), batch_num, num_predicted, cost_time)) 64 | return num_predicted, batch_num 65 | 66 | 67 | def main(argv): 68 | if len(argv) > 1: 69 | raise app.UsageError('Too many command-line arguments.') 70 | flags.mark_flag_as_required('input_file') 71 | flags.mark_flag_as_required('input_format') 72 | flags.mark_flag_as_required('output_file') 73 | flags.mark_flag_as_required('label_map_file') 74 | flags.mark_flag_as_required('vocab_file') 75 | flags.mark_flag_as_required('saved_model') 76 | 77 | label_map = utils.read_label_map(FLAGS.label_map_file) 78 | converter = tagging_converter.TaggingConverter( 79 | tagging_converter.get_phrase_vocabulary_from_label_map(label_map), 80 | FLAGS.enable_swap_tag) 81 | builder = bert_example.BertExampleBuilder(label_map, FLAGS.vocab_file, 82 | FLAGS.max_seq_length, 83 | FLAGS.do_lower_case, converter) 84 | predictor = predict_utils.LaserTaggerPredictor( 85 | tf.contrib.predictor.from_saved_model(FLAGS.saved_model), builder, 86 | label_map) 87 | print(colored("%s input file:%s" % (curLine(), FLAGS.input_file), "red")) 88 | sourcesA_list = [] 89 | with open(FLAGS.input_file) as f: 90 | for line in f: 91 | json_map = json.loads(line.rstrip('\n')) 92 | sourcesA_list.append(json_map["questions"]) 93 | print(curLine(), len(sourcesA_list), "sourcesA_list:", sourcesA_list[-1]) 94 | start_time = time.time() 95 | num_predicted = 0 96 | with tf.gfile.Open(FLAGS.output_file, 'w') as writer: 97 | for batch_id,sources_batch in enumerate(sourcesA_list): 98 | # sources_batch = sourcesA_list[batch_id * predict_batch_size: (batch_id + 1) * predict_batch_size] 99 | location_batch = [] 100 | for source in sources_batch: 101 | location = list() 102 | for char in source[0]: 103 | if (char>='0' and char<='9') or char in '.- ' or (char>='a' and char<='z') or (char>='A' and char<='Z'): 104 | location.append("1") # TODO TODO 105 | else: 106 | location.append("0") 107 | location_batch.append("".join(location)) 108 | prediction_batch = predictor.predict_batch(sources_batch=sources_batch, location_batch=location_batch) 109 | expand_list = [] 110 | for prediction in prediction_batch: # TODO 111 | if prediction in sources_batch: 112 | continue 113 | expand_list.append(prediction) 114 | 115 | json_map = {"questions":sources_batch, "expands":expand_list} 116 | json_str = json.dumps(json_map, ensure_ascii=False) 117 | writer.write("%s\n" % json_str) 118 | # input(curLine()) 119 | num_predicted += len(expand_list) 120 | if batch_id % 20 == 0: 121 | cost_time = (time.time() - start_time) / 60.0 122 | print("%s batch_id=%d/%d, predict %d/%d examples, cost %.2fmin." % 123 | (curLine(), batch_id + 1, len(sourcesA_list), num_predicted, num_predicted, cost_time)) 124 | cost_time = (time.time() - start_time) / 60.0 125 | # print(curLine(), id, prediction_A, prediction_B, "target:", target, "current_batch_size=", current_batch_size) 126 | # print(curLine(), id, "sourceA:", sourceA, "sourceB:", sourceB, "target:", target) 127 | # logging.info( 128 | # f'{curLine()} {num_predicted} predictions saved to:{FLAGS.output_file}, cost {cost_time} min, ave {cost_time / num_predicted*60000}ms.') 129 | 130 | if __name__ == '__main__': 131 | app.run(main) 132 | -------------------------------------------------------------------------------- /chat_rephrase/score_for_qa.txt: -------------------------------------------------------------------------------- 1 | 1200 2 | Exact score: 12.843 3 | SARI score: 51.313 4 | KEEP score: 76.046 5 | ADDITION score: 20.086 6 | DELETION score: 57.808 7 | cost time 0.337063 h 8 | 9 | 房山 Restoring parameters from /home/wzk/Mywork/corpus/rephrase_corpus/output/models/wikisplit_experiment_name/model.ckpt-3643 10 | Exact score: 13.672 11 | SARI score: 53.450 12 | KEEP score: 77.125 13 | ADDITION score: 21.678 14 | DELETION score: 61.546 15 | cost time 0.260398 h 16 | 继续train epoch从7.0改为30.0 17 | Saving checkpoints for 36436 into /home/wzk/Mywork/corpus/rephrase_corpus/output/models/wikisplit_experiment_name/model.ckpt. 18 | Exact score: 19.414 19 | SARI score: 56.749 20 | KEEP score: 78.622 21 | ADDITION score: 26.927 22 | DELETION score: 64.697 23 | cost time 0.260367 h 24 | model.ckpt-35643 25 | Exact score: 19.440 26 | SARI score: 56.725 27 | KEEP score: 78.620 28 | ADDITION score: 26.917 29 | DELETION score: 64.638 30 | cost time 0.259561 h 31 | Restoring parameters from /home/wzk/Mywork/corpus/rephrase_corpus/output/models/wikisplit_experiment_name/model.ckpt-62436 32 | Exact score: 20.347 33 | SARI score: 57.107 34 | KEEP score: 78.693 35 | ADDITION score: 27.379 36 | DELETION score: 65.250 37 | cost time 0.575784 h 38 | 39 | 40 | 旧的 41 | 完善大小写的处理规则 42 | Exact score: 29.283 43 | SARI score: 64.002 44 | KEEP score: 84.942 45 | ADDITION score: 38.128 46 | DELETION score: 68.935 47 | cost time 0.41796 h -------------------------------------------------------------------------------- /compute_lcs.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # 函数compute_lcs,用于计算两个列表的Longest Common Subsequence (LCS) 3 | 4 | def _compute_lcs(source, target): 5 | """Computes the Longest Common Subsequence (LCS). 6 | 7 | Description of the dynamic programming algorithm: 8 | https://www.algorithmist.com/index.php/Longest_Common_Subsequence 9 | 10 | Args: 11 | source: List of source tokens. 12 | target: List of target tokens. 13 | 14 | Returns: 15 | List of tokens in the LCS. 16 | """ 17 | table = _lcs_table(source, target) 18 | return _backtrack(table, source, target, len(source), len(target)) 19 | 20 | 21 | def _lcs_table(source, target): 22 | """Returns the Longest Common Subsequence dynamic programming table.""" 23 | rows = len(source) 24 | cols = len(target) 25 | lcs_table = [[0] * (cols + 1) for _ in range(rows + 1)] 26 | for i in range(1, rows + 1): 27 | for j in range(1, cols + 1): 28 | if source[i - 1] == target[j - 1]: 29 | lcs_table[i][j] = lcs_table[i - 1][j - 1] + 1 30 | else: 31 | lcs_table[i][j] = max(lcs_table[i - 1][j], lcs_table[i][j - 1]) 32 | return lcs_table 33 | 34 | 35 | def _backtrack(table, source, target, i, j): 36 | """Backtracks the Longest Common Subsequence table to reconstruct the LCS. 37 | 38 | Args: 39 | table: Precomputed LCS table. 40 | source: List of source tokens. 41 | target: List of target tokens. 42 | i: Current row index. 43 | j: Current column index. 44 | 45 | Returns: 46 | List of tokens corresponding to LCS. 47 | """ 48 | if i == 0 or j == 0: 49 | return [] 50 | if source[i - 1] == target[j - 1]: 51 | # Append the aligned token to output. 52 | return _backtrack(table, source, target, i - 1, j - 1) + [target[j - 1]] 53 | if table[i][j - 1] > table[i - 1][j]: 54 | return _backtrack(table, source, target, i, j - 1) 55 | else: 56 | return _backtrack(table, source, target, i - 1, j) 57 | -------------------------------------------------------------------------------- /configs/lasertagger_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "attention_probs_dropout_prob": 0.1, 3 | "hidden_act": "gelu", 4 | "hidden_dropout_prob": 0.1, 5 | "hidden_size": 312, 6 | "initializer_range": 0.02, 7 | "intermediate_size": 1248, 8 | "max_position_embeddings": 512, 9 | "num_attention_heads": 12, 10 | "num_hidden_layers": 4, 11 | "type_vocab_size": 2, 12 | "vocab_size": 8021, 13 | "use_t2t_decoder": true, 14 | "decoder_num_hidden_layers": 1, 15 | "decoder_hidden_size": 768, 16 | "decoder_num_attention_heads": 4, 17 | "decoder_filter_size": 3072, 18 | "use_full_attention": false 19 | } 20 | -------------------------------------------------------------------------------- /curLine_file.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import sys 3 | def curLine(): 4 | file_path = sys._getframe().f_back.f_code.co_filename # 获取调用函数的路径 5 | file_name=file_path[file_path.rfind("/") + 1:] # 获取调用函数所在的文件名 6 | lineno=sys._getframe().f_back.f_lineno#当前行号 7 | str="[%s:%s] "%(file_name,lineno) 8 | return str -------------------------------------------------------------------------------- /domain_rephrase/__init__.py: -------------------------------------------------------------------------------- 1 | import bert_example 2 | import predict_utils 3 | import tagging_converter 4 | import utils -------------------------------------------------------------------------------- /domain_rephrase/predict_for_domain.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # 为domain识别泛化语料 3 | 4 | from __future__ import absolute_import 5 | from __future__ import division 6 | from __future__ import print_function 7 | 8 | from absl import app 9 | from absl import flags 10 | from absl import logging 11 | import os, sys, time 12 | from termcolor import colored 13 | import tensorflow as tf 14 | 15 | block_list = os.path.realpath(__file__).split("/") 16 | path = "/".join(block_list[:-2]) 17 | sys.path.append(path) 18 | 19 | import bert_example 20 | import predict_utils 21 | import tagging_converter 22 | import utils 23 | 24 | from curLine_file import curLine 25 | 26 | FLAGS = flags.FLAGS 27 | 28 | flags.DEFINE_string( 29 | 'input_file', None, 30 | 'Path to the input file containing examples for which to compute ' 31 | 'predictions.') 32 | flags.DEFINE_enum( 33 | 'input_format', None, ['wikisplit', 'discofuse'], 34 | 'Format which indicates how to parse the input_file.') 35 | flags.DEFINE_string( 36 | 'output_file', None, 37 | 'Path to the TSV file where the predictions are written to.') 38 | flags.DEFINE_string( 39 | 'label_map_file', None, 40 | 'Path to the label map file. Either a JSON file ending with ".json", that ' 41 | 'maps each possible tag to an ID, or a text file that has one tag per ' 42 | 'line.') 43 | flags.DEFINE_string('vocab_file', None, 'Path to the BERT vocabulary file.') 44 | flags.DEFINE_integer('max_seq_length', 128, 'Maximum sequence length.') 45 | flags.DEFINE_bool( 46 | 'do_lower_case', False, 47 | 'Whether to lower case the input text. Should be True for uncased ' 48 | 'models and False for cased models.') 49 | flags.DEFINE_bool('enable_swap_tag', True, 'Whether to enable the SWAP tag.') 50 | flags.DEFINE_string('saved_model', None, 'Path to an exported TF model.') 51 | 52 | def predict_and_write(predictor, sources_batch, previous_line_list,context_list, writer, num_predicted, start_time, batch_num): 53 | prediction_batch = predictor.predict_batch(sources_batch=sources_batch) 54 | assert len(prediction_batch) == len(sources_batch) 55 | for id, [prediction, sources] in enumerate(zip(prediction_batch, sources_batch)): 56 | output = "" 57 | if len(prediction) > 1 and prediction != sources: # TODO ignore keep totelly and short prediction 58 | output= "%s%s" % (context_list[id], prediction) # 需要和context拼接么 59 | # print(curLine(), "prediction:", prediction, "sources:", sources, ",output:", output, prediction != sources) 60 | writer.write("%s\t%s\n" % (previous_line_list[id], output)) 61 | batch_num = batch_num + 1 62 | num_predicted += len(prediction_batch) 63 | if batch_num % 200 == 0: 64 | cost_time = (time.time() - start_time) / 3600.0 65 | print("%s batch_id=%d, predict %d examples, cost %.3fh." % 66 | (curLine(), batch_num, num_predicted, cost_time)) 67 | return num_predicted, batch_num 68 | 69 | def remove_p(text_pre): 70 | text = text_pre[:9].replace('\n', '').replace('\r', '').replace('news', '').replace('food','').replace('GetPath', '') \ 71 | .replace('GetYear','').replace('GetDate', '').replace('flight', '').replace('weather', '').replace('currency', '').replace('stock', '').replace('story', '') \ 72 | .replace('drama', '').replace('jokes', '').replace('other','').replace('poetry', '').replace('GetLunar', '')+text_pre[9:] 73 | if len(text) == len(text_pre) and len(text)>9: 74 | text = text.replace('stoytelling', '').replace('storytypes', '').replace( 75 | 'SetPathPlace', '').replace('GetPoetryByTitle', '').replace('GetTitles', '') \ 76 | .replace('GetSolarterm', '').replace('GetLastPhrases', '').replace('SelectPlaceIndex','').replace( 77 | 'photo.tag', '').replace('photo.rac', '').replace('GetSuitAndAvoid', '') \ 78 | .replace('GetOnePoetry', '').replace('SetPathTrans', '').replace('navigation', '').replace( 79 | 'GetNextPhrases', '').replace('crosstalk', '') 80 | if len(text) == len(text_pre) and len(text) > 5: 81 | text = text.replace('currency,SetCurrentcy', '').replace('SetCurrentcy', '').replace('GetWeekDay', '').replace('GetAuthors', '') \ 82 | .replace('trafficrestr', '').replace('GetTranslates', '').replace('GetAuthorNames','') \ 83 | .replace('MatchTitle', '').replace('WeiboFirst', '').replace('music','').replace('times', '').replace('photo', '') 84 | return text 85 | 86 | def main(argv): 87 | if len(argv) > 1: 88 | raise app.UsageError('Too many command-line arguments.') 89 | flags.mark_flag_as_required('input_file') 90 | flags.mark_flag_as_required('input_format') 91 | flags.mark_flag_as_required('output_file') 92 | flags.mark_flag_as_required('label_map_file') 93 | flags.mark_flag_as_required('vocab_file') 94 | flags.mark_flag_as_required('saved_model') 95 | 96 | label_map = utils.read_label_map(FLAGS.label_map_file) 97 | converter = tagging_converter.TaggingConverter( 98 | tagging_converter.get_phrase_vocabulary_from_label_map(label_map), 99 | FLAGS.enable_swap_tag) 100 | builder = bert_example.BertExampleBuilder(label_map, FLAGS.vocab_file, 101 | FLAGS.max_seq_length, 102 | FLAGS.do_lower_case, converter) 103 | predictor = predict_utils.LaserTaggerPredictor( 104 | tf.contrib.predictor.from_saved_model(FLAGS.saved_model), builder, 105 | label_map) 106 | print(colored("%s input file:%s" % (curLine(), FLAGS.input_file), "red")) 107 | predict_batch_size = 64 108 | batch_num = 0 109 | num_predicted = 0 110 | with tf.gfile.Open(FLAGS.output_file, 'w') as writer: 111 | with open(FLAGS.input_file, "r") as f: 112 | sources_batch = [] 113 | previous_line_list = [] 114 | context_list = [] 115 | line_number = 0 116 | start_time = time.time() 117 | while True: 118 | line_number +=1 119 | line = f.readline().rstrip('\n').strip("\"").strip(" ") 120 | if len(line) == 0: 121 | break 122 | 123 | column_index = line.index(",") 124 | text = line[column_index+1:].strip("\"") # context and query 125 | # for charChinese_id, char in enumerate(line[column_index+1:]): 126 | # if (char>='a' and char<='z') or (char>='A' and char<='Z'): 127 | # continue 128 | # else: 129 | # break 130 | source = remove_p(text) 131 | if source not in text: # TODO ignore的就给空字符串,这样输出也是空字符串 132 | print(curLine(), "line_number=%d, ignore:%s" % (line_number, text), ",source:", len(source), source) 133 | source = "" 134 | # continue 135 | context_list.append(text[:text.index(source)]) 136 | previous_line_list.append(line) 137 | sources_batch.append(source) 138 | if len(sources_batch) == predict_batch_size: 139 | num_predicted, batch_num = predict_and_write(predictor, sources_batch, 140 | previous_line_list,context_list, writer, num_predicted, start_time, batch_num) 141 | sources_batch = [] 142 | previous_line_list = [] 143 | context_list = [] 144 | # if num_predicted > 1000: 145 | # break 146 | if len(context_list)>0: 147 | num_predicted, batch_num = predict_and_write(predictor, sources_batch, 148 | previous_line_list, context_list, writer, num_predicted, 149 | start_time, batch_num) 150 | cost_time = (time.time() - start_time) / 60.0 151 | logging.info( 152 | f'{curLine()} {num_predicted} predictions saved to:{FLAGS.output_file}, cost {cost_time} min, ave {cost_time / num_predicted/60} hours.') 153 | 154 | if __name__ == '__main__': 155 | app.run(main) 156 | -------------------------------------------------------------------------------- /domain_rephrase/rephrase_for_domain.sh: -------------------------------------------------------------------------------- 1 | # 扩充技能的语料 2 | # rephrase_for_skill.sh: 在rephrase.sh基础上改的 3 | # predict_for_skill.py: 在 predict_main.py基础上改的 4 | # score_for_skill.txt 结果对比 5 | 6 | 7 | # 成都 8 | # pyenv activate python373tf115 9 | # pip install -i https://pypi.douban.com/simple/ bert-tensorflow==1.0.1 10 | #pip install -i https://pypi.douban.com/simple/ tensorflow==1.15.0 11 | #python -m pip install --upgrade pip -i https://pypi.douban.com/simple 12 | 13 | # set gpu id to use 14 | export CUDA_VISIBLE_DEVICES="" 15 | 16 | # 房山 17 | # pyenv activate python363tf111 18 | # pip install bert-tensorflow==1.0.1 19 | 20 | #scp -r /home/cloudminds/PycharmProjects/lasertagger-Chinese/predict_main.py cloudminds@10.13.33.128:/home/cloudminds/PycharmProjects/lasertagger-Chinese 21 | #scp -r cloudminds@10.13.33.128:/home/wzk/Mywork/corpus/文本复述/output/models/wikisplit_experiment_name /home/cloudminds/Mywork/corpus/文本复述/output/models/ 22 | # watch -n 1 nvidia-smi 23 | 24 | start_tm=`date +%s%N`; 25 | 26 | export HOST_NAME="cloudminds" #   "wzk" # 27 | ### Optional parameters ### 28 | 29 | # If you train multiple models on the same data, change this label. 30 | EXPERIMENT=wikisplit_experiment 31 | # To quickly test that model training works, set the number of epochs to a 32 | # smaller value (e.g. 0.01). 33 | NUM_EPOCHS=10.0 34 | export TRAIN_BATCH_SIZE=256 # 512 OOM 256 OK 35 | PHRASE_VOCAB_SIZE=500 36 | MAX_INPUT_EXAMPLES=1000000 37 | SAVE_CHECKPOINT_STEPS=200 38 | export enable_swap_tag=false 39 | export output_arbitrary_targets_for_infeasible_examples=false 40 | export WIKISPLIT_DIR="/home/${HOST_NAME}/Mywork/corpus/rephrase_corpus" 41 | export OUTPUT_DIR="${WIKISPLIT_DIR}/output" 42 | 43 | #python phrase_vocabulary_optimization.py \ 44 | # --input_file=${WIKISPLIT_DIR}/train.txt \ 45 | # --input_format=wikisplit \ 46 | # --vocabulary_size=500 \ 47 | # --max_input_examples=1000000 \ 48 | # --enable_swap_tag=${enable_swap_tag} \ 49 | # --output_file=${OUTPUT_DIR}/label_map.txt 50 | 51 | 52 | export max_seq_length=40 # TODO 53 | export BERT_BASE_DIR="/home/${HOST_NAME}/Mywork/model/RoBERTa-tiny-clue" # chinese_L-12_H-768_A-12" 54 | 55 | 56 | 57 | 58 | # Check these numbers from the "*.num_examples" files created in step 2. 59 | export CONFIG_FILE=configs/lasertagger_config.json 60 | export EXPERIMENT=wikisplit_experiment_name 61 | 62 | 63 | 64 | ### 4. Prediction 65 | 66 | # Export the model. 67 | #python run_lasertagger.py \ 68 | # --label_map_file=${OUTPUT_DIR}/label_map.txt \ 69 | # --model_config_file=${CONFIG_FILE} \ 70 | # --output_dir=${OUTPUT_DIR}/models/${EXPERIMENT} \ 71 | # --do_export=true \ 72 | # --export_path=${OUTPUT_DIR}/models/${EXPERIMENT}/export 73 | 74 | ## Get the most recently exported model directory. 75 | TIMESTAMP=$(ls "${OUTPUT_DIR}/models/${EXPERIMENT}/export/" | \ 76 | grep -v "temp-" | sort -r | head -1) 77 | SAVED_MODEL_DIR=${OUTPUT_DIR}/models/${EXPERIMENT}/export/${TIMESTAMP} 78 | PREDICTION_FILE=${OUTPUT_DIR}/models/${EXPERIMENT}/pred.tsv 79 | 80 | python domain_rephrase/predict_for_domain.py \ 81 | --input_file=/home/${HOST_NAME}/Mywork/corpus/ner_corpus/domain_corpus/train3.csv \ 82 | --input_format=wikisplit \ 83 | --output_file=/home/${HOST_NAME}/Mywork/corpus/ner_corpus/domain_corpus/train3_expand.csv \ 84 | --label_map_file=${OUTPUT_DIR}/label_map.txt \ 85 | --vocab_file=${BERT_BASE_DIR}/vocab.txt \ 86 | --max_seq_length=${max_seq_length} \ 87 | --saved_model=${SAVED_MODEL_DIR} 88 | 89 | #### 5. Evaluation 90 | #python score_main.py --prediction_file=${PREDICTION_FILE} 91 | 92 | 93 | end_tm=`date +%s%N`; 94 | use_tm=`echo $end_tm $start_tm | awk '{ print ($1 - $2) / 1000000000 /3600}'` 95 | echo "cost time" $use_tm "h" -------------------------------------------------------------------------------- /get_pairs_chinese/__init__.py: -------------------------------------------------------------------------------- 1 | # 获取文本复述(rephrase)任务的语料 2 | import curLine_file 3 | import compute_lcs -------------------------------------------------------------------------------- /get_pairs_chinese/curLine_file.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import sys 3 | def curLine(): 4 | file_path = sys._getframe().f_back.f_code.co_filename # 获取调用函数的路径 5 | file_name=file_path[file_path.rfind("/") + 1:] # 获取调用函数所在的文件名 6 | lineno=sys._getframe().f_back.f_lineno#当前行号 7 | str="[%s:%s] "%(file_name,lineno) 8 | return str -------------------------------------------------------------------------------- /get_pairs_chinese/get_text_pair_lcqmc.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # 利用文本匹配的语料,从正例中采样得到句子对(A,B),然后训练模型把A改写成B 3 | # 当前是针对LCQMC 改了几条语料的标注 4 | import random 5 | import os 6 | import sys 7 | sys.path.append("..") 8 | from compute_lcs import _compute_lcs 9 | from curLine_file import curLine 10 | 11 | def process(corpus_folder, raw_file_name, save_folder): 12 | corpus_list = [] 13 | for name in raw_file_name: 14 | raw_file = os.path.join(corpus_folder, name) 15 | with open(raw_file, "r") as fr: 16 | lines = fr.readlines() 17 | 18 | for i ,line in enumerate(lines): 19 | source, target, label = line.strip().split("\t") 20 | if label=="0" or source==target: 21 | continue 22 | if label != "1": 23 | input(curLine()+line.strip()) 24 | length = float(len(source) + len(target)) 25 | 26 | source_length = len(source) 27 | if source_length > 8 and source_length<38 and (i+1)%2>0: # 对50%的长句构造交换操作 28 | rand = random.uniform(0.4, 0.9) 29 | source_pre = source 30 | swag_location = int(source_length*rand) 31 | source = "%s%s" % (source[swag_location:], source[:swag_location]) 32 | lcs1 = _compute_lcs(source, target) 33 | lcs_rate= len(lcs1)/length 34 | if (lcs_rate<0.4):# 差异大,换回来 35 | source = source_pre 36 | else: 37 | print(curLine(), "source_pre:%s, source:%s, lcs_rate=%f" % (source_pre, source, lcs_rate)) 38 | 39 | lcs1 = _compute_lcs(source, target) 40 | lcs_rate = len(lcs1) / length 41 | if (lcs_rate<0.2): 42 | continue # 变动过大,忽略 43 | 44 | # if (lcs_rate<0.4): 45 | # continue # 变动过大,忽略 46 | # if len(source)*1.15 < len(target): 47 | # new_t = source 48 | # source = target 49 | # target = new_t 50 | # print(curLine(), source, target, ",lcs1:",lcs1 , ",lcs_rate=", lcs_rate) 51 | corpus = "%s\t%s\t%f\n" % (source, target, lcs_rate) 52 | corpus_list.append(corpus) 53 | print(curLine(), len(corpus_list), "from %s" % raw_file) 54 | save_file = os.path.join(save_folder, "lcqmc.txt") 55 | with open(save_file, "w") as fw: 56 | fw.writelines(corpus_list) 57 | print(curLine(), "have save %d to %s" % (len(corpus_list), save_file)) 58 | 59 | if __name__ == "__main__": 60 | corpus_folder = "/home/cloudminds/Mywork/corpus/Chinese_QA/LCQMC" 61 | raw_file_name = ["train.txt", "dev.txt", "test.txt"] 62 | save_folder = "/home/cloudminds/Mywork/corpus/rephrase_corpus" 63 | process(corpus_folder, raw_file_name, save_folder) 64 | 65 | -------------------------------------------------------------------------------- /get_pairs_chinese/get_text_pair_shixi.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # 世西给的QA对(至少是答案相同)的句子在一行,从中采样得到句子对(A,B),然后训练模型把A改写成B 3 | # 当前是针对宝安机场 比较旧且少的数据集, 4 | import os 5 | from compute_lcs import _compute_lcs 6 | from curLine_file import curLine 7 | 8 | def process(corpus_folder, raw_file_name, save_folder): 9 | raw_file = os.path.join(corpus_folder, raw_file_name) 10 | with open(raw_file, "r") as fr: 11 | lines = fr.readlines() 12 | corpus_list = [] 13 | for line in lines: 14 | sent_list = line.strip().split("&&") 15 | sent_num = len(sent_list) 16 | for i in range(1, sent_num, 2): 17 | source= sent_list[i-1] 18 | target = sent_list[i] 19 | length = float(len(source) + len(target)) 20 | lcs1 = _compute_lcs(source, target) 21 | lcs_rate= len(lcs1)/length 22 | if (lcs_rate<0.3): 23 | continue # 变动过大,忽略 24 | if len(source)*1.15 < len(target): 25 | new_t = source 26 | source = target 27 | target = new_t 28 | corpus = "%s\t%s\t%f\n" % (source, target, lcs_rate) 29 | corpus_list.append(corpus) 30 | save_file = os.path.join(save_folder, "baoan_airport.txt") 31 | with open(save_file, "w") as fw: 32 | fw.writelines(corpus_list) 33 | print(curLine(), "have save %d to %s" % (len(corpus_list), save_file)) 34 | 35 | 36 | 37 | 38 | if __name__ == "__main__": 39 | corpus_folder = "/home/cloudminds/Mywork/corpus/Chinese_QA/baoanairport" 40 | raw_file_name = "baoan_airport_processed.txt" 41 | save_folder = "/home/cloudminds/Mywork/corpus/rephrase_corpus" 42 | process(corpus_folder, raw_file_name, save_folder) 43 | 44 | -------------------------------------------------------------------------------- /get_pairs_chinese/get_text_pair_sv.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # 从SV导出后的xlsx文件,相同含义(至少是答案相同)的句子在一行,从中采样得到句子对(A,B),然后训练模型把A改写成B 3 | # 当前是针对宝安机场  4 | import sys 5 | sys.path.append("..") 6 | import os 7 | import xlrd # 引入模块 8 | import random 9 | from compute_lcs import _compute_lcs 10 | from curLine_file import curLine 11 | 12 | def process(corpus_folder, raw_file_name): 13 | raw_file = os.path.join(corpus_folder, raw_file_name) 14 | # 打开文件,获取excel文件的workbook(工作簿)对象 15 | workbook = xlrd.open_workbook(raw_file) # 文件路径 16 | 17 | # 通过sheet索引获得sheet对象 18 | worksheet = workbook.sheet_by_index(0) 19 | nrows = worksheet.nrows # 获取该表总行数 20 | ncols = worksheet.ncols # 获取该表总列数 21 | print(curLine(), "raw_file_name:%s, worksheet:%s nrows=%d, ncols=%d" % (raw_file_name, worksheet.name,nrows, ncols)) 22 | assert ncols == 3 23 | assert nrows > 0 24 | col_data = worksheet.col_values(0) # 获取第一列的内容 25 | corpus_list = [] 26 | for line in col_data: 27 | sent_list = line.strip().split("&&") 28 | sent_num = len(sent_list) 29 | for i in range(1, sent_num, 2): 30 | source= sent_list[i-1] 31 | target = sent_list[i] 32 | # source_length = len(source) 33 | # if source_length > 8 and (i+1)%4>0: # 对50%的长句随机删除 34 | # rand = random.uniform(0.1, 0.9) 35 | # source_pre = source 36 | # swag_location = int(source_length*rand) 37 | # source = "%s%s" % (source[:swag_location], source[swag_location+1:]) 38 | # print(curLine(), "source_pre:%s, source:%s" % (source_pre, source)) 39 | 40 | length = float(len(source) + len(target)) 41 | lcs1 = _compute_lcs(source, target) 42 | lcs_rate= len(lcs1)/length 43 | if (lcs_rate<0.2): 44 | continue # 变动过大,忽略 45 | 46 | # if (lcs_rate<0.3): 47 | # continue # 变动过大,忽略 48 | # if len(source)*1.15 < len(target): 49 | # new_t = source 50 | # source = target 51 | # target = new_t 52 | corpus = "%s\t%s\t%f\n" % (source, target, lcs_rate) 53 | corpus_list.append(corpus) 54 | return corpus_list 55 | 56 | def main(corpus_folder, save_folder): 57 | fileList = os.listdir(corpus_folder) 58 | corpus_list_total = [] 59 | for raw_file_name in fileList: 60 | corpus_list = process(corpus_folder, raw_file_name) 61 | print(curLine(), raw_file_name, len(corpus_list)) 62 | corpus_list_total.extend(corpus_list) 63 | save_file = os.path.join(save_folder, "baoan_airport_from_xlsx.txt") 64 | with open(save_file, "w") as fw: 65 | fw.writelines(corpus_list_total) 66 | print(curLine(), "have save %d to %s" % (len(corpus_list_total), save_file)) 67 | 68 | 69 | 70 | 71 | if __name__ == "__main__": 72 | corpus_folder = "/home/cloudminds/Mywork/corpus/Chinese_QA/baoanairport/agent842-3月2日" 73 | 74 | save_folder = "/home/cloudminds/Mywork/corpus/rephrase_corpus" 75 | raw_file_name = "专业知识导出记录.xlsx" 76 | main(corpus_folder, save_folder) 77 | 78 | -------------------------------------------------------------------------------- /get_pairs_chinese/merge_split_corpus.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # 对不同来源的语料融合,然后再划分出train,dev,test 3 | import os 4 | import numpy as np 5 | from curLine_file import curLine 6 | 7 | def merge(raw_file_name_list, save_folder): 8 | corpus_list = [] 9 | for raw_file_name in raw_file_name_list: 10 | raw_file = os.path.join(save_folder, "%s.txt" % raw_file_name) 11 | with open(raw_file) as fr: 12 | lines = fr.readlines() 13 | corpus_list.extend(lines) 14 | if "baoan" in raw_file_name: 15 | corpus_list.extend(lines) # TODO 16 | return corpus_list 17 | 18 | def split(corpus_list, save_folder, trainRate=0.8): 19 | corpusNum = len(corpus_list) 20 | shuffle_indices = list(np.random.permutation(range(corpusNum))) 21 | indexTrain = int(trainRate * corpusNum) 22 | # indexDev= int((trainRate + devRate) * corpusNum) 23 | corpusTrain = [] 24 | for i in shuffle_indices[:indexTrain]: 25 | corpusTrain.append(corpus_list[i]) 26 | save_file = os.path.join(save_folder, "train.txt") 27 | with open(save_file, "w") as fw: 28 | fw.writelines(corpusTrain) 29 | print(curLine(), "have save %d to %s" % (len(corpusTrain), save_file)) 30 | 31 | corpusDev = [] 32 | for i in shuffle_indices[indexTrain:]: # TODO all corpus 33 | corpusDev.append(corpus_list[i]) 34 | save_file = os.path.join(save_folder, "tune.txt") 35 | with open(save_file, "w") as fw: 36 | fw.writelines(corpusDev) 37 | print(curLine(), "have save %d to %s" % (len(corpusDev), save_file)) 38 | 39 | 40 | save_file = os.path.join(save_folder, "test.txt") 41 | with open(save_file, "w") as fw: 42 | fw.writelines(corpusDev) 43 | print(curLine(), "have save %d to %s" % (len(corpusDev), save_file)) 44 | 45 | 46 | 47 | 48 | if __name__ == "__main__": 49 | raw_file_name = ["baoan_airport", "lcqmc", "baoan_airport_from_xlsx"] 50 | save_folder = "/home/cloudminds/Mywork/corpus/rephrase_corpus" 51 | corpus_list = merge(raw_file_name, save_folder) 52 | split(corpus_list, save_folder, trainRate=0.8) 53 | 54 | -------------------------------------------------------------------------------- /official_transformer/README: -------------------------------------------------------------------------------- 1 | This directory contains Transformer related code files. These are copied from: 2 | https://github.com/tensorflow/models/tree/master/official/transformer 3 | to make it possible to install all LaserTagger dependencies with pip (the 4 | official Transformer implementation doesn't support pip installation). -------------------------------------------------------------------------------- /official_transformer/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Mleader2/text_scalpel/131922c5dcfc558a46a7590925e348c39cb24951/official_transformer/__init__.py -------------------------------------------------------------------------------- /official_transformer/attention_layer.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The TensorFlow Authors. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | """Implementation of multiheaded attention and self-attention layers.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import tensorflow as tf 23 | 24 | 25 | class Attention(tf.layers.Layer): 26 | """Multi-headed attention layer.""" 27 | 28 | def __init__(self, hidden_size, num_heads, attention_dropout, train): 29 | if hidden_size % num_heads != 0: 30 | raise ValueError("Hidden size must be evenly divisible by the number of " 31 | "heads.") 32 | 33 | super(Attention, self).__init__() 34 | self.hidden_size = hidden_size 35 | self.num_heads = num_heads 36 | self.attention_dropout = attention_dropout 37 | self.train = train 38 | 39 | # Layers for linearly projecting the queries, keys, and values. 40 | self.q_dense_layer = tf.layers.Dense(hidden_size, use_bias=False, name="q") 41 | self.k_dense_layer = tf.layers.Dense(hidden_size, use_bias=False, name="k") 42 | self.v_dense_layer = tf.layers.Dense(hidden_size, use_bias=False, name="v") 43 | 44 | self.output_dense_layer = tf.layers.Dense(hidden_size, use_bias=False, 45 | name="output_transform") 46 | 47 | def split_heads(self, x): 48 | """Split x into different heads, and transpose the resulting value. 49 | 50 | The tensor is transposed to insure the inner dimensions hold the correct 51 | values during the matrix multiplication. 52 | 53 | Args: 54 | x: A tensor with shape [batch_size, length, hidden_size] 55 | 56 | Returns: 57 | A tensor with shape [batch_size, num_heads, length, hidden_size/num_heads] 58 | """ 59 | with tf.name_scope("split_heads"): 60 | batch_size = tf.shape(x)[0] 61 | length = tf.shape(x)[1] 62 | 63 | # Calculate depth of last dimension after it has been split. 64 | depth = (self.hidden_size // self.num_heads) 65 | 66 | # Split the last dimension 67 | x = tf.reshape(x, [batch_size, length, self.num_heads, depth]) 68 | 69 | # Transpose the result 70 | return tf.transpose(x, [0, 2, 1, 3]) 71 | 72 | def combine_heads(self, x): 73 | """Combine tensor that has been split. 74 | 75 | Args: 76 | x: A tensor [batch_size, num_heads, length, hidden_size/num_heads] 77 | 78 | Returns: 79 | A tensor with shape [batch_size, length, hidden_size] 80 | """ 81 | with tf.name_scope("combine_heads"): 82 | batch_size = tf.shape(x)[0] 83 | length = tf.shape(x)[2] 84 | x = tf.transpose(x, [0, 2, 1, 3]) # --> [batch, length, num_heads, depth] 85 | return tf.reshape(x, [batch_size, length, self.hidden_size]) 86 | 87 | def call(self, x, y, bias, cache=None): 88 | """Apply attention mechanism to x and y. 89 | 90 | Args: 91 | x: a tensor with shape [batch_size, length_x, hidden_size] 92 | y: a tensor with shape [batch_size, length_y, hidden_size] 93 | bias: attention bias that will be added to the result of the dot product. 94 | cache: (Used during prediction) dictionary with tensors containing results 95 | of previous attentions. The dictionary must have the items: 96 | {"k": tensor with shape [batch_size, i, key_channels], 97 | "v": tensor with shape [batch_size, i, value_channels]} 98 | where i is the current decoded length. 99 | 100 | Returns: 101 | Attention layer output with shape [batch_size, length_x, hidden_size] 102 | """ 103 | # Linearly project the query (q), key (k) and value (v) using different 104 | # learned projections. This is in preparation of splitting them into 105 | # multiple heads. Multi-head attention uses multiple queries, keys, and 106 | # values rather than regular attention (which uses a single q, k, v). 107 | q = self.q_dense_layer(x) 108 | k = self.k_dense_layer(y) 109 | v = self.v_dense_layer(y) 110 | 111 | if cache is not None: 112 | # Combine cached keys and values with new keys and values. 113 | k = tf.concat([cache["k"], k], axis=1) 114 | v = tf.concat([cache["v"], v], axis=1) 115 | 116 | # Update cache 117 | cache["k"] = k 118 | cache["v"] = v 119 | 120 | # Split q, k, v into heads. 121 | q = self.split_heads(q) 122 | k = self.split_heads(k) 123 | v = self.split_heads(v) 124 | 125 | # Scale q to prevent the dot product between q and k from growing too large. 126 | depth = (self.hidden_size // self.num_heads) 127 | q *= depth ** -0.5 128 | 129 | # Calculate dot product attention 130 | logits = tf.matmul(q, k, transpose_b=True) 131 | logits += bias 132 | weights = tf.nn.softmax(logits, name="attention_weights") 133 | if self.train: 134 | weights = tf.nn.dropout(weights, 1.0 - self.attention_dropout) 135 | attention_output = tf.matmul(weights, v) 136 | 137 | # Recombine heads --> [batch_size, length, hidden_size] 138 | attention_output = self.combine_heads(attention_output) 139 | 140 | # Run the combined outputs through another linear projection layer. 141 | attention_output = self.output_dense_layer(attention_output) 142 | return attention_output 143 | 144 | 145 | class SelfAttention(Attention): 146 | """Multiheaded self-attention layer.""" 147 | 148 | def call(self, x, bias, cache=None): 149 | return super(SelfAttention, self).call(x, x, bias, cache) 150 | -------------------------------------------------------------------------------- /official_transformer/embedding_layer.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The TensorFlow Authors. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | """Implementation of embedding layer with shared weights.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import tensorflow as tf # pylint: disable=g-bad-import-order 23 | 24 | from official_transformer import tpu as tpu_utils 25 | 26 | 27 | class EmbeddingSharedWeights(tf.layers.Layer): 28 | """Calculates input embeddings and pre-softmax linear with shared weights.""" 29 | 30 | def __init__(self, vocab_size, hidden_size, method="gather"): 31 | """Specify characteristic parameters of embedding layer. 32 | 33 | Args: 34 | vocab_size: Number of tokens in the embedding. (Typically ~32,000) 35 | hidden_size: Dimensionality of the embedding. (Typically 512 or 1024) 36 | method: Strategy for performing embedding lookup. "gather" uses tf.gather 37 | which performs well on CPUs and GPUs, but very poorly on TPUs. "matmul" 38 | one-hot encodes the indicies and formulates the embedding as a sparse 39 | matrix multiplication. The matmul formulation is wasteful as it does 40 | extra work, however matrix multiplication is very fast on TPUs which 41 | makes "matmul" considerably faster than "gather" on TPUs. 42 | """ 43 | super(EmbeddingSharedWeights, self).__init__() 44 | self.vocab_size = vocab_size 45 | self.hidden_size = hidden_size 46 | if method not in ("gather", "matmul"): 47 | raise ValueError("method {} must be 'gather' or 'matmul'".format(method)) 48 | self.method = method 49 | 50 | def build(self, _): 51 | with tf.variable_scope("embedding_and_softmax", reuse=tf.AUTO_REUSE): 52 | # Create and initialize weights. The random normal initializer was chosen 53 | # randomly, and works well. 54 | self.shared_weights = tf.get_variable( 55 | "weights", [self.vocab_size, self.hidden_size], 56 | initializer=tf.random_normal_initializer( 57 | 0., self.hidden_size ** -0.5)) 58 | 59 | self.built = True 60 | 61 | def call(self, x): 62 | """Get token embeddings of x. 63 | 64 | Args: 65 | x: An int64 tensor with shape [batch_size, length] 66 | Returns: 67 | embeddings: float32 tensor with shape [batch_size, length, embedding_size] 68 | padding: float32 tensor with shape [batch_size, length] indicating the 69 | locations of the padding tokens in x. 70 | """ 71 | with tf.name_scope("embedding"): 72 | # Create binary mask of size [batch_size, length] 73 | mask = tf.to_float(tf.not_equal(x, 0)) 74 | 75 | if self.method == "gather": 76 | embeddings = tf.gather(self.shared_weights, x) 77 | embeddings *= tf.expand_dims(mask, -1) 78 | else: # matmul 79 | embeddings = tpu_utils.embedding_matmul( 80 | embedding_table=self.shared_weights, 81 | values=tf.cast(x, dtype=tf.int32), 82 | mask=mask 83 | ) 84 | # embedding_matmul already zeros out masked positions, so 85 | # `embeddings *= tf.expand_dims(mask, -1)` is unnecessary. 86 | 87 | 88 | # Scale embedding by the sqrt of the hidden size 89 | embeddings *= self.hidden_size ** 0.5 90 | 91 | return embeddings 92 | 93 | 94 | def linear(self, x): 95 | """Computes logits by running x through a linear layer. 96 | 97 | Args: 98 | x: A float32 tensor with shape [batch_size, length, hidden_size] 99 | Returns: 100 | float32 tensor with shape [batch_size, length, vocab_size]. 101 | """ 102 | with tf.name_scope("presoftmax_linear"): 103 | batch_size = tf.shape(x)[0] 104 | length = tf.shape(x)[1] 105 | 106 | x = tf.reshape(x, [-1, self.hidden_size]) 107 | logits = tf.matmul(x, self.shared_weights, transpose_b=True) 108 | 109 | return tf.reshape(logits, [batch_size, length, self.vocab_size]) 110 | -------------------------------------------------------------------------------- /official_transformer/ffn_layer.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The TensorFlow Authors. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | """Implementation of fully connected network.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import tensorflow as tf 23 | 24 | 25 | class FeedFowardNetwork(tf.layers.Layer): 26 | """Fully connected feedforward network.""" 27 | 28 | def __init__(self, hidden_size, filter_size, relu_dropout, train, allow_pad): 29 | super(FeedFowardNetwork, self).__init__() 30 | self.hidden_size = hidden_size 31 | self.filter_size = filter_size 32 | self.relu_dropout = relu_dropout 33 | self.train = train 34 | self.allow_pad = allow_pad 35 | 36 | self.filter_dense_layer = tf.layers.Dense( 37 | filter_size, use_bias=True, activation=tf.nn.relu, name="filter_layer") 38 | self.output_dense_layer = tf.layers.Dense( 39 | hidden_size, use_bias=True, name="output_layer") 40 | 41 | def call(self, x, padding=None): 42 | """Return outputs of the feedforward network. 43 | 44 | Args: 45 | x: tensor with shape [batch_size, length, hidden_size] 46 | padding: (optional) If set, the padding values are temporarily removed 47 | from x (provided self.allow_pad is set). The padding values are placed 48 | back in the output tensor in the same locations. 49 | shape [batch_size, length] 50 | 51 | Returns: 52 | Output of the feedforward network. 53 | tensor with shape [batch_size, length, hidden_size] 54 | """ 55 | padding = None if not self.allow_pad else padding 56 | 57 | # Retrieve dynamically known shapes 58 | batch_size = tf.shape(x)[0] 59 | length = tf.shape(x)[1] 60 | 61 | if padding is not None: 62 | with tf.name_scope("remove_padding"): 63 | # Flatten padding to [batch_size*length] 64 | pad_mask = tf.reshape(padding, [-1]) 65 | 66 | nonpad_ids = tf.to_int32(tf.where(pad_mask < 1e-9)) 67 | 68 | # Reshape x to [batch_size*length, hidden_size] to remove padding 69 | x = tf.reshape(x, [-1, self.hidden_size]) 70 | x = tf.gather_nd(x, indices=nonpad_ids) 71 | 72 | # Reshape x from 2 dimensions to 3 dimensions. 73 | x.set_shape([None, self.hidden_size]) 74 | x = tf.expand_dims(x, axis=0) 75 | 76 | output = self.filter_dense_layer(x) 77 | if self.train: 78 | output = tf.nn.dropout(output, 1.0 - self.relu_dropout) 79 | output = self.output_dense_layer(output) 80 | 81 | if padding is not None: 82 | with tf.name_scope("re_add_padding"): 83 | output = tf.squeeze(output, axis=0) 84 | output = tf.scatter_nd( 85 | indices=nonpad_ids, 86 | updates=output, 87 | shape=[batch_size * length, self.hidden_size] 88 | ) 89 | output = tf.reshape(output, [batch_size, length, self.hidden_size]) 90 | return output 91 | -------------------------------------------------------------------------------- /official_transformer/model_params.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The TensorFlow Authors. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | """Defines Transformer model parameters.""" 17 | 18 | from collections import defaultdict 19 | 20 | 21 | BASE_PARAMS = defaultdict( 22 | lambda: None, # Set default value to None. 23 | 24 | # Input params 25 | default_batch_size=2048, # Maximum number of tokens per batch of examples. 26 | default_batch_size_tpu=32768, 27 | max_length=256, # Maximum number of tokens per example. 28 | 29 | # Model params 30 | initializer_gain=1.0, # Used in trainable variable initialization. 31 | vocab_size=33708, # Number of tokens defined in the vocabulary file. 32 | hidden_size=512, # Model dimension in the hidden layers. 33 | num_hidden_layers=6, # Number of layers in the encoder and decoder stacks. 34 | num_heads=8, # Number of heads to use in multi-headed attention. 35 | filter_size=2048, # Inner layer dimension in the feedforward network. 36 | 37 | # Dropout values (only used when training) 38 | layer_postprocess_dropout=0.1, 39 | attention_dropout=0.1, 40 | relu_dropout=0.1, 41 | 42 | # Training params 43 | label_smoothing=0.1, 44 | learning_rate=2.0, 45 | learning_rate_decay_rate=1.0, 46 | learning_rate_warmup_steps=16000, 47 | 48 | # Optimizer params 49 | optimizer_adam_beta1=0.9, 50 | optimizer_adam_beta2=0.997, 51 | optimizer_adam_epsilon=1e-09, 52 | 53 | # Default prediction params 54 | extra_decode_length=50, 55 | beam_size=4, 56 | alpha=0.6, # used to calculate length normalization in beam search 57 | 58 | # TPU specific parameters 59 | use_tpu=False, 60 | static_batch=False, 61 | allow_ffn_pad=True, 62 | ) 63 | 64 | BIG_PARAMS = BASE_PARAMS.copy() 65 | BIG_PARAMS.update( 66 | default_batch_size=4096, 67 | 68 | # default batch size is smaller than for BASE_PARAMS due to memory limits. 69 | default_batch_size_tpu=16384, 70 | 71 | hidden_size=1024, 72 | filter_size=4096, 73 | num_heads=16, 74 | ) 75 | 76 | # Parameters for running the model in multi gpu. These should not change the 77 | # params that modify the model shape (such as the hidden_size or num_heads). 78 | BASE_MULTI_GPU_PARAMS = BASE_PARAMS.copy() 79 | BASE_MULTI_GPU_PARAMS.update( 80 | learning_rate_warmup_steps=8000 81 | ) 82 | 83 | BIG_MULTI_GPU_PARAMS = BIG_PARAMS.copy() 84 | BIG_MULTI_GPU_PARAMS.update( 85 | layer_postprocess_dropout=0.3, 86 | learning_rate_warmup_steps=8000 87 | ) 88 | 89 | # Parameters for testing the model 90 | TINY_PARAMS = BASE_PARAMS.copy() 91 | TINY_PARAMS.update( 92 | default_batch_size=1024, 93 | default_batch_size_tpu=1024, 94 | hidden_size=32, 95 | num_heads=4, 96 | filter_size=256, 97 | ) 98 | -------------------------------------------------------------------------------- /official_transformer/model_utils.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The TensorFlow Authors. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | """Transformer model helper methods.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import math 23 | 24 | import numpy as np 25 | import tensorflow as tf 26 | 27 | # Very low numbers to represent -infinity. We do not actually use -Inf, since we 28 | # want to be able to multiply these values by zero to get zero. (-Inf * 0 = NaN) 29 | _NEG_INF_FP32 = -1e9 30 | _NEG_INF_FP16 = np.finfo(np.float16).min 31 | 32 | 33 | def get_position_encoding( 34 | length, hidden_size, min_timescale=1.0, max_timescale=1.0e2): # TODO previous max_timescale=1.0e4 35 | """Return positional encoding. 36 | 37 | Calculates the position encoding as a mix of sine and cosine functions with 38 | geometrically increasing wavelengths. 39 | Defined and formulized in Attention is All You Need, section 3.5. 40 | 41 | Args: 42 | length: Sequence length. 43 | hidden_size: Size of the 44 | min_timescale: Minimum scale that will be applied at each position 45 | max_timescale: Maximum scale that will be applied at each position 46 | 47 | Returns: 48 | Tensor with shape [length, hidden_size] 49 | """ 50 | # We compute the positional encoding in float32 even if the model uses 51 | # float16, as many of the ops used, like log and exp, are numerically unstable 52 | # in float16. 53 | position = tf.cast(tf.range(length), tf.float32) 54 | num_timescales = hidden_size // 2 55 | log_timescale_increment = ( 56 | math.log(float(max_timescale) / float(min_timescale)) / 57 | (tf.cast(num_timescales, tf.float32) - 1)) 58 | inv_timescales = min_timescale * tf.exp( 59 | tf.cast(tf.range(num_timescales), tf.float32) * -log_timescale_increment) 60 | scaled_time = tf.expand_dims(position, 1) * tf.expand_dims(inv_timescales, 0) 61 | signal = tf.concat([tf.sin(scaled_time), tf.cos(scaled_time)], axis=1) 62 | return signal 63 | 64 | 65 | def get_decoder_self_attention_bias(length, dtype=tf.float32): 66 | """Calculate bias for decoder that maintains model's autoregressive property. 67 | 68 | Creates a tensor that masks out locations that correspond to illegal 69 | connections, so prediction at position i cannot draw information from future 70 | positions. 71 | 72 | Args: 73 | length: int length of sequences in batch. 74 | dtype: The dtype of the return value. 75 | 76 | Returns: 77 | float tensor of shape [1, 1, length, length] 78 | """ 79 | neg_inf = _NEG_INF_FP16 if dtype == tf.float16 else _NEG_INF_FP32 80 | with tf.name_scope("decoder_self_attention_bias"): 81 | valid_locs = tf.linalg.band_part(input=tf.ones([length, length], dtype=dtype), # 2个length分别是输入和输出的最大长度 82 | num_lower=-1, num_upper=0) # Lower triangular part is 1, other is 0 83 | valid_locs = tf.reshape(valid_locs, [1, 1, length, length]) 84 | decoder_bias = neg_inf * (1.0 - valid_locs) 85 | return decoder_bias 86 | 87 | 88 | def get_padding(x, padding_value=0, dtype=tf.float32): 89 | """Return float tensor representing the padding values in x. 90 | 91 | Args: 92 | x: int tensor with any shape 93 | padding_value: int value that 94 | dtype: The dtype of the return value. 95 | 96 | Returns: 97 | float tensor with same shape as x containing values 0 or 1. 98 | 0 -> non-padding, 1 -> padding 99 | """ 100 | with tf.name_scope("padding"): 101 | return tf.cast(tf.equal(x, padding_value), dtype) 102 | 103 | 104 | def get_padding_bias(x): 105 | """Calculate bias tensor from padding values in tensor. 106 | 107 | Bias tensor that is added to the pre-softmax multi-headed attention logits, 108 | which has shape [batch_size, num_heads, length, length]. The tensor is zero at 109 | non-padding locations, and -1e9 (negative infinity) at padding locations. 110 | 111 | Args: 112 | x: int tensor with shape [batch_size, length] 113 | 114 | Returns: 115 | Attention bias tensor of shape [batch_size, 1, 1, length]. 116 | """ 117 | with tf.name_scope("attention_bias"): 118 | padding = get_padding(x) 119 | attention_bias = padding * _NEG_INF_FP32 120 | attention_bias = tf.expand_dims( 121 | tf.expand_dims(attention_bias, axis=1), axis=1) 122 | return attention_bias 123 | -------------------------------------------------------------------------------- /official_transformer/tpu.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The TensorFlow Authors. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | """Functions specific to running TensorFlow on TPUs.""" 17 | 18 | import tensorflow as tf 19 | 20 | 21 | # "local" is a magic word in the TPU cluster resolver; it informs the resolver 22 | # to use the local CPU as the compute device. This is useful for testing and 23 | # debugging; the code flow is ostensibly identical, but without the need to 24 | # actually have a TPU on the other end. 25 | LOCAL = "local" 26 | 27 | 28 | def construct_scalar_host_call(metric_dict, model_dir, prefix=""): 29 | """Construct a host call to log scalars when training on TPU. 30 | 31 | Args: 32 | metric_dict: A dict of the tensors to be logged. 33 | model_dir: The location to write the summary. 34 | prefix: The prefix (if any) to prepend to the metric names. 35 | 36 | Returns: 37 | A tuple of (function, args_to_be_passed_to_said_function) 38 | """ 39 | # type: (dict, str) -> (function, list) 40 | metric_names = list(metric_dict.keys()) 41 | 42 | def host_call_fn(global_step, *args): 43 | """Training host call. Creates scalar summaries for training metrics. 44 | 45 | This function is executed on the CPU and should not directly reference 46 | any Tensors in the rest of the `model_fn`. To pass Tensors from the 47 | model to the `metric_fn`, provide as part of the `host_call`. See 48 | https://www.tensorflow.org/api_docs/python/tf/contrib/tpu/TPUEstimatorSpec 49 | for more information. 50 | 51 | Arguments should match the list of `Tensor` objects passed as the second 52 | element in the tuple passed to `host_call`. 53 | 54 | Args: 55 | global_step: `Tensor with shape `[batch]` for the global_step 56 | *args: Remaining tensors to log. 57 | 58 | Returns: 59 | List of summary ops to run on the CPU host. 60 | """ 61 | step = global_step[0] 62 | with tf.contrib.summary.create_file_writer( 63 | logdir=model_dir, filename_suffix=".host_call").as_default(): 64 | with tf.contrib.summary.always_record_summaries(): 65 | for i, name in enumerate(metric_names): 66 | tf.contrib.summary.scalar(prefix + name, args[i][0], step=step) 67 | 68 | return tf.contrib.summary.all_summary_ops() 69 | 70 | # To log the current learning rate, and gradient norm for Tensorboard, the 71 | # summary op needs to be run on the host CPU via host_call. host_call 72 | # expects [batch_size, ...] Tensors, thus reshape to introduce a batch 73 | # dimension. These Tensors are implicitly concatenated to 74 | # [params['batch_size']]. 75 | global_step_tensor = tf.reshape( 76 | tf.compat.v1.train.get_or_create_global_step(), [1]) 77 | other_tensors = [tf.reshape(metric_dict[key], [1]) for key in metric_names] 78 | 79 | return host_call_fn, [global_step_tensor] + other_tensors 80 | 81 | 82 | def embedding_matmul(embedding_table, values, mask, name="embedding_matmul"): 83 | """Performs embedding lookup via a matmul. 84 | 85 | The matrix to be multiplied by the embedding table Tensor is constructed 86 | via an implementation of scatter based on broadcasting embedding indices 87 | and performing an equality comparison against a broadcasted 88 | range(num_embedding_table_rows). All masked positions will produce an 89 | embedding vector of zeros. 90 | 91 | Args: 92 | embedding_table: Tensor of embedding table. 93 | Rank 2 (table_size x embedding dim) 94 | values: Tensor of embedding indices. Rank 2 (batch x n_indices) 95 | mask: Tensor of mask / weights. Rank 2 (batch x n_indices) 96 | name: Optional name scope for created ops 97 | 98 | Returns: 99 | Rank 3 tensor of embedding vectors. 100 | """ 101 | 102 | with tf.name_scope(name): 103 | n_embeddings = embedding_table.get_shape().as_list()[0] 104 | batch_size, padded_size = values.shape.as_list() 105 | 106 | emb_idcs = tf.tile( 107 | tf.reshape(values, (batch_size, padded_size, 1)), (1, 1, n_embeddings)) 108 | emb_weights = tf.tile( 109 | tf.reshape(mask, (batch_size, padded_size, 1)), (1, 1, n_embeddings)) 110 | col_idcs = tf.tile( 111 | tf.reshape(tf.range(n_embeddings), (1, 1, n_embeddings)), 112 | (batch_size, padded_size, 1)) 113 | one_hot = tf.where( 114 | tf.equal(emb_idcs, col_idcs), emb_weights, 115 | tf.zeros((batch_size, padded_size, n_embeddings))) 116 | 117 | return tf.tensordot(one_hot, embedding_table, 1) 118 | -------------------------------------------------------------------------------- /phrase_vocabulary_optimization.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2019 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # Lint as: python3 17 | """Optimizes the vocabulary of phrases that can be added by LaserTagger. 18 | 19 | The goal is to find a fixed-size set of phrases that cover as many training 20 | examples as possible. Based on the phrases, saves a file containing all possible 21 | tags to be predicted and another file reporting the percentage of covered 22 | training examples with different vocabulary sizes. 23 | """ 24 | 25 | from __future__ import absolute_import 26 | from __future__ import division 27 | 28 | from __future__ import print_function 29 | 30 | import collections 31 | from typing import Sequence, Text 32 | 33 | from absl import app 34 | from absl import flags 35 | from absl import logging 36 | 37 | import utils 38 | 39 | import numpy as np 40 | import scipy.sparse 41 | import tensorflow as tf 42 | from compute_lcs import _compute_lcs 43 | from utils import my_tokenizer_class 44 | from curLine_file import curLine 45 | 46 | FLAGS = flags.FLAGS 47 | 48 | flags.DEFINE_string( 49 | 'input_file', None, 50 | 'Path to the input file containing source-target pairs from which the ' 51 | 'vocabulary is optimized (see `input_format` flag and utils.py for ' 52 | 'documentation).') 53 | flags.DEFINE_enum( 54 | 'input_format', None, ['wikisplit', 'discofuse'], 55 | 'Format which indicates how to parse the `input_file`. See utils.py for ' 56 | 'documentation on the different formats.') 57 | flags.DEFINE_integer( 58 | 'max_input_examples', 50000, 59 | 'At most this many examples from the `input_file` are used for optimizing ' 60 | 'the vocabulary.') 61 | flags.DEFINE_string( 62 | 'output_file', None, 63 | 'Path to the resulting file with all possible tags. Coverage numbers will ' 64 | 'be written to a separate file which has the same path but ".log" appended ' 65 | 'to it.') 66 | flags.DEFINE_bool('enable_swap_tag', True, 'Whether to enable the SWAP tag.') 67 | flags.DEFINE_integer('vocabulary_size', 500, 68 | 'Number of phrases to include in the vocabulary.') 69 | flags.DEFINE_integer( 70 | 'num_extra_statistics', 100, 71 | 'Number of extra phrases that are not included in the vocabulary but for ' 72 | 'which we compute the coverage numbers. These numbers help determining ' 73 | 'whether the vocabulary size should have been larger.') 74 | 75 | flags.DEFINE_string('vocab_file', None, 'Path to the BERT vocabulary file.') 76 | flags.DEFINE_bool( 77 | 'do_lower_case', True, 78 | 'Whether to lower case the input text. Should be True for uncased ' 79 | 'models and False for cased models.') 80 | 81 | def _get_added_phrases(source: Text, target: Text, tokenizer) -> Sequence[Text]: 82 | """Computes the phrases that need to be added to the source to get the target. 83 | 84 | This is done by aligning each token in the LCS to the first match in the 85 | target and checking which phrases in the target remain unaligned. 86 | 87 | TODO(b/142853960): The LCS tokens should ideally be aligned to consecutive(连续不断的) 88 | target tokens whenever possible, instead of aligning them always to the first 89 | match. This should result in a more meaningful phrase vocabulary with a higher 90 | coverage. 91 | 92 | Note that the algorithm is case-insensitive and the resulting phrases are 93 | always lowercase. 94 | 95 | Args: 96 | source: Source text. 97 | target: Target text. 98 | 99 | Returns: 100 | List of added phrases. 101 | """ 102 | # 英文是分成word sep=' ',中文是分成字 sep='' 103 | sep = '' 104 | source_tokens = tokenizer.tokenize(source) 105 | target_tokens = tokenizer.tokenize(target) 106 | 107 | kept_tokens = _compute_lcs(source_tokens, target_tokens) 108 | added_phrases = [] 109 | # Index of the `kept_tokens` element that we are currently looking for. 110 | kept_idx = 0 111 | phrase = [] 112 | for token in target_tokens: 113 | if kept_idx < len(kept_tokens) and token == kept_tokens[kept_idx]: 114 | kept_idx += 1 115 | if phrase: 116 | added_phrases.append(sep.join(phrase)) 117 | phrase = [] 118 | else: 119 | phrase.append(token) 120 | if phrase: 121 | added_phrases.append(sep.join(phrase)) 122 | return added_phrases 123 | 124 | 125 | def _added_token_counts(data_iterator, try_swapping, max_input_examples=10000, tokenizer=None): 126 | """Computes how many times different phrases have to be added. 127 | 128 | Args: 129 | data_iterator: Iterator to yield source lists and targets. See function 130 | yield_sources_and_targets in utils.py for the available iterators. The 131 | strings in the source list will be concatenated, possibly after swapping 132 | their order if swapping is enabled. 133 | try_swapping: Whether to try if swapping sources results in less added text. 134 | max_input_examples: Maximum number of examples to be read from the iterator. 135 | 136 | Returns: 137 | Tuple (collections.Counter for phrases, added phrases for each example). 138 | """ 139 | phrase_counter = collections.Counter() 140 | num_examples = 0 141 | all_added_phrases = [] 142 | max_seq_length = 0 143 | for sources, target in data_iterator: 144 | # sources 可能是多句话,后面用空格拼接起来 145 | if num_examples >= max_input_examples: 146 | break 147 | source_merge = ' '.join(sources) 148 | if len(source_merge) > max_seq_length: 149 | print(curLine(), "max_seq_length=%d, len(source_merge)=%d,source_merge:%s" % 150 | (max_seq_length, len(source_merge), source_merge)) 151 | max_seq_length = len(source_merge) 152 | logging.log_every_n(logging.INFO, f'{num_examples} examples processed.', 10000) 153 | added_phrases = _get_added_phrases(source_merge, target, tokenizer) 154 | if try_swapping and len(sources) == 2: 155 | added_phrases_swap = _get_added_phrases(' '.join(sources[::-1]), target, tokenizer) 156 | # If we can align more and have to add less after swapping, we assume that 157 | # the sources would be swapped during conversion. 158 | if len(''.join(added_phrases_swap)) < len(''.join(added_phrases)): 159 | added_phrases = added_phrases_swap 160 | for phrase in added_phrases: 161 | phrase_counter[phrase] += 1 162 | all_added_phrases.append(added_phrases) 163 | num_examples += 1 164 | logging.info(f'{num_examples} examples processed.\n') 165 | return phrase_counter, all_added_phrases, max_seq_length 166 | 167 | 168 | def _construct_added_phrases_matrix(all_added_phrases, phrase_counter): 169 | """Constructs a sparse phrase occurrence matrix. 170 | 171 | Examples are on rows and phrases on columns. 172 | 173 | Args: 174 | all_added_phrases: List of lists of added phrases (one list per example). 175 | phrase_counter: Frequence of each unique added phrase. 176 | 177 | Returns: 178 | Sparse boolean matrix whose element (i, j) indicates whether example i 179 | contains the added phrase j. Columns start from the most frequent phrase. 180 | """ 181 | phrase_2_idx = { 182 | tup[0]: i for i, tup in enumerate(phrase_counter.most_common()) 183 | } 184 | matrix = scipy.sparse.dok_matrix((len(all_added_phrases), len(phrase_2_idx)), 185 | dtype=np.bool) 186 | for i, added_phrases in enumerate(all_added_phrases): 187 | for phrase in added_phrases: 188 | phrase_idx = phrase_2_idx[phrase] 189 | matrix[i, phrase_idx] = True 190 | # Convert to CSC format to support more efficient column slicing. 191 | return matrix.tocsc() 192 | 193 | 194 | def _count_covered_examples(matrix, vocabulary_size): 195 | """Returns the number of examples whose added phrases are in the vocabulary. 196 | 197 | This assumes the vocabulary is created simply by selecting the 198 | `vocabulary_size` most frequent phrases. 199 | 200 | Args: 201 | matrix: Phrase occurrence matrix with the most frequent phrases on the 202 | left-most columns. 203 | vocabulary_size: Number of most frequent phrases to include in the 204 | vocabulary. 205 | """ 206 | # Ignore the `vocabulary_size` most frequent (i.e. leftmost) phrases (i.e. 207 | # columns) and count the rows with zero added phrases. 208 | return (matrix[:, vocabulary_size:].sum(axis=1) == 0).sum() 209 | 210 | 211 | def main(argv): 212 | if len(argv) > 1: 213 | raise app.UsageError('Too many command-line arguments.') 214 | flags.mark_flag_as_required('input_file') 215 | flags.mark_flag_as_required('input_format') 216 | flags.mark_flag_as_required('output_file') 217 | 218 | data_iterator = utils.yield_sources_and_targets(FLAGS.input_file, 219 | FLAGS.input_format) 220 | tokenizer = my_tokenizer_class(FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case) 221 | phrase_counter, all_added_phrases, max_seq_length = _added_token_counts( 222 | data_iterator, FLAGS.enable_swap_tag, FLAGS.max_input_examples, tokenizer) 223 | matrix = _construct_added_phrases_matrix(all_added_phrases, phrase_counter) 224 | num_examples = len(all_added_phrases) 225 | 226 | statistics_file = FLAGS.output_file + '.log' 227 | with tf.io.gfile.GFile(FLAGS.output_file, 'w') as writer: 228 | with tf.io.gfile.GFile(statistics_file, 'w') as stats_writer: 229 | stats_writer.write('Idx\tFrequency\tCoverage (%)\tPhrase\n') 230 | writer.write('KEEP\n') 231 | writer.write('DELETE\n') 232 | if FLAGS.enable_swap_tag: 233 | writer.write('SWAP\n') 234 | for i, (phrase, count) in enumerate( 235 | phrase_counter.most_common(FLAGS.vocabulary_size + 236 | FLAGS.num_extra_statistics)): 237 | # Write tags. 238 | if i < FLAGS.vocabulary_size: # TODO 为什么要限制一个phrase既能在KEEP前,也能在DELETE前??  239 | writer.write(f'KEEP|{phrase}\n') 240 | writer.write(f'DELETE|{phrase}\n') 241 | # Write statistics. 242 | coverage = 100.0 * _count_covered_examples(matrix, i + 1) / num_examples # 用前i+1个高频phrase能覆盖的语料的比例 243 | stats_writer.write(f'{i + 1}\t{count}\t{coverage:.2f}\t{phrase}\n') 244 | logging.info(f'Wrote tags to: {FLAGS.output_file}') 245 | logging.info(f'Wrote coverage numbers to: {statistics_file}') 246 | print(curLine(), "max_seq_length=", max_seq_length) 247 | 248 | 249 | if __name__ == '__main__': 250 | app.run(main) 251 | -------------------------------------------------------------------------------- /predict_main.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2019 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # Lint as: python3 17 | """Compute realized predictions for a dataset.""" 18 | 19 | from __future__ import absolute_import 20 | from __future__ import division 21 | from __future__ import print_function 22 | 23 | from absl import app 24 | from absl import flags 25 | from absl import logging 26 | import math, time 27 | from termcolor import colored 28 | import tensorflow as tf 29 | 30 | import bert_example 31 | import predict_utils 32 | import tagging_converter 33 | import utils 34 | from utils import my_tokenizer_class 35 | from curLine_file import curLine 36 | 37 | FLAGS = flags.FLAGS 38 | 39 | flags.DEFINE_string( 40 | 'input_file', None, 41 | 'Path to the input file containing examples for which to compute ' 42 | 'predictions.') 43 | flags.DEFINE_enum( 44 | 'input_format', None, ['wikisplit', 'discofuse'], 45 | 'Format which indicates how to parse the input_file.') 46 | flags.DEFINE_string( 47 | 'output_file', None, 48 | 'Path to the TSV file where the predictions are written to.') 49 | flags.DEFINE_string( 50 | 'label_map_file', None, 51 | 'Path to the label map file. Either a JSON file ending with ".json", that ' 52 | 'maps each possible tag to an ID, or a text file that has one tag per ' 53 | 'line.') 54 | flags.DEFINE_string('vocab_file', None, 'Path to the BERT vocabulary file.') 55 | flags.DEFINE_integer('max_seq_length', 128, 'Maximum sequence length.') 56 | flags.DEFINE_bool( 57 | 'do_lower_case', True, 58 | 'Whether to lower case the input text. Should be True for uncased ' 59 | 'models and False for cased models.') 60 | flags.DEFINE_bool('enable_swap_tag', True, 'Whether to enable the SWAP tag.') 61 | flags.DEFINE_string('saved_model', None, 'Path to an exported TF model.') 62 | 63 | 64 | def main(argv): 65 | if len(argv) > 1: 66 | raise app.UsageError('Too many command-line arguments.') 67 | flags.mark_flag_as_required('input_file') 68 | flags.mark_flag_as_required('input_format') 69 | flags.mark_flag_as_required('output_file') 70 | flags.mark_flag_as_required('label_map_file') 71 | flags.mark_flag_as_required('vocab_file') 72 | flags.mark_flag_as_required('saved_model') 73 | 74 | label_map = utils.read_label_map(FLAGS.label_map_file) 75 | tokenizer = my_tokenizer_class(FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case) 76 | converter = tagging_converter.TaggingConverter( 77 | tagging_converter.get_phrase_vocabulary_from_label_map(label_map), 78 | FLAGS.enable_swap_tag, tokenizer) 79 | builder = bert_example.BertExampleBuilder(label_map, FLAGS.vocab_file, 80 | FLAGS.max_seq_length, 81 | FLAGS.do_lower_case, converter) 82 | predictor = predict_utils.LaserTaggerPredictor( 83 | tf.contrib.predictor.from_saved_model(FLAGS.saved_model), builder, 84 | label_map) 85 | print(colored("%s input file:%s" % (curLine(), FLAGS.input_file), "red")) 86 | sources_list = [] 87 | target_list = [] 88 | with tf.io.gfile.GFile(FLAGS.input_file) as f: 89 | for line in f: 90 | sources, target, lcs_rate = line.rstrip('\n').split('\t') 91 | sources_list.append([sources]) 92 | target_list.append(target) 93 | number = len(sources_list) # 总样本数 94 | predict_batch_size = min(96, number) 95 | batch_num = math.ceil(float(number) / predict_batch_size) 96 | 97 | start_time = time.time() 98 | num_predicted = 0 99 | with tf.gfile.Open(FLAGS.output_file, 'w') as writer: 100 | writer.write(f'source\tprediction\ttarget\n') 101 | for batch_id in range(batch_num): 102 | sources_batch = sources_list[batch_id * predict_batch_size: (batch_id + 1) * predict_batch_size] 103 | prediction_batch = predictor.predict_batch(sources_batch=sources_batch) 104 | assert len(prediction_batch) == len(sources_batch) 105 | num_predicted += len(prediction_batch) 106 | for id, [prediction, sources] in enumerate(zip(prediction_batch, sources_batch)): 107 | target = target_list[batch_id * predict_batch_size + id] 108 | writer.write(f'{"".join(sources)}\t{prediction}\t{target}\n') 109 | if batch_id % 20 == 0: 110 | cost_time = (time.time() - start_time) / 60.0 111 | print("%s batch_id=%d/%d, predict %d/%d examples, cost %.2fmin." % 112 | (curLine(), batch_id + 1, batch_num, num_predicted, number, cost_time)) 113 | cost_time = (time.time() - start_time) / 60.0 114 | logging.info( 115 | f'{curLine()} {num_predicted} predictions saved to:{FLAGS.output_file}, cost {cost_time} min, ave {cost_time / num_predicted} min.') 116 | 117 | 118 | if __name__ == '__main__': 119 | app.run(main) 120 | -------------------------------------------------------------------------------- /predict_utils.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2019 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # Lint as: python3 17 | """Utility functions for running inference with a LaserTagger model.""" 18 | 19 | from __future__ import absolute_import 20 | from __future__ import division 21 | 22 | from __future__ import print_function 23 | from collections import defaultdict 24 | import tagging 25 | from curLine_file import curLine 26 | 27 | class LaserTaggerPredictor(object): 28 | """Class for computing and realizing predictions with LaserTagger.""" 29 | 30 | def __init__(self, tf_predictor, 31 | example_builder, 32 | label_map): 33 | """Initializes an instance of LaserTaggerPredictor. 34 | 35 | Args: 36 | tf_predictor: Loaded Tensorflow model. 37 | example_builder: BERT example builder. 38 | label_map: Mapping from tags to tag IDs. 39 | """ 40 | self._predictor = tf_predictor 41 | self._example_builder = example_builder 42 | self._id_2_tag = { 43 | tag_id: tagging.Tag(tag) for tag, tag_id in label_map.items() 44 | } 45 | 46 | def predict_batch(self, sources_batch, location_batch=None): # 由predict改成 47 | """Returns realized prediction for given sources.""" 48 | # Predict tag IDs. 49 | keys = ['input_ids', 'input_mask', 'segment_ids'] 50 | input_info = defaultdict(list) 51 | example_list = [] 52 | location = None 53 | for id, sources in enumerate(sources_batch): 54 | if location_batch is not None: 55 | location = location_batch[id] # 表示是否能修改 56 | example = self._example_builder.build_bert_example(sources, location=location) 57 | if example is None: 58 | raise ValueError("Example couldn't be built.") 59 | for key in keys: 60 | input_info[key].append(example.features[key]) 61 | example_list.append(example) 62 | 63 | out = self._predictor(input_info) 64 | prediction_list = [] 65 | for output, outputs_outputs, example in zip(out['pred'], out['outputs_outputs'], example_list): 66 | predicted_ids = output.tolist() 67 | # Realize output. 68 | example.features['labels'] = predicted_ids 69 | # Mask out the begin and the end token. 70 | example.features['labels_mask'] = [0] + [1] * (len(predicted_ids) - 2) + [0] 71 | labels = [ 72 | self._id_2_tag[label_id] for label_id in example.get_token_labels() 73 | ] 74 | prediction = example.editing_task.realize_output(labels) 75 | prediction_list.append(prediction) 76 | return prediction_list 77 | -------------------------------------------------------------------------------- /preprocess_main.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2019 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # Lint as: python3 17 | """Convert a dataset into the TFRecord format. 18 | 19 | The resulting TFRecord file will be used when training a LaserTagger model. 20 | """ 21 | 22 | from __future__ import absolute_import 23 | from __future__ import division 24 | 25 | from __future__ import print_function 26 | 27 | from typing import Text 28 | 29 | from absl import app 30 | from absl import flags 31 | from absl import logging 32 | 33 | import bert_example 34 | import tagging_converter 35 | import utils 36 | 37 | import tensorflow as tf 38 | from curLine_file import curLine 39 | 40 | FLAGS = flags.FLAGS 41 | 42 | flags.DEFINE_string( 43 | 'input_file', None, 44 | 'Path to the input file containing examples to be converted to ' 45 | 'tf.Examples.') 46 | flags.DEFINE_enum( 47 | 'input_format', None, ['wikisplit', 'discofuse'], 48 | 'Format which indicates how to parse the input_file.') 49 | flags.DEFINE_string('output_tfrecord', None, 50 | 'Path to the resulting TFRecord file.') 51 | flags.DEFINE_string( 52 | 'label_map_file', None, 53 | 'Path to the label map file. Either a JSON file ending with ".json", that ' 54 | 'maps each possible tag to an ID, or a text file that has one tag per ' 55 | 'line.') 56 | flags.DEFINE_string('vocab_file', None, 'Path to the BERT vocabulary file.') 57 | flags.DEFINE_integer('max_seq_length', 128, 'Maximum sequence length.') 58 | flags.DEFINE_bool( 59 | 'do_lower_case', True, 60 | 'Whether to lower case the input text. Should be True for uncased ' 61 | 'models and False for cased models.') 62 | flags.DEFINE_bool('enable_swap_tag', True, 'Whether to enable the SWAP tag.') 63 | flags.DEFINE_bool( 64 | 'output_arbitrary_targets_for_infeasible_examples', False, 65 | 'Set this to True when preprocessing the development set. Determines ' 66 | 'whether to output a TF example also for sources that can not be converted ' 67 | 'to target via the available tagging operations. In these cases, the ' 68 | 'target ids will correspond to the tag sequence KEEP-DELETE-KEEP-DELETE... ' 69 | 'which should be very unlikely to be predicted by chance. This will be ' 70 | 'useful for getting more accurate eval scores during training.') 71 | 72 | 73 | def _write_example_count(count: int) -> Text: 74 | """Saves the number of converted examples to a file. 75 | 76 | This count is used when determining the number of training steps. 77 | 78 | Args: 79 | count: The number of converted examples. 80 | 81 | Returns: 82 | The filename to which the count is saved. 83 | """ 84 | count_fname = FLAGS.output_tfrecord + '.num_examples.txt' 85 | with tf.io.gfile.GFile(count_fname, 'w') as count_writer: 86 | count_writer.write(str(count)) 87 | return count_fname 88 | 89 | 90 | def main(argv): 91 | if len(argv) > 1: 92 | raise app.UsageError('Too many command-line arguments.') 93 | flags.mark_flag_as_required('input_file') 94 | flags.mark_flag_as_required('input_format') 95 | flags.mark_flag_as_required('output_tfrecord') 96 | flags.mark_flag_as_required('label_map_file') 97 | flags.mark_flag_as_required('vocab_file') 98 | 99 | label_map = utils.read_label_map(FLAGS.label_map_file) 100 | tokenizer = utils.my_tokenizer_class(FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case) 101 | converter = tagging_converter.TaggingConverter( 102 | tagging_converter.get_phrase_vocabulary_from_label_map(label_map), # phrase_vocabulary set 103 | FLAGS.enable_swap_tag, tokenizer=tokenizer) 104 | # print(curLine(), len(label_map), "label_map:", label_map, converter._max_added_phrase_length) 105 | builder = bert_example.BertExampleBuilder(label_map, FLAGS.vocab_file, 106 | FLAGS.max_seq_length, 107 | FLAGS.do_lower_case, converter) 108 | 109 | num_converted = 0 110 | with tf.io.TFRecordWriter(FLAGS.output_tfrecord) as writer: 111 | for i, (sources, target) in enumerate(utils.yield_sources_and_targets( 112 | FLAGS.input_file, FLAGS.input_format)): 113 | logging.log_every_n( 114 | logging.INFO, 115 | f'{i} examples processed, {num_converted} converted to tf.Example.', 116 | 10000) 117 | example = builder.build_bert_example( 118 | sources, target, 119 | FLAGS.output_arbitrary_targets_for_infeasible_examples) 120 | if example is None: 121 | continue # 根据output_arbitrary_targets_for_infeasible_examples,不能转化的忽略或随机,如果随机也会加到num_converted 122 | writer.write(example.to_tf_example().SerializeToString()) 123 | num_converted += 1 124 | logging.info(f'Done. {num_converted} examples converted to tf.Example.') 125 | count_fname = _write_example_count(num_converted) 126 | logging.info(f'Wrote:\n{FLAGS.output_tfrecord}\n{count_fname}') 127 | 128 | 129 | if __name__ == '__main__': 130 | app.run(main) 131 | -------------------------------------------------------------------------------- /qa_rephrase/__init__.py: -------------------------------------------------------------------------------- 1 | import bert_example 2 | import predict_utils 3 | import tagging_converter 4 | import utils -------------------------------------------------------------------------------- /qa_rephrase/predict_for_qa.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # 为domain识别泛化语料 3 | 4 | from __future__ import absolute_import 5 | from __future__ import division 6 | from __future__ import print_function 7 | 8 | from absl import app 9 | from absl import flags 10 | from absl import logging 11 | import os, sys, time 12 | from termcolor import colored 13 | import math 14 | import tensorflow as tf 15 | 16 | block_list = os.path.realpath(__file__).split("/") 17 | path = "/".join(block_list[:-2]) 18 | sys.path.append(path) 19 | from compute_lcs import _compute_lcs 20 | import bert_example 21 | import predict_utils 22 | import tagging_converter 23 | import utils 24 | 25 | from curLine_file import curLine 26 | 27 | FLAGS = flags.FLAGS 28 | 29 | flags.DEFINE_string( 30 | 'input_file', None, 31 | 'Path to the input file containing examples for which to compute ' 32 | 'predictions.') 33 | flags.DEFINE_enum( 34 | 'input_format', None, ['wikisplit', 'discofuse'], 35 | 'Format which indicates how to parse the input_file.') 36 | flags.DEFINE_string( 37 | 'output_file', None, 38 | 'Path to the TSV file where the predictions are written to.') 39 | flags.DEFINE_string( 40 | 'label_map_file', None, 41 | 'Path to the label map file. Either a JSON file ending with ".json", that ' 42 | 'maps each possible tag to an ID, or a text file that has one tag per ' 43 | 'line.') 44 | flags.DEFINE_string('vocab_file', None, 'Path to the BERT vocabulary file.') 45 | flags.DEFINE_integer('max_seq_length', 128, 'Maximum sequence length.') 46 | flags.DEFINE_bool( 47 | 'do_lower_case', False, 48 | 'Whether to lower case the input text. Should be True for uncased ' 49 | 'models and False for cased models.') 50 | flags.DEFINE_bool('enable_swap_tag', True, 'Whether to enable the SWAP tag.') 51 | flags.DEFINE_string('saved_model', None, 'Path to an exported TF model.') 52 | 53 | def predict_and_write(predictor, sources_batch, previous_line_list,context_list, writer, num_predicted, start_time, batch_num): 54 | prediction_batch = predictor.predict_batch(sources_batch=sources_batch) 55 | assert len(prediction_batch) == len(sources_batch) 56 | for id, [prediction, sources] in enumerate(zip(prediction_batch, sources_batch)): 57 | output = "" 58 | if len(prediction) > 1 and prediction != sources: # TODO ignore keep totelly and short prediction 59 | output= "%s%s" % (context_list[id], prediction) # 需要和context拼接么 60 | # print(curLine(), "prediction:", prediction, "sources:", sources, ",output:", output, prediction != sources) 61 | writer.write("%s\t%s\n" % (previous_line_list[id], output)) 62 | batch_num = batch_num + 1 63 | num_predicted += len(prediction_batch) 64 | if batch_num % 200 == 0: 65 | cost_time = (time.time() - start_time) / 3600.0 66 | print("%s batch_id=%d, predict %d examples, cost %.3fh." % 67 | (curLine(), batch_num, num_predicted, cost_time)) 68 | return num_predicted, batch_num 69 | 70 | 71 | def main(argv): 72 | if len(argv) > 1: 73 | raise app.UsageError('Too many command-line arguments.') 74 | flags.mark_flag_as_required('input_file') 75 | flags.mark_flag_as_required('input_format') 76 | flags.mark_flag_as_required('output_file') 77 | flags.mark_flag_as_required('label_map_file') 78 | flags.mark_flag_as_required('vocab_file') 79 | flags.mark_flag_as_required('saved_model') 80 | 81 | label_map = utils.read_label_map(FLAGS.label_map_file) 82 | converter = tagging_converter.TaggingConverter( 83 | tagging_converter.get_phrase_vocabulary_from_label_map(label_map), 84 | FLAGS.enable_swap_tag) 85 | builder = bert_example.BertExampleBuilder(label_map, FLAGS.vocab_file, 86 | FLAGS.max_seq_length, 87 | FLAGS.do_lower_case, converter) 88 | predictor = predict_utils.LaserTaggerPredictor( 89 | tf.contrib.predictor.from_saved_model(FLAGS.saved_model), builder, 90 | label_map) 91 | print(colored("%s input file:%s" % (curLine(), FLAGS.input_file), "red")) 92 | sourcesA_list = [] 93 | sourcesB_list = [] 94 | target_list = [] 95 | with tf.io.gfile.GFile(FLAGS.input_file) as f: 96 | for line in f: 97 | sourceA, sourceB, label = line.rstrip('\n').split('\t') 98 | sourcesA_list.append([sourceA.strip(".")]) 99 | sourcesB_list.append([sourceB.strip(".")]) 100 | target_list.append(label) 101 | 102 | 103 | 104 | number = len(sourcesA_list) # 总样本数 105 | predict_batch_size = min(32, number) 106 | batch_num = math.ceil(float(number) / predict_batch_size) 107 | 108 | start_time = time.time() 109 | num_predicted = 0 110 | prediction_list = [] 111 | with tf.gfile.Open(FLAGS.output_file, 'w') as writer: 112 | for batch_id in range(batch_num): 113 | sources_batch = sourcesA_list[batch_id * predict_batch_size: (batch_id + 1) * predict_batch_size] 114 | batch_b = sourcesB_list[batch_id * predict_batch_size: (batch_id + 1) * predict_batch_size] 115 | location_batch = [] 116 | sources_batch.extend(batch_b) 117 | for source in sources_batch: 118 | location = list() 119 | for char in source[0]: 120 | if (char>='0' and char<='9') or char in '.- ' or (char>='a' and char<='z') or (char>='A' and char<='Z'): 121 | location.append("1") # TODO TODO 122 | else: 123 | location.append("0") 124 | location_batch.append("".join(location)) 125 | prediction_batch = predictor.predict_batch(sources_batch=sources_batch, location_batch=location_batch) 126 | current_batch_size = int(len(sources_batch)/2) 127 | assert len(prediction_batch) == current_batch_size*2 128 | 129 | for id in range(0, current_batch_size): 130 | target = target_list[num_predicted+id] 131 | prediction_A = prediction_batch[id] 132 | prediction_B = prediction_batch[current_batch_size+id] 133 | sourceA = "".join(sources_batch[id]) 134 | sourceB = "".join(sources_batch[current_batch_size + id]) 135 | if prediction_A == prediction_B: # 其中一个换为source 136 | lcsA = len(_compute_lcs(sourceA, prediction_A)) 137 | if lcsA < 8: # A的变化大 138 | prediction_B = sourceB 139 | else: 140 | lcsB = len(_compute_lcs(sourceB, prediction_B)) 141 | if lcsA <= lcsB: # A的变化大 142 | prediction_B = sourceB 143 | else: 144 | prediction_A = sourceA 145 | print(curLine(), batch_id, prediction_A, prediction_B, "target:", target, "current_batch_size=", 146 | current_batch_size, "lcsA=%d,lcsB=%d" % (lcsA, lcsB)) 147 | writer.write(f'{prediction_A}\t{prediction_B}\t{target}\n') 148 | 149 | prediction_list.append("%s\t%s\n"% (sourceA, prediction_A)) 150 | # print(curLine(), id,"sourceA:", sourceA, "sourceB:",sourceB, "target:", target) 151 | prediction_list.append("%s\t%s\n" % (sourceB, prediction_B)) 152 | num_predicted += current_batch_size 153 | if batch_id % 20 == 0: 154 | cost_time = (time.time() - start_time) / 60.0 155 | print(curLine(), id, prediction_A, prediction_B, "target:", target, "current_batch_size=", current_batch_size) 156 | print(curLine(), id,"sourceA:", sourceA, "sourceB:",sourceB, "target:", target) 157 | print("%s batch_id=%d/%d, predict %d/%d examples, cost %.2fmin." % 158 | (curLine(), batch_id + 1, batch_num, num_predicted, number, cost_time)) 159 | with open("prediction.txt", "w") as prediction_file: 160 | prediction_file.writelines(prediction_list) 161 | print(curLine(), "save to prediction_qa.txt.") 162 | cost_time = (time.time() - start_time) / 60.0 163 | print(curLine(), id, prediction_A, prediction_B, "target:", target, "current_batch_size=", current_batch_size) 164 | print(curLine(), id, "sourceA:", sourceA, "sourceB:", sourceB, "target:", target) 165 | logging.info( 166 | f'{curLine()} {num_predicted} predictions saved to:{FLAGS.output_file}, cost {cost_time} min, ave {cost_time / num_predicted*60000}ms.') 167 | 168 | if __name__ == '__main__': 169 | app.run(main) 170 | -------------------------------------------------------------------------------- /qa_rephrase/score_for_qa.txt: -------------------------------------------------------------------------------- 1 | 1200 2 | Exact score: 12.843 3 | SARI score: 51.313 4 | KEEP score: 76.046 5 | ADDITION score: 20.086 6 | DELETION score: 57.808 7 | cost time 0.337063 h 8 | 9 | 房山 Restoring parameters from /home/wzk/Mywork/corpus/rephrase_corpus/output/models/wikisplit_experiment_name/model.ckpt-3643 10 | Exact score: 13.672 11 | SARI score: 53.450 12 | KEEP score: 77.125 13 | ADDITION score: 21.678 14 | DELETION score: 61.546 15 | cost time 0.260398 h 16 | 继续train epoch从7.0改为30.0 17 | Saving checkpoints for 36436 into /home/wzk/Mywork/corpus/rephrase_corpus/output/models/wikisplit_experiment_name/model.ckpt. 18 | Exact score: 19.414 19 | SARI score: 56.749 20 | KEEP score: 78.622 21 | ADDITION score: 26.927 22 | DELETION score: 64.697 23 | cost time 0.260367 h 24 | model.ckpt-35643 25 | Exact score: 19.440 26 | SARI score: 56.725 27 | KEEP score: 78.620 28 | ADDITION score: 26.917 29 | DELETION score: 64.638 30 | cost time 0.259561 h 31 | Restoring parameters from /home/wzk/Mywork/corpus/rephrase_corpus/output/models/wikisplit_experiment_name/model.ckpt-62436 32 | Exact score: 20.347 33 | SARI score: 57.107 34 | KEEP score: 78.693 35 | ADDITION score: 27.379 36 | DELETION score: 65.250 37 | cost time 0.575784 h 38 | 39 | 40 | 旧的 41 | 完善大小写的处理规则 42 | Exact score: 29.283 43 | SARI score: 64.002 44 | KEEP score: 84.942 45 | ADDITION score: 38.128 46 | DELETION score: 68.935 47 | cost time 0.41796 h -------------------------------------------------------------------------------- /rephrase.sh: -------------------------------------------------------------------------------- 1 | # 扩充文本匹配的语料 文本复述任务 2 | export HOST_NAME=$1 3 | # set gpu id to use 4 | if [[ "wzk" == "$HOST_NAME" ]] 5 | then 6 | # set gpu id to use 7 | export CUDA_VISIBLE_DEVICES=0 8 | else 9 | # not use gpu 10 | export CUDA_VISIBLE_DEVICES="" 11 | fi 12 | 13 | start_tm=`date +%s%N`; 14 | 15 | 16 | ### Optional parameters ### 17 | # If you train multiple models on the same data, change this label. 18 | export EXPERIMENT="scapal_model" #"wikisplit_experiment_name_BertTiny" #  19 | 20 | # To quickly test that model training works, set the number of epochs to a 21 | # smaller value (e.g. 0.01). 22 | export num_train_epochs=30 23 | export TRAIN_BATCH_SIZE=256 24 | export PHRASE_VOCAB_SIZE=500 25 | export MAX_INPUT_EXAMPLES=1000000 26 | export SAVE_CHECKPOINT_STEPS=2000 27 | export keep_checkpoint_max=8 28 | export enable_swap_tag=false 29 | export output_arbitrary_targets_for_infeasible_examples=false 30 | export REPHRASE_DIR="/home/${HOST_NAME}/Mywork/corpus/rephrase_corpus" 31 | export BERT_BASE_DIR="/home/${HOST_NAME}/Mywork/model/RoBERTa-tiny-clue" 32 | export OUTPUT_DIR="${REPHRASE_DIR}/output" 33 | 34 | export max_seq_length=40 35 | # Check these numbers from the "*.num_examples" files created in step 2. 36 | export NUM_TRAIN_EXAMPLES=310922 37 | export NUM_EVAL_EXAMPLES=5000 38 | export CONFIG_FILE=configs/lasertagger_config.json 39 | 40 | ### Get the most recently exported model directory. 41 | TIMESTAMP=$(ls "${OUTPUT_DIR}/models/${EXPERIMENT}/export/" | \ 42 | grep -v "temp-" | sort -r | head -1) 43 | SAVED_MODEL_DIR=${OUTPUT_DIR}/models/${EXPERIMENT}/export/${TIMESTAMP} 44 | PREDICTION_FILE=${OUTPUT_DIR}/models/${EXPERIMENT}/pred.tsv 45 | 46 | 47 | #python phrase_vocabulary_optimization.py \ 48 | # --input_file=${REPHRASE_DIR}/train.txt \ 49 | # --input_format=wikisplit \ 50 | # --vocabulary_size=${PHRASE_VOCAB_SIZE} \ 51 | # --max_input_examples=1000000 \ 52 | # --enable_swap_tag=${enable_swap_tag} \ 53 | # --output_file=${OUTPUT_DIR}/label_map.txt \ 54 | # --vocab_file=${BERT_BASE_DIR}/vocab.txt 55 | # 56 | #python preprocess_main.py \ 57 | # --input_file=${REPHRASE_DIR}/tune.txt \ 58 | # --input_format=wikisplit \ 59 | # --output_tfrecord=${OUTPUT_DIR}/tune.tf_record \ 60 | # --label_map_file=${OUTPUT_DIR}/label_map.txt \ 61 | # --vocab_file=${BERT_BASE_DIR}/vocab.txt \ 62 | # --max_seq_length=${max_seq_length} \ 63 | # --enable_swap_tag=${enable_swap_tag} \ 64 | # --output_arbitrary_targets_for_infeasible_examples=${output_arbitrary_targets_for_infeasible_examples} 65 | #python preprocess_main.py \ 66 | # --input_file=${REPHRASE_DIR}/train.txt \ 67 | # --input_format=wikisplit \ 68 | # --output_tfrecord=${OUTPUT_DIR}/train.tf_record \ 69 | # --label_map_file=${OUTPUT_DIR}/label_map.txt \ 70 | # --vocab_file=${BERT_BASE_DIR}/vocab.txt \ 71 | # --max_seq_length=${max_seq_length} \ 72 | # --enable_swap_tag=${enable_swap_tag} \ 73 | # --output_arbitrary_targets_for_infeasible_examples=${output_arbitrary_targets_for_infeasible_examples} 74 | 75 | 76 | #python run_lasertagger.py \ 77 | # --training_file=${OUTPUT_DIR}/train.tf_record \ 78 | # --eval_file=${OUTPUT_DIR}/tune.tf_record \ 79 | # --label_map_file=${OUTPUT_DIR}/label_map.txt \ 80 | # --model_config_file=${CONFIG_FILE} \ 81 | # --output_dir=${OUTPUT_DIR}/models/${EXPERIMENT} \ 82 | # --init_checkpoint=${BERT_BASE_DIR}/bert_model.ckpt \ 83 | # --do_train=true \ 84 | # --do_eval=true \ 85 | # --num_train_epochs=${num_train_epochs} \ 86 | # --train_batch_size=${TRAIN_BATCH_SIZE} \ 87 | # --save_checkpoints_steps=${SAVE_CHECKPOINT_STEPS} \ 88 | # --keep_checkpoint_max=${keep_checkpoint_max} \ 89 | # --max_seq_length=${max_seq_length} \ 90 | # --num_train_examples=${NUM_TRAIN_EXAMPLES} \ 91 | # --num_eval_examples=${NUM_EVAL_EXAMPLES} 92 | 93 | 94 | ## Export the model. 95 | echo "Export the model." 96 | python run_lasertagger.py \ 97 | --label_map_file=${OUTPUT_DIR}/label_map.txt \ 98 | --model_config_file=${CONFIG_FILE} \ 99 | --output_dir=${OUTPUT_DIR}/models/${EXPERIMENT} \ 100 | --do_export=true \ 101 | --export_path=${OUTPUT_DIR}/models/${EXPERIMENT}/export 102 | # 103 | # 104 | # 105 | ### 4. Prediction 106 | #echo "predict" 107 | python predict_main.py \ 108 | --input_file=${REPHRASE_DIR}/test.txt \ 109 | --input_format=wikisplit \ 110 | --output_file=${PREDICTION_FILE} \ 111 | --label_map_file=${OUTPUT_DIR}/label_map.txt \ 112 | --vocab_file=${BERT_BASE_DIR}/vocab.txt \ 113 | --max_seq_length=${max_seq_length} \ 114 | --enable_swap_tag=${enable_swap_tag} \ 115 | --saved_model=${SAVED_MODEL_DIR} 116 | 117 | ### 5. Evaluation 118 | python score_main.py --prediction_file=${PREDICTION_FILE} 119 | 120 | 121 | end_tm=`date +%s%N`; 122 | use_tm=`echo $end_tm $start_tm | awk '{ print ($1 - $2) / 1000000000 /3600}'` 123 | echo "cost time" $use_tm "h" -------------------------------------------------------------------------------- /rephrase_for_chat.sh: -------------------------------------------------------------------------------- 1 | # 为闲聊的文本匹配语料做数据增强 2 | # 扩充文本匹配的语料 文本复述任务 3 | 4 | # set gpu id to use 5 | export CUDA_VISIBLE_DEVICES="" 6 | start_tm=`date +%s%N`; 7 | 8 | export HOST_NAME="wzk" 9 | ### Optional parameters ### 10 | 11 | # If you train multiple models on the same data, change this label. 12 | EXPERIMENT=wikisplit_experiment 13 | # To quickly test that model training works, set the number of epochs to a 14 | # smaller value (e.g. 0.01). 15 | NUM_EPOCHS=60.0 16 | export TRAIN_BATCH_SIZE=256 17 | export PHRASE_VOCAB_SIZE=500 18 | export MAX_INPUT_EXAMPLES=1000000 19 | export SAVE_CHECKPOINT_STEPS=200 20 | export enable_swap_tag=true 21 | export output_arbitrary_targets_for_infeasible_examples=false 22 | export WIKISPLIT_DIR="/home/${HOST_NAME}/Mywork/corpus/rephrase_corpus" 23 | export OUTPUT_DIR="${WIKISPLIT_DIR}/output" 24 | 25 | #python phrase_vocabulary_optimization.py \ 26 | # --input_file=${WIKISPLIT_DIR}/train.txt \ 27 | # --input_format=wikisplit \ 28 | # --vocabulary_size=500 \ 29 | # --max_input_examples=1000000 \ 30 | # --enable_swap_tag=${enable_swap_tag} \ 31 | # --output_file=${OUTPUT_DIR}/label_map.txt 32 | 33 | 34 | export max_seq_length=40 # TODO 35 | export BERT_BASE_DIR="/home/${HOST_NAME}/Mywork/model/RoBERTa-tiny-clue" # chinese_L-12_H-768_A-12" 36 | 37 | #python preprocess_main.py \ 38 | # --input_file=${WIKISPLIT_DIR}/tune.txt \ 39 | # --input_format=wikisplit \ 40 | # --output_tfrecord=${OUTPUT_DIR}/tune.tf_record \ 41 | # --label_map_file=${OUTPUT_DIR}/label_map.txt \ 42 | # --vocab_file=${BERT_BASE_DIR}/vocab.txt \ 43 | # --max_seq_length=${max_seq_length} \ 44 | # --enable_swap_tag=${enable_swap_tag} \ 45 | # --output_arbitrary_targets_for_infeasible_examples=${output_arbitrary_targets_for_infeasible_examples} # TODO true 46 | # 47 | #python preprocess_main.py \ 48 | # --input_file=${WIKISPLIT_DIR}/train.txt \ 49 | # --input_format=wikisplit \ 50 | # --output_tfrecord=${OUTPUT_DIR}/train.tf_record \ 51 | # --label_map_file=${OUTPUT_DIR}/label_map.txt \ 52 | # --vocab_file=${BERT_BASE_DIR}/vocab.txt \ 53 | # --max_seq_length=${max_seq_length} \ 54 | # --enable_swap_tag=${enable_swap_tag} \ 55 | # --output_arbitrary_targets_for_infeasible_examples=${output_arbitrary_targets_for_infeasible_examples} # TODO false 56 | 57 | 58 | 59 | # Check these numbers from the "*.num_examples" files created in step 2. 60 | export NUM_TRAIN_EXAMPLES=310922 61 | export NUM_EVAL_EXAMPLES=5000 62 | export CONFIG_FILE=configs/lasertagger_config.json 63 | export EXPERIMENT=wikisplit_experiment_name 64 | 65 | 66 | 67 | #python run_lasertagger.py \ 68 | # --training_file=${OUTPUT_DIR}/train.tf_record \ 69 | # --eval_file=${OUTPUT_DIR}/tune.tf_record \ 70 | # --label_map_file=${OUTPUT_DIR}/label_map.txt \ 71 | # --model_config_file=${CONFIG_FILE} \ 72 | # --output_dir=${OUTPUT_DIR}/models/${EXPERIMENT} \ 73 | # --init_checkpoint=${BERT_BASE_DIR}/bert_model.ckpt \ 74 | # --do_train=true \ 75 | # --do_eval=true \ 76 | # --train_batch_size=${TRAIN_BATCH_SIZE} \ 77 | # --save_checkpoints_steps=200 \ 78 | # --max_seq_length=${max_seq_length} \ 79 | # --num_train_examples=${NUM_TRAIN_EXAMPLES} \ 80 | # --num_eval_examples=${NUM_EVAL_EXAMPLES} 81 | 82 | #CUDA_VISIBLE_DEVICES="" nohup python run_lasertagger.py \ 83 | # --training_file=${OUTPUT_DIR}/train.tf_record \ 84 | # --eval_file=${OUTPUT_DIR}/tune.tf_record \ 85 | # --label_map_file=${OUTPUT_DIR}/label_map.txt \ 86 | # --model_config_file=${CONFIG_FILE} \ 87 | # --output_dir=${OUTPUT_DIR}/models/${EXPERIMENT} \ 88 | # --init_checkpoint=${BERT_BASE_DIR}/bert_model.ckpt \ 89 | # --do_train=true \ 90 | # --do_eval=true \ 91 | # --train_batch_size=${TRAIN_BATCH_SIZE} \ 92 | # --save_checkpoints_steps=${SAVE_CHECKPOINT_STEPS} \ 93 | # --num_train_epochs=${NUM_EPOCHS} \ 94 | # --max_seq_length=${max_seq_length} \ 95 | # --num_train_examples=${NUM_TRAIN_EXAMPLES} \ 96 | # --num_eval_examples=${NUM_EVAL_EXAMPLES} > log.txt 2>&1 & 97 | 98 | 99 | ### 4. Prediction 100 | 101 | # Export the model. 102 | python run_lasertagger.py \ 103 | --label_map_file=${OUTPUT_DIR}/label_map.txt \ 104 | --model_config_file=${CONFIG_FILE} \ 105 | --output_dir=${OUTPUT_DIR}/models/${EXPERIMENT} \ 106 | --do_export=true \ 107 | --export_path=${OUTPUT_DIR}/models/${EXPERIMENT}/export 108 | # 109 | ### Get the most recently exported model directory. 110 | TIMESTAMP=$(ls "${OUTPUT_DIR}/models/${EXPERIMENT}/export/" | \ 111 | grep -v "temp-" | sort -r | head -1) 112 | SAVED_MODEL_DIR=${OUTPUT_DIR}/models/${EXPERIMENT}/export/${TIMESTAMP} 113 | PREDICTION_FILE=${OUTPUT_DIR}/models/${EXPERIMENT}/expand_simq_v4.0.json 114 | 115 | python chat_rephrase/predict_for_chat.py \ 116 | --input_file=/home/${HOST_NAME}/Mywork/corpus/闲聊/simq_v4.0.json \ 117 | --input_format=wikisplit \ 118 | --output_file=${PREDICTION_FILE} \ 119 | --label_map_file=${OUTPUT_DIR}/label_map.txt \ 120 | --vocab_file=${BERT_BASE_DIR}/vocab.txt \ 121 | --max_seq_length=${max_seq_length} \ 122 | --enable_swap_tag=${enable_swap_tag} \ 123 | --saved_model=${SAVED_MODEL_DIR} 124 | 125 | # downloag file of pred_qa.tsv, give shixi 126 | 127 | #[predict_for_qa.py:166] 238766 predictions saved to:/home/wzk/Mywork/corpus/rephrase_corpus/output/models/wikisplit_experiment_name/pred_qa.tsv, cost 332.37216806411743 min, ave 83.52248680233805ms. 128 | #cost time 5.54552 h 129 | 130 | 131 | 132 | #### 5. Evaluation 133 | #python score_main.py --prediction_file=${PREDICTION_FILE} 134 | 135 | 136 | end_tm=`date +%s%N`; 137 | use_tm=`echo $end_tm $start_tm | awk '{ print ($1 - $2) / 1000000000 /3600}'` 138 | echo "cost time" $use_tm "h" -------------------------------------------------------------------------------- /rephrase_for_qa.sh: -------------------------------------------------------------------------------- 1 | # 扩充文本匹配的语料 文本复述任务 2 | # set gpu id to use 3 | export CUDA_VISIBLE_DEVICES="" 4 | 5 | 6 | start_tm=`date +%s%N`; 7 | 8 | export HOST_NAME="wzk" #"cloudminds" #    9 | ### Optional parameters ### 10 | 11 | # If you train multiple models on the same data, change this label. 12 | EXPERIMENT=wikisplit_experiment 13 | # To quickly test that model training works, set the number of epochs to a 14 | # smaller value (e.g. 0.01). 15 | NUM_EPOCHS=60.0 16 | export TRAIN_BATCH_SIZE=256 # 512 OOM 256 OK 17 | PHRASE_VOCAB_SIZE=500 18 | MAX_INPUT_EXAMPLES=1000000 19 | SAVE_CHECKPOINT_STEPS=200 20 | export enable_swap_tag=true 21 | export output_arbitrary_targets_for_infeasible_examples=false 22 | export WIKISPLIT_DIR="/home/${HOST_NAME}/Mywork/corpus/rephrase_corpus" 23 | export OUTPUT_DIR="${WIKISPLIT_DIR}/output" 24 | 25 | #python phrase_vocabulary_optimization.py \ 26 | # --input_file=${WIKISPLIT_DIR}/train.txt \ 27 | # --input_format=wikisplit \ 28 | # --vocabulary_size=500 \ 29 | # --max_input_examples=1000000 \ 30 | # --enable_swap_tag=${enable_swap_tag} \ 31 | # --output_file=${OUTPUT_DIR}/label_map.txt 32 | 33 | 34 | export max_seq_length=40 # TODO 35 | export BERT_BASE_DIR="/home/${HOST_NAME}/Mywork/model/RoBERTa-tiny-clue" # chinese_L-12_H-768_A-12" 36 | 37 | #python preprocess_main.py \ 38 | # --input_file=${WIKISPLIT_DIR}/tune.txt \ 39 | # --input_format=wikisplit \ 40 | # --output_tfrecord=${OUTPUT_DIR}/tune.tf_record \ 41 | # --label_map_file=${OUTPUT_DIR}/label_map.txt \ 42 | # --vocab_file=${BERT_BASE_DIR}/vocab.txt \ 43 | # --max_seq_length=${max_seq_length} \ 44 | # --enable_swap_tag=${enable_swap_tag} \ 45 | # --output_arbitrary_targets_for_infeasible_examples=${output_arbitrary_targets_for_infeasible_examples} # TODO true 46 | # 47 | #python preprocess_main.py \ 48 | # --input_file=${WIKISPLIT_DIR}/train.txt \ 49 | # --input_format=wikisplit \ 50 | # --output_tfrecord=${OUTPUT_DIR}/train.tf_record \ 51 | # --label_map_file=${OUTPUT_DIR}/label_map.txt \ 52 | # --vocab_file=${BERT_BASE_DIR}/vocab.txt \ 53 | # --max_seq_length=${max_seq_length} \ 54 | # --enable_swap_tag=${enable_swap_tag} \ 55 | # --output_arbitrary_targets_for_infeasible_examples=${output_arbitrary_targets_for_infeasible_examples} # TODO false 56 | 57 | 58 | 59 | # Check these numbers from the "*.num_examples" files created in step 2. 60 | export NUM_TRAIN_EXAMPLES=310922 61 | export NUM_EVAL_EXAMPLES=5000 62 | export CONFIG_FILE=configs/lasertagger_config.json 63 | export EXPERIMENT=wikisplit_experiment_name 64 | 65 | 66 | 67 | #python run_lasertagger.py \ 68 | # --training_file=${OUTPUT_DIR}/train.tf_record \ 69 | # --eval_file=${OUTPUT_DIR}/tune.tf_record \ 70 | # --label_map_file=${OUTPUT_DIR}/label_map.txt \ 71 | # --model_config_file=${CONFIG_FILE} \ 72 | # --output_dir=${OUTPUT_DIR}/models/${EXPERIMENT} \ 73 | # --init_checkpoint=${BERT_BASE_DIR}/bert_model.ckpt \ 74 | # --do_train=true \ 75 | # --do_eval=true \ 76 | # --train_batch_size=${TRAIN_BATCH_SIZE} \ 77 | # --save_checkpoints_steps=200 \ 78 | # --max_seq_length=${max_seq_length} \ 79 | # --num_train_examples=${NUM_TRAIN_EXAMPLES} \ 80 | # --num_eval_examples=${NUM_EVAL_EXAMPLES} 81 | 82 | #CUDA_VISIBLE_DEVICES="" nohup python run_lasertagger.py \ 83 | # --training_file=${OUTPUT_DIR}/train.tf_record \ 84 | # --eval_file=${OUTPUT_DIR}/tune.tf_record \ 85 | # --label_map_file=${OUTPUT_DIR}/label_map.txt \ 86 | # --model_config_file=${CONFIG_FILE} \ 87 | # --output_dir=${OUTPUT_DIR}/models/${EXPERIMENT} \ 88 | # --init_checkpoint=${BERT_BASE_DIR}/bert_model.ckpt \ 89 | # --do_train=true \ 90 | # --do_eval=true \ 91 | # --train_batch_size=${TRAIN_BATCH_SIZE} \ 92 | # --save_checkpoints_steps=${SAVE_CHECKPOINT_STEPS} \ 93 | # --num_train_epochs=${NUM_EPOCHS} \ 94 | # --max_seq_length=${max_seq_length} \ 95 | # --num_train_examples=${NUM_TRAIN_EXAMPLES} \ 96 | # --num_eval_examples=${NUM_EVAL_EXAMPLES} > log.txt 2>&1 & 97 | 98 | 99 | ### 4. Prediction 100 | 101 | # Export the model. 102 | python run_lasertagger.py \ 103 | --label_map_file=${OUTPUT_DIR}/label_map.txt \ 104 | --model_config_file=${CONFIG_FILE} \ 105 | --output_dir=${OUTPUT_DIR}/models/${EXPERIMENT} \ 106 | --do_export=true \ 107 | --export_path=${OUTPUT_DIR}/models/${EXPERIMENT}/export 108 | # 109 | ### Get the most recently exported model directory. 110 | TIMESTAMP=$(ls "${OUTPUT_DIR}/models/${EXPERIMENT}/export/" | \ 111 | grep -v "temp-" | sort -r | head -1) 112 | SAVED_MODEL_DIR=${OUTPUT_DIR}/models/${EXPERIMENT}/export/${TIMESTAMP} 113 | PREDICTION_FILE=${OUTPUT_DIR}/models/${EXPERIMENT}/pred_qa.txt 114 | 115 | python qa_rephrase/predict_for_qa.py \ 116 | --input_file=/home/${HOST_NAME}/Mywork/corpus/Chinese_QA/LCQMC/train.txt \ 117 | --input_format=wikisplit \ 118 | --output_file=${PREDICTION_FILE} \ 119 | --label_map_file=${OUTPUT_DIR}/label_map.txt \ 120 | --vocab_file=${BERT_BASE_DIR}/vocab.txt \ 121 | --max_seq_length=${max_seq_length} \ 122 | --enable_swap_tag=${enable_swap_tag} \ 123 | --saved_model=${SAVED_MODEL_DIR} 124 | 125 | # downloag file of pred_qa.tsv, give shixi 126 | 127 | #[predict_for_qa.py:166] 238766 predictions saved to:/home/wzk/Mywork/corpus/rephrase_corpus/output/models/wikisplit_experiment_name/pred_qa.tsv, cost 332.37216806411743 min, ave 83.52248680233805ms. 128 | #cost time 5.54552 h 129 | 130 | 131 | 132 | #### 5. Evaluation 133 | #python score_main.py --prediction_file=${PREDICTION_FILE} 134 | 135 | 136 | end_tm=`date +%s%N`; 137 | use_tm=`echo $end_tm $start_tm | awk '{ print ($1 - $2) / 1000000000 /3600}'` 138 | echo "cost time" $use_tm "h" -------------------------------------------------------------------------------- /rephrase_for_skill.sh: -------------------------------------------------------------------------------- 1 | # 扩充技能的语料 2 | # rephrase_for_skill.sh: 在rephrase.sh基础上改的 3 | # predict_for_skill.py: 在 predict_main.py基础上改的 4 | # score_for_skill.txt 结果对比 5 | 6 | # set gpu id to use 7 | export CUDA_VISIBLE_DEVICES="" 8 | 9 | start_tm=`date +%s%N`; 10 | 11 | export HOST_NAME="wzk" 12 | ### Optional parameters ### 13 | 14 | # If you train multiple models on the same data, change this label. 15 | EXPERIMENT=wikisplit_experiment 16 | # To quickly test that model training works, set the number of epochs to a 17 | # smaller value (e.g. 0.01). 18 | NUM_EPOCHS=10.0 19 | export TRAIN_BATCH_SIZE=256 20 | export PHRASE_VOCAB_SIZE=500 21 | export MAX_INPUT_EXAMPLES=1000000 22 | export SAVE_CHECKPOINT_STEPS=200 23 | export enable_swap_tag=false 24 | export output_arbitrary_targets_for_infeasible_examples=false 25 | export WIKISPLIT_DIR="/home/${HOST_NAME}/Mywork/corpus/rephrase_corpus" 26 | export OUTPUT_DIR="${WIKISPLIT_DIR}/output" 27 | 28 | 29 | 30 | export max_seq_length=40 # TODO 31 | export BERT_BASE_DIR="/home/${HOST_NAME}/Mywork/model/RoBERTa-tiny-clue" # chinese_L-12_H-768_A-12" 32 | 33 | 34 | 35 | 36 | # Check these numbers from the "*.num_examples" files created in step 2. 37 | export CONFIG_FILE=configs/lasertagger_config.json 38 | export EXPERIMENT=wikisplit_experiment_name 39 | 40 | 41 | 42 | ### 4. Prediction 43 | 44 | # Export the model. 45 | python run_lasertagger.py \ 46 | --label_map_file=${OUTPUT_DIR}/label_map.txt \ 47 | --model_config_file=${CONFIG_FILE} \ 48 | --output_dir=${OUTPUT_DIR}/models/${EXPERIMENT} \ 49 | --do_export=true \ 50 | --export_path=${OUTPUT_DIR}/models/${EXPERIMENT}/export 51 | 52 | ## Get the most recently exported model directory. 53 | TIMESTAMP=$(ls "${OUTPUT_DIR}/models/${EXPERIMENT}/export/" | \ 54 | grep -v "temp-" | sort -r | head -1) 55 | SAVED_MODEL_DIR=${OUTPUT_DIR}/models/${EXPERIMENT}/export/${TIMESTAMP} 56 | PREDICTION_FILE=${OUTPUT_DIR}/models/${EXPERIMENT}/pred.tsv 57 | export domain_name=times 58 | python skill_rephrase/predict_for_skill.py \ 59 | --input_file=/home/${HOST_NAME}/Mywork/corpus/ner_corpus/times_corpus/slot_times.txt \ 60 | --input_format=wikisplit \ 61 | --output_file=/home/${HOST_NAME}/Mywork/corpus/ner_corpus/times_corpus/slot_times_expandexpand.json \ 62 | --label_map_file=${OUTPUT_DIR}/label_map.txt \ 63 | --vocab_file=${BERT_BASE_DIR}/vocab.txt \ 64 | --max_seq_length=${max_seq_length} \ 65 | --saved_model=${SAVED_MODEL_DIR} 66 | 67 | #### 5. Evaluation 68 | #python score_main.py --prediction_file=${PREDICTION_FILE} 69 | 70 | 71 | end_tm=`date +%s%N`; 72 | use_tm=`echo $end_tm $start_tm | awk '{ print ($1 - $2) / 1000000000 /3600}'` 73 | echo "cost time" $use_tm "h" -------------------------------------------------------------------------------- /rephrase_server.sh: -------------------------------------------------------------------------------- 1 | # 文本复述(rephrase)的服务 2 | start_tm=`date +%s%N`; 3 | 4 | export HOST_NAME="wzk" 5 | ### Optional parameters ### 6 | 7 | # If you train multiple models on the same data, change this label. 8 | EXPERIMENT=wikisplit_experiment 9 | # To quickly test that model training works, set the number of epochs to a 10 | # smaller value (e.g. 0.01). 11 | 12 | export TRAIN_BATCH_SIZE=256 13 | export PHRASE_VOCAB_SIZE=500 14 | export MAX_INPUT_EXAMPLES=1000000 15 | export SAVE_CHECKPOINT_STEPS=200 16 | export enable_swap_tag=false 17 | export output_arbitrary_targets_for_infeasible_examples=false 18 | export WIKISPLIT_DIR="/home/${HOST_NAME}/Mywork/corpus/rephrase_corpus" 19 | export BERT_BASE_DIR="/home/${HOST_NAME}/Mywork/model/RoBERTa-tiny-clue" # chinese_L-12_H-768_A-12" 20 | export OUTPUT_DIR="${WIKISPLIT_DIR}/output" 21 | 22 | #python phrase_vocabulary_optimization.py \ 23 | # --input_file=${WIKISPLIT_DIR}/train.txt \ 24 | # --input_format=wikisplit \ 25 | # --vocabulary_size=500 \ 26 | # --max_input_examples=1000000 \ 27 | # --enable_swap_tag=${enable_swap_tag} \ 28 | # --output_file=${OUTPUT_DIR}/label_map.txt 29 | 30 | 31 | export max_seq_length=40 32 | 33 | #python preprocess_main.py \ 34 | # --input_file=${WIKISPLIT_DIR}/tune.txt \ 35 | # --input_format=wikisplit \ 36 | # --output_tfrecord=${OUTPUT_DIR}/tune.tf_record \ 37 | # --label_map_file=${OUTPUT_DIR}/label_map.txt \ 38 | # --vocab_file=${BERT_BASE_DIR}/vocab.txt \ 39 | # --max_seq_length=${max_seq_length} \ 40 | # --output_arbitrary_targets_for_infeasible_examples=${output_arbitrary_targets_for_infeasible_examples} # TODO true 41 | # 42 | #python preprocess_main.py \ 43 | # --input_file=${WIKISPLIT_DIR}/train.txt \ 44 | # --input_format=wikisplit \ 45 | # --output_tfrecord=${OUTPUT_DIR}/train.tf_record \ 46 | # --label_map_file=${OUTPUT_DIR}/label_map.txt \ 47 | # --vocab_file=${BERT_BASE_DIR}/vocab.txt \ 48 | # --max_seq_length=${max_seq_length} \ 49 | # --output_arbitrary_targets_for_infeasible_examples=${output_arbitrary_targets_for_infeasible_examples} # TODO false 50 | 51 | 52 | 53 | # Check these numbers from the "*.num_examples" files created in step 2. 54 | export NUM_TRAIN_EXAMPLES=310922 55 | export NUM_EVAL_EXAMPLES=5000 56 | export CONFIG_FILE=configs/lasertagger_config.json 57 | export EXPERIMENT=wikisplit_experiment_name 58 | 59 | 60 | ### 4. Prediction 61 | 62 | # Export the model. 63 | #python run_lasertagger.py \ 64 | # --label_map_file=${OUTPUT_DIR}/label_map.txt \ 65 | # --model_config_file=${CONFIG_FILE} \ 66 | # --output_dir=${OUTPUT_DIR}/models/${EXPERIMENT} \ 67 | # --do_export=true \ 68 | # --export_path=${OUTPUT_DIR}/models/${EXPERIMENT}/export 69 | 70 | ## Get the most recently exported model directory. 71 | TIMESTAMP=$(ls "${OUTPUT_DIR}/models/${EXPERIMENT}/export/" | \ 72 | grep -v "temp-" | sort -r | head -1) 73 | SAVED_MODEL_DIR=${OUTPUT_DIR}/models/${EXPERIMENT}/export/${TIMESTAMP} 74 | label_map_file=${OUTPUT_DIR}/label_map.txt 75 | export host="0.0.0.0" 76 | export port=6000 77 | 78 | # start server 79 | python rephrase_server/rephrase_server_flask.py \ 80 | --input_file=${WIKISPLIT_DIR}/test.txt \ 81 | --input_format=wikisplit \ 82 | --output_file=${PREDICTION_FILE} \ 83 | --label_map_file=${OUTPUT_DIR}/label_map.txt \ 84 | --vocab_file=${BERT_BASE_DIR}/vocab.txt \ 85 | --max_seq_length=${max_seq_length} \ 86 | --saved_model=${SAVED_MODEL_DIR} \ 87 | --host=${host} \ 88 | --port=${port} 89 | 90 | end_tm=`date +%s%N`; 91 | use_tm=`echo $end_tm $start_tm | awk '{ print ($1 - $2) / 1000000000 /3600}'` 92 | echo "cost time" $use_tm "h" -------------------------------------------------------------------------------- /rephrase_server/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Mleader2/text_scalpel/131922c5dcfc558a46a7590925e348c39cb24951/rephrase_server/__init__.py -------------------------------------------------------------------------------- /rephrase_server/rephrase_server_flask.py: -------------------------------------------------------------------------------- 1 | # 文本复述服务 基于tensorflow框架 2 | import os, sys 3 | from absl import flags 4 | from absl import logging 5 | import numpy as np 6 | import json 7 | import logging 8 | # http 接口 9 | from flask import Flask, jsonify, request 10 | app = Flask(__name__) 11 | 12 | logger = logging.getLogger('log') 13 | logger.setLevel(logging.DEBUG) 14 | 15 | while logger.hasHandlers(): 16 | for i in logger.handlers: 17 | logger.removeHandler(i) 18 | 19 | user_name = "" # wzk/ 20 | version="1.0.0.0" 21 | 22 | 23 | block_list = os.path.realpath(__file__).split("/") 24 | path = "/".join(block_list[:-2]) 25 | sys.path.append(path) 26 | 27 | 28 | import bert_example 29 | import predict_utils 30 | import tagging_converter 31 | import utils 32 | import tensorflow as tf 33 | # FLAGS = flags.FLAGS 34 | FLAGS = tf.app.flags.FLAGS 35 | flags.DEFINE_string( 36 | 'input_file', None, 37 | 'Path to the input file containing examples for which to compute ' 38 | 'predictions.') 39 | flags.DEFINE_enum( 40 | 'input_format', None, ['wikisplit', 'discofuse'], 41 | 'Format which indicates how to parse the input_file.') 42 | flags.DEFINE_string( 43 | 'output_file', None, 44 | 'Path to the TSV file where the predictions are written to.') 45 | flags.DEFINE_string( 46 | 'label_map_file', None, 47 | 'Path to the label map file. Either a JSON file ending with ".json", that ' 48 | 'maps each possible tag to an ID, or a text file that has one tag per ' 49 | 'line.') 50 | flags.DEFINE_string('vocab_file', None, 'Path to the BERT vocabulary file.') 51 | flags.DEFINE_integer('max_seq_length', 40, 'Maximum sequence length.') 52 | flags.DEFINE_bool( 53 | 'do_lower_case', True, 54 | 'Whether to lower case the input text. Should be True for uncased ' 55 | 'models and False for cased models.') 56 | flags.DEFINE_bool('enable_swap_tag', True, 'Whether to enable the SWAP tag.') 57 | flags.DEFINE_string('saved_model', None, 'Path to an exported TF model.') 58 | 59 | flags.DEFINE_string('host', None, 'host address.') 60 | flags.DEFINE_integer('port', None, 'port.') 61 | 62 | 63 | class RequestHandler(): 64 | def __init__(self): 65 | # if len(argv) > 1: 66 | # raise app.UsageError('Too many command-line arguments.') 67 | flags.mark_flag_as_required('input_file') 68 | flags.mark_flag_as_required('input_format') 69 | flags.mark_flag_as_required('output_file') 70 | flags.mark_flag_as_required('label_map_file') 71 | flags.mark_flag_as_required('vocab_file') 72 | flags.mark_flag_as_required('saved_model') 73 | 74 | label_map = utils.read_label_map(FLAGS.label_map_file) 75 | converter = tagging_converter.TaggingConverter( 76 | tagging_converter.get_phrase_vocabulary_from_label_map(label_map), 77 | FLAGS.enable_swap_tag) 78 | builder = bert_example.BertExampleBuilder(label_map, FLAGS.vocab_file, 79 | FLAGS.max_seq_length, 80 | FLAGS.do_lower_case, converter) 81 | self.predictor = predict_utils.LaserTaggerPredictor( 82 | tf.contrib.predictor.from_saved_model(FLAGS.saved_model), builder, 83 | label_map) 84 | 85 | def infer(self, sources_batch): 86 | prediction_batch = self.predictor.predict_batch(sources_batch=sources_batch) 87 | assert len(prediction_batch) == len(sources_batch) 88 | 89 | return prediction_batch 90 | 91 | class MyEncoder(json.JSONEncoder): 92 | def default(self, obj): 93 | if isinstance(obj, np.integer): 94 | return int(obj) 95 | elif isinstance(obj, np.floating): 96 | return float(obj) 97 | elif isinstance(obj, np.ndarray): 98 | return obj.tolist() 99 | else: 100 | return super(MyEncoder, self).default(obj) 101 | 102 | # http POST 接口 103 | @app.route('/rephrase', methods=['POST']) # 推理 104 | def returnPost(): 105 | data_json = {'version': version} 106 | try: 107 | question_raw_batch = request.json['text_list'] 108 | except TypeError: 109 | data_json['status'] = -1 110 | data_json['message'] = "Fail:input information type error" 111 | data_json['data'] = [] 112 | return json.dumps(data_json, cls=MyEncoder, ensure_ascii=False) 113 | question_batch = [] 114 | for question_raw in question_raw_batch: 115 | question_batch.append([question_raw.strip()]) 116 | decoded_output = rHandler.infer(question_batch) 117 | if len(decoded_output) == 0: 118 | data_json['status'] = -1 119 | data_json['message'] = "Fail: fail to get the retell of the text." 120 | data_json['data'] = [] 121 | return json.dumps(data_json, cls=MyEncoder, ensure_ascii=False) 122 | data_json['status'] = 0 123 | data_json['message'] = "Success" 124 | data_json['data'] = {} 125 | data_json['data']['output'] = decoded_output 126 | return json.dumps(data_json, cls=MyEncoder, ensure_ascii=False) 127 | 128 | rHandler = RequestHandler() 129 | 130 | if __name__ == '__main__': 131 | 132 | 133 | app.run(host=FLAGS.host, port=FLAGS.port) 134 | -------------------------------------------------------------------------------- /rephrase_server/test_server.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | #调用服务 3 | import time, math 4 | import json 5 | import requests 6 | import tensorflow as tf 7 | from curLine_file import curLine 8 | 9 | user_name="wzk" 10 | test_num_max=300000 11 | language='Chinese' 12 | # 此URL中IP地址为参赛者的服务器地址,应为可达的公网IP地址,端口默认21628 13 | url = "http://0.0.0.0:6000/rephrase" 14 | 15 | # 调用API接口 16 | def http_post(sources_batch): 17 | parameter = {'text_list': sources_batch} 18 | headers = {'Content-type': 'application/json'} 19 | status = -1 20 | output = None 21 | try: 22 | r = requests.post(url, data=json.dumps( 23 | parameter), headers=headers, timeout=10.5) 24 | if r.status_code == 200: 25 | result = r.json() 26 | # print(curLine(),result) 27 | status = result['status'] 28 | version = result['version'] 29 | if status == 0: 30 | data = result["data"] 31 | output = data['output'] 32 | else: 33 | print(curLine(), "version:%s, status=%d, message:%s" % (version, status, result['message'])) 34 | else: 35 | print("%sraise wrong,status_code: " % (curLine()), r.status_code) 36 | except Exception as e: 37 | print(curLine(), Exception, ' : ', e) 38 | input(curLine()) 39 | return status, output 40 | 41 | def test(): 42 | """ 43 | 此函数为测试函数,将sh运行在服务器端后,用该程序在另一网络测试 44 | This function is a test function. 45 | Run this function for test in a network while ServerDemo.py is running on a server in a different network 46 | """ 47 | sources_list = [] 48 | target_list = [] 49 | output_file = "/home/cloudminds/Mywork/corpus/rephrase_corpus/pred.tsv" 50 | input_file= "/home/cloudminds/Mywork/corpus/rephrase_corpus/test.txt" 51 | with tf.io.gfile.GFile(input_file) as f: 52 | for line in f: 53 | sources, target, lcs_rate = line.rstrip('\n').split('\t') 54 | sources_list.append(sources) # [sources]) 55 | target_list.append(target) 56 | number = len(target_list) # 总样本数 57 | predict_batch_size = min(64, number) # TODO 58 | batch_num = math.ceil(float(number)/predict_batch_size) 59 | num_predicted = 0 60 | with open(output_file, 'w') as writer: 61 | writer.write(f'source\tprediction\ttarget\n') 62 | start_time = time.time() 63 | for batch_id in range(batch_num): 64 | sources_batch = sources_list[batch_id * predict_batch_size: (batch_id + 1)*predict_batch_size] 65 | # prediction_batch = predictor.predict_batch(sources_batch=sources_batch) 66 | status, prediction_batch = http_post(sources_batch) 67 | assert len(prediction_batch) == len(sources_batch) 68 | num_predicted += len(prediction_batch) 69 | for id,[prediction,sources] in enumerate(zip(prediction_batch, sources_batch)): 70 | target = target_list[batch_id * predict_batch_size + id] 71 | writer.write(f'{"".join(sources)}\t{prediction}\t{target}\n') 72 | if batch_id % 20 == 0: 73 | cost_time = (time.time()-start_time)/60.0 74 | print("%s batch_id=%d/%d, predict %d/%d examples, cost %.2fmin." % 75 | (curLine(), batch_id+1, batch_num, num_predicted,number, cost_time)) 76 | cost_time = (time.time() - start_time) / 60.0 77 | print(curLine(), "%d predictions saved to %s, cost %f min, ave %f min." 78 | % (num_predicted, output_file, cost_time, cost_time / num_predicted)) 79 | 80 | 81 | if __name__ == '__main__': 82 | rougeL_ave=test() 83 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | tensorflow==1.15.0 # CPU Version of TensorFlow. 2 | # tensorflow-gpu==1.15.0 # GPU version of TensorFlow. 3 | absl-py==0.8.1 4 | astor==0.8.0 5 | bert-tensorflow==1.0.1 6 | gast==0.2.2 7 | google-pasta==0.1.7 8 | grpcio==1.24.3 9 | h5py==2.10.0 10 | Keras-Applications==1.0.8 11 | Keras-Preprocessing==1.1.0 12 | Markdown==3.1.1 13 | numpy==1.17.3 14 | pkg-resources==0.0.0 15 | protobuf==3.10.0 16 | scipy==1.3.1 17 | six==1.12.0 18 | tensorboard==1.15.0 19 | tensorflow-estimator==1.15.1 20 | termcolor==1.1.0 21 | Werkzeug==0.16.0 22 | wrapt==1.11.2 23 | flask 24 | -------------------------------------------------------------------------------- /sari_hook.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2019 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """SARI score for evaluating paraphrasing and other text generation models. 17 | 18 | The score is introduced in the following paper: 19 | 20 | Optimizing Statistical Machine Translation for Text Simplification 21 | Wei Xu, Courtney Napoles, Ellie Pavlick, Quanze Chen and Chris Callison-Burch 22 | In Transactions of the Association for Computational Linguistics (TACL) 2015 23 | http://cs.jhu.edu/~napoles/res/tacl2016-optimizing.pdf 24 | 25 | This implementation has two differences with the GitHub [1] implementation: 26 | (1) Define 0/0=1 instead of 0 to give higher scores for predictions that match 27 | a target exactly. 28 | (2) Fix an alleged bug [2] in the deletion score computation. 29 | 30 | [1] https://github.com/cocoxu/simplification/blob/master/SARI.py 31 | (commit 0210f15) 32 | [2] https://github.com/cocoxu/simplification/issues/6 33 | """ 34 | 35 | from __future__ import absolute_import 36 | from __future__ import division 37 | from __future__ import print_function 38 | 39 | import collections 40 | 41 | import numpy as np 42 | import tensorflow as tf 43 | 44 | # The paper that intoduces the SARI score uses only the precision of the deleted 45 | # tokens (i.e. beta=0). To give more emphasis on recall, you may set, e.g., 46 | # beta=1. 47 | BETA_FOR_SARI_DELETION_F_MEASURE = 0 48 | 49 | 50 | def _get_ngram_counter(ids, n): 51 | """Get a Counter with the ngrams of the given ID list. 52 | 53 | Args: 54 | ids: np.array or a list corresponding to a single sentence 55 | n: n-gram size 56 | 57 | Returns: 58 | collections.Counter with ID tuples as keys and 1s as values. 59 | """ 60 | # Remove zero IDs used to pad the sequence. 61 | ids = [token_id for token_id in ids if token_id != 0] 62 | ngram_list = [tuple(ids[i:i + n]) for i in range(len(ids) + 1 - n)] 63 | ngrams = set(ngram_list) 64 | counts = collections.Counter() 65 | for ngram in ngrams: 66 | counts[ngram] = 1 67 | return counts 68 | 69 | 70 | def _get_fbeta_score(true_positives, selected, relevant, beta=1): 71 | """Compute Fbeta score. 72 | 73 | Args: 74 | true_positives: Number of true positive ngrams. 75 | selected: Number of selected ngrams. 76 | relevant: Number of relevant ngrams. 77 | beta: 0 gives precision only, 1 gives F1 score, and Inf gives recall only. 78 | 79 | Returns: 80 | Fbeta score. 81 | """ 82 | precision = 1 83 | if selected > 0: 84 | precision = true_positives / selected 85 | if beta == 0: 86 | return precision 87 | recall = 1 88 | if relevant > 0: 89 | recall = true_positives / relevant 90 | if precision > 0 and recall > 0: 91 | beta2 = beta * beta 92 | return (1 + beta2) * precision * recall / (beta2 * precision + recall) 93 | else: 94 | return 0 95 | 96 | 97 | def get_addition_score(source_counts, prediction_counts, target_counts): 98 | """Compute the addition score (Equation 4 in the paper).""" 99 | added_to_prediction_counts = prediction_counts - source_counts 100 | true_positives = sum((added_to_prediction_counts & target_counts).values()) 101 | selected = sum(added_to_prediction_counts.values()) 102 | # Note that in the paper the summation is done over all the ngrams in the 103 | # output rather than the ngrams in the following set difference. Since the 104 | # former does not make as much sense we compute the latter, which is also done 105 | # in the GitHub implementation. 106 | relevant = sum((target_counts - source_counts).values()) 107 | return _get_fbeta_score(true_positives, selected, relevant) 108 | 109 | 110 | def get_keep_score(source_counts, prediction_counts, target_counts): 111 | """Compute the keep score (Equation 5 in the paper).""" 112 | source_and_prediction_counts = source_counts & prediction_counts 113 | source_and_target_counts = source_counts & target_counts 114 | true_positives = sum((source_and_prediction_counts & 115 | source_and_target_counts).values()) 116 | selected = sum(source_and_prediction_counts.values()) 117 | relevant = sum(source_and_target_counts.values()) 118 | return _get_fbeta_score(true_positives, selected, relevant) 119 | 120 | 121 | def get_deletion_score(source_counts, prediction_counts, target_counts, beta=0): 122 | """Compute the deletion score (Equation 6 in the paper).""" 123 | source_not_prediction_counts = source_counts - prediction_counts 124 | source_not_target_counts = source_counts - target_counts 125 | true_positives = sum((source_not_prediction_counts & 126 | source_not_target_counts).values()) 127 | selected = sum(source_not_prediction_counts.values()) 128 | relevant = sum(source_not_target_counts.values()) 129 | return _get_fbeta_score(true_positives, selected, relevant, beta=beta) 130 | 131 | 132 | def get_sari_score(source_ids, prediction_ids, list_of_targets, 133 | max_gram_size=4, beta_for_deletion=0): 134 | """Compute the SARI score for a single prediction and one or more targets. 135 | 136 | Args: 137 | source_ids: a list / np.array of SentencePiece IDs 138 | prediction_ids: a list / np.array of SentencePiece IDs 139 | list_of_targets: a list of target ID lists / np.arrays 140 | max_gram_size: int. largest n-gram size we care about (e.g. 3 for unigrams, 141 | bigrams, and trigrams) 142 | beta_for_deletion: beta for deletion F score. 143 | 144 | Returns: 145 | the SARI score and its three components: add, keep, and deletion scores 146 | """ 147 | addition_scores = [] 148 | keep_scores = [] 149 | deletion_scores = [] 150 | for n in range(1, max_gram_size + 1): 151 | source_counts = _get_ngram_counter(source_ids, n) 152 | prediction_counts = _get_ngram_counter(prediction_ids, n) 153 | # All ngrams in the targets with count 1. 154 | target_counts = collections.Counter() 155 | # All ngrams in the targets with count r/num_targets, where r is the number 156 | # of targets where the ngram occurs. 157 | weighted_target_counts = collections.Counter() 158 | num_nonempty_targets = 0 159 | for target_ids_i in list_of_targets: 160 | target_counts_i = _get_ngram_counter(target_ids_i, n) 161 | if target_counts_i: 162 | weighted_target_counts += target_counts_i 163 | num_nonempty_targets += 1 164 | for gram in weighted_target_counts.keys(): 165 | weighted_target_counts[gram] /= num_nonempty_targets 166 | target_counts[gram] = 1 167 | keep_scores.append(get_keep_score(source_counts, prediction_counts, 168 | weighted_target_counts)) 169 | deletion_scores.append(get_deletion_score(source_counts, prediction_counts, 170 | weighted_target_counts, 171 | beta_for_deletion)) 172 | addition_scores.append(get_addition_score(source_counts, prediction_counts, 173 | target_counts)) 174 | 175 | avg_keep_score = sum(keep_scores) / max_gram_size 176 | avg_addition_score = sum(addition_scores) / max_gram_size 177 | avg_deletion_score = sum(deletion_scores) / max_gram_size 178 | sari = (avg_keep_score + avg_addition_score + avg_deletion_score) / 3.0 179 | return sari, avg_keep_score, avg_addition_score, avg_deletion_score 180 | 181 | 182 | def get_sari(source_ids, prediction_ids, target_ids, max_gram_size=4): 183 | """Computes the SARI scores from the given source, prediction and targets. 184 | 185 | Args: 186 | source_ids: A 2D tf.Tensor of size (batch_size , sequence_length) 187 | prediction_ids: A 2D tf.Tensor of size (batch_size, sequence_length) 188 | target_ids: A 3D tf.Tensor of size (batch_size, number_of_targets, 189 | sequence_length) 190 | max_gram_size: int. largest n-gram size we care about (e.g. 3 for unigrams, 191 | bigrams, and trigrams) 192 | 193 | Returns: 194 | A 4-tuple of 1D float Tensors of size (batch_size) for the SARI score and 195 | the keep, addition and deletion scores. 196 | """ 197 | 198 | def get_sari_numpy(source_ids, prediction_ids, target_ids): 199 | """Iterate over elements in the batch and call the SARI function.""" 200 | sari_scores = [] 201 | keep_scores = [] 202 | add_scores = [] 203 | deletion_scores = [] 204 | # Iterate over elements in the batch. 205 | for source_ids_i, prediction_ids_i, target_ids_i in zip( 206 | source_ids, prediction_ids, target_ids): 207 | sari, keep, add, deletion = get_sari_score( 208 | source_ids_i, prediction_ids_i, target_ids_i, max_gram_size, 209 | BETA_FOR_SARI_DELETION_F_MEASURE) 210 | sari_scores.append(sari) 211 | keep_scores.append(keep) 212 | add_scores.append(add) 213 | deletion_scores.append(deletion) 214 | return (np.asarray(sari_scores), np.asarray(keep_scores), 215 | np.asarray(add_scores), np.asarray(deletion_scores)) 216 | 217 | sari, keep, add, deletion = tf.py_func( 218 | get_sari_numpy, 219 | [source_ids, prediction_ids, target_ids], 220 | [tf.float64, tf.float64, tf.float64, tf.float64]) 221 | return sari, keep, add, deletion 222 | 223 | 224 | def sari_score(predictions, labels, features, **unused_kwargs): 225 | """Computes the SARI scores from the given source, prediction and targets. 226 | 227 | An approximate SARI scoring method since we do not glue word pieces or 228 | decode the ids and tokenize the output. By default, we use ngram order of 4. 229 | Also, this does not have beam search. 230 | 231 | Args: 232 | predictions: tensor, model predictions. 233 | labels: tensor, gold output. 234 | features: dict, containing inputs. 235 | 236 | Returns: 237 | sari: int, approx sari score 238 | """ 239 | if "inputs" not in features: 240 | raise ValueError("sari_score requires inputs feature") 241 | 242 | # Convert the inputs and outputs to a [batch_size, sequence_length] tensor. 243 | inputs = tf.squeeze(features["inputs"], axis=[-1, -2]) 244 | outputs = tf.to_int32(tf.argmax(predictions, axis=-1)) 245 | outputs = tf.squeeze(outputs, axis=[-1, -2]) 246 | 247 | # Convert the labels to a [batch_size, 1, sequence_length] tensor. 248 | labels = tf.squeeze(labels, axis=[-1, -2]) 249 | labels = tf.expand_dims(labels, axis=1) 250 | 251 | score, _, _, _ = get_sari(inputs, outputs, labels) 252 | return score, tf.constant(1.0) 253 | -------------------------------------------------------------------------------- /score_lib.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2019 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # Lint as: python3 17 | """Utility functions for computing evaluation metrics.""" 18 | 19 | from __future__ import absolute_import 20 | from __future__ import division 21 | 22 | from __future__ import print_function 23 | 24 | import re 25 | from nltk.translate import bleu_score 26 | import numpy as np 27 | import tensorflow as tf 28 | import sari_hook 29 | import utils 30 | from curLine_file import curLine 31 | 32 | def read_data( 33 | path, 34 | lowercase): 35 | """Reads data from prediction TSV file. 36 | 37 | The prediction file should contain 3 or more columns: 38 | 1: sources (concatenated) 39 | 2: prediction 40 | 3-n: targets (1 or more) 41 | 42 | Args: 43 | path: Path to the prediction file. 44 | lowercase: Whether to lowercase the data (to compute case insensitive 45 | scores). 46 | 47 | Returns: 48 | Tuple (list of sources, list of predictions, list of target lists) 49 | """ 50 | sources = [] 51 | predictions = [] 52 | target_lists = [] 53 | with tf.gfile.GFile(path) as f: 54 | for line_id, line in enumerate(f): 55 | if line_id == 0: 56 | continue 57 | source, pred, *targets = line.rstrip('\n').split('\t') 58 | if lowercase: 59 | source = source.lower() 60 | pred = pred.lower() 61 | targets = [t.lower() for t in targets] 62 | sources.append(source) 63 | predictions.append(pred) 64 | target_lists.append(targets) 65 | return sources, predictions, target_lists 66 | 67 | 68 | def compute_exact_score(predictions, 69 | target_lists): 70 | """Computes the Exact score (accuracy) of the predictions. 71 | 72 | Exact score is defined as the percentage of predictions that match at least 73 | one of the targets. 74 | 75 | Args: 76 | predictions: List of predictions. 77 | target_lists: List of targets (1 or more per prediction). 78 | 79 | Returns: 80 | Exact score between [0, 1]. 81 | """ 82 | num_matches = sum(any(pred == target for target in targets) 83 | for pred, targets in zip(predictions, target_lists)) 84 | return num_matches / max(len(predictions), 0.1) # Avoids 0/0. 85 | 86 | 87 | def bleu(hyps, refs_list): 88 | """ 89 | calculate bleu1, bleu2, bleu3 90 | """ 91 | bleu_1 = [] 92 | bleu_2 = [] 93 | 94 | for hyp, refs in zip(hyps, refs_list): 95 | if len(hyp) <= 1: 96 | # print("ignore hyp:%s, refs:" % hyp, refs) 97 | bleu_1.append(0.0) 98 | bleu_2.append(0.0) 99 | continue 100 | 101 | score = bleu_score.sentence_bleu( 102 | refs, hyp, 103 | smoothing_function=None, # bleu_score.SmoothingFunction().method7, 104 | weights=[1, 0, 0, 0]) 105 | # input(curLine()) 106 | if score > 1.0: 107 | print(curLine(), refs, hyp) 108 | print(curLine(), "score=", score) 109 | input(curLine()) 110 | bleu_1.append(score) 111 | 112 | score = bleu_score.sentence_bleu( 113 | refs, hyp, 114 | smoothing_function=None, # bleu_score.SmoothingFunction().method7, 115 | weights=[0.5, 0.5, 0, 0]) 116 | bleu_2.append(score) 117 | bleu_1 = np.average(bleu_1) 118 | bleu_2 = np.average(bleu_2) 119 | bleu_average_score = (bleu_1 + bleu_2) * 0.5 120 | print("bleu_1=%f, bleu_2=%f, bleu_average_score=%f" % (bleu_1, bleu_2, bleu_average_score)) 121 | return bleu_average_score 122 | 123 | 124 | def compute_sari_scores( 125 | sources, 126 | predictions, 127 | target_lists, 128 | ignore_wikisplit_separators=True 129 | ): 130 | """Computes SARI scores. 131 | 132 | Wraps the t2t implementation of SARI computation. 133 | 134 | Args: 135 | sources: List of sources. 136 | predictions: List of predictions. 137 | target_lists: List of targets (1 or more per prediction). 138 | ignore_wikisplit_separators: Whether to ignore "<::::>" tokens, used as 139 | sentence separators in Wikisplit, when evaluating. For the numbers 140 | reported in the paper, we accidentally ignored those tokens. Ignoring them 141 | does not affect the Exact score (since there's usually always a period 142 | before the separator to indicate sentence break), but it decreases the 143 | SARI score (since the Addition score goes down as the model doesn't get 144 | points for correctly adding <::::> anymore). 145 | 146 | Returns: 147 | Tuple (SARI score, keep score, addition score, deletion score). 148 | """ 149 | sari_sum = 0 150 | keep_sum = 0 151 | add_sum = 0 152 | del_sum = 0 153 | for source, pred, targets in zip(sources, predictions, target_lists): 154 | if ignore_wikisplit_separators: 155 | source = re.sub(' <::::> ', ' ', source) 156 | pred = re.sub(' <::::> ', ' ', pred) 157 | targets = [re.sub(' <::::> ', ' ', t) for t in targets] 158 | source_ids = list(source) # utils.get_token_list(source) 159 | pred_ids = list(pred) # utils.get_token_list(pred) 160 | list_of_targets = [list(t) for t in targets] 161 | sari, keep, addition, deletion = sari_hook.get_sari_score( 162 | source_ids, pred_ids, list_of_targets, beta_for_deletion=1) 163 | sari_sum += sari 164 | keep_sum += keep 165 | add_sum += addition 166 | del_sum += deletion 167 | n = max(len(sources), 0.1) # Avoids 0/0. 168 | return (sari_sum / n, keep_sum / n, add_sum / n, del_sum / n) 169 | -------------------------------------------------------------------------------- /score_main.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2019 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # Lint as: python3 17 | """Calculates evaluation scores for a prediction TSV file. 18 | 19 | The prediction file is produced by predict_main.py and should contain 3 or more 20 | columns: 21 | 1: sources (concatenated) 22 | 2: prediction 23 | 3-n: targets (1 or more) 24 | """ 25 | 26 | from __future__ import absolute_import 27 | from __future__ import division 28 | 29 | from __future__ import print_function 30 | 31 | from absl import app 32 | from absl import flags 33 | from absl import logging 34 | 35 | import score_lib 36 | 37 | FLAGS = flags.FLAGS 38 | 39 | flags.DEFINE_string( 40 | 'prediction_file', None, 41 | 'TSV file containing source, prediction, and target columns.') 42 | flags.DEFINE_bool( 43 | 'case_insensitive', True, 44 | 'Whether score computation should be case insensitive (in the LaserTagger ' 45 | 'paper this was set to True).') 46 | 47 | 48 | def main(argv): 49 | if len(argv) > 1: 50 | raise app.UsageError('Too many command-line arguments.') 51 | flags.mark_flag_as_required('prediction_file') 52 | 53 | sources, predictions, target_lists = score_lib.read_data( 54 | FLAGS.prediction_file, FLAGS.case_insensitive) 55 | logging.info(f'Read file: {FLAGS.prediction_file}') 56 | exact = score_lib.compute_exact_score(predictions, target_lists) 57 | bleu = score_lib.bleu(predictions, target_lists) 58 | sari, keep, addition, deletion = score_lib.compute_sari_scores( 59 | sources, predictions, target_lists) 60 | print("num=", len(predictions)) 61 | print(f'Exact score: {100 * exact:.3f}') 62 | print(f'Bleu score: {100 * bleu:.3f}') 63 | print(f'SARI score: {100 * sari:.3f}') 64 | print(f' KEEP score: {100 * keep:.3f}') 65 | print(f' ADDITION score: {100 * addition:.3f}') 66 | print(f' DELETION score: {100 * deletion:.3f}') 67 | 68 | 69 | if __name__ == '__main__': 70 | app.run(main) 71 | -------------------------------------------------------------------------------- /sentence_fusion_task.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Mleader2/text_scalpel/131922c5dcfc558a46a7590925e348c39cb24951/sentence_fusion_task.png -------------------------------------------------------------------------------- /skill_rephrase/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Mleader2/text_scalpel/131922c5dcfc558a46a7590925e348c39cb24951/skill_rephrase/__init__.py -------------------------------------------------------------------------------- /skill_rephrase/predict_for_skill.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # 为任务型语料做泛化 意图和槽位识别 3 | 4 | from __future__ import absolute_import 5 | from __future__ import division 6 | from __future__ import print_function 7 | 8 | from absl import app 9 | from absl import flags 10 | from absl import logging 11 | import math, json 12 | import os, sys, time 13 | from termcolor import colored 14 | import tensorflow as tf 15 | 16 | block_list = os.path.realpath(__file__).split("/") 17 | path = "/".join(block_list[:-2]) 18 | sys.path.append(path) 19 | 20 | import bert_example 21 | import predict_utils 22 | import tagging_converter 23 | import utils 24 | 25 | from curLine_file import curLine 26 | 27 | FLAGS = flags.FLAGS 28 | 29 | flags.DEFINE_string( 30 | 'input_file', None, 31 | 'Path to the input file containing examples for which to compute ' 32 | 'predictions.') 33 | flags.DEFINE_enum( 34 | 'input_format', None, ['wikisplit', 'discofuse'], 35 | 'Format which indicates how to parse the input_file.') 36 | flags.DEFINE_string( 37 | 'output_file', None, 38 | 'Path to the TSV file where the predictions are written to.') 39 | flags.DEFINE_string( 40 | 'label_map_file', None, 41 | 'Path to the label map file. Either a JSON file ending with ".json", that ' 42 | 'maps each possible tag to an ID, or a text file that has one tag per ' 43 | 'line.') 44 | flags.DEFINE_string('vocab_file', None, 'Path to the BERT vocabulary file.') 45 | flags.DEFINE_integer('max_seq_length', 128, 'Maximum sequence length.') 46 | flags.DEFINE_bool( 47 | 'do_lower_case', False, 48 | 'Whether to lower case the input text. Should be True for uncased ' 49 | 'models and False for cased models.') 50 | flags.DEFINE_bool('enable_swap_tag', True, 'Whether to enable the SWAP tag.') 51 | flags.DEFINE_string('saved_model', None, 'Path to an exported TF model.') 52 | 53 | 54 | def main(argv): 55 | if len(argv) > 1: 56 | raise app.UsageError('Too many command-line arguments.') 57 | flags.mark_flag_as_required('input_file') 58 | flags.mark_flag_as_required('input_format') 59 | flags.mark_flag_as_required('output_file') 60 | flags.mark_flag_as_required('label_map_file') 61 | flags.mark_flag_as_required('vocab_file') 62 | flags.mark_flag_as_required('saved_model') 63 | 64 | label_map = utils.read_label_map(FLAGS.label_map_file) 65 | converter = tagging_converter.TaggingConverter( 66 | tagging_converter.get_phrase_vocabulary_from_label_map(label_map), 67 | FLAGS.enable_swap_tag) 68 | builder = bert_example.BertExampleBuilder(label_map, FLAGS.vocab_file, 69 | FLAGS.max_seq_length, 70 | FLAGS.do_lower_case, converter) 71 | predictor = predict_utils.LaserTaggerPredictor( 72 | tf.contrib.predictor.from_saved_model(FLAGS.saved_model), builder, 73 | label_map) 74 | print(colored("%s input file:%s" % (curLine(), FLAGS.input_file), "red")) 75 | num_predicted = 0 76 | 77 | sources_list = [] 78 | location_list = [] 79 | corpus_id_list = [] 80 | entity_list = [] 81 | domainname_list = [] 82 | intentname_list = [] 83 | context_list = [] 84 | template_id_list = [] 85 | with open(FLAGS.input_file, "r") as f: 86 | corpus_json_list = json.load(f) 87 | # corpus_json_list = corpus_json_list[:100] 88 | for corpus_json in corpus_json_list: 89 | sources_list.append([corpus_json["oriText"]]) 90 | location_list.append(corpus_json["location"]) 91 | corpus_id_list.append(corpus_json["corpus_id"]) 92 | entity_list.append(corpus_json["entity"]) 93 | domainname_list.append(corpus_json["domainname"]) 94 | intentname_list.append(corpus_json["intentname"]) 95 | context_list.append(corpus_json["context"]) 96 | template_id_list.append(corpus_json["template_id"]) 97 | number = len(sources_list) # 总样本数 98 | predict_batch_size = min(64, number) 99 | batch_num = math.ceil(float(number) / predict_batch_size) 100 | start_time = time.time() 101 | index = 0 102 | for batch_id in range(batch_num): 103 | sources_batch = sources_list[batch_id * predict_batch_size: (batch_id + 1) * predict_batch_size] 104 | location_batch = location_list[batch_id * predict_batch_size: (batch_id + 1) * predict_batch_size] 105 | prediction_batch = predictor.predict_batch(sources_batch=sources_batch, location_batch=location_batch) 106 | assert len(prediction_batch) == len(sources_batch) 107 | num_predicted += len(prediction_batch) 108 | for id, [prediction, sources] in enumerate(zip(prediction_batch, sources_batch)): 109 | index = batch_id * predict_batch_size + id 110 | output_json = {"corpus_id": corpus_id_list[index], "oriText": prediction, "sources": sources[0], 111 | "entity": entity_list[index], "location": location_list[index], 112 | "domainname": domainname_list[index], "intentname": intentname_list[index], 113 | "context": context_list[index], "template_id": template_id_list[index]} 114 | corpus_json_list[index] = output_json 115 | if batch_id % 20 == 0: 116 | cost_time = (time.time() - start_time) / 60.0 117 | print("%s batch_id=%d/%d, predict %d/%d examples, cost %.2fmin." % 118 | (curLine(), batch_id + 1, batch_num, num_predicted, number, cost_time)) 119 | assert len(corpus_json_list) == index + 1 120 | with open(FLAGS.output_file, 'w', encoding='utf-8') as writer: 121 | json.dump(corpus_json_list, writer, ensure_ascii=False, indent=4) 122 | cost_time = (time.time() - start_time) / 60.0 123 | logging.info( 124 | f'{curLine()} {num_predicted} predictions saved to:{FLAGS.output_file}, cost {cost_time} min, ave {cost_time / num_predicted} min.') 125 | 126 | 127 | if __name__ == '__main__': 128 | app.run(main) 129 | -------------------------------------------------------------------------------- /skill_rephrase/score_for_skill.txt: -------------------------------------------------------------------------------- 1 | # 扩展times的语料前 2 | [trainer.py:182] f1_score_slot= 99.93185946253503 f1_score_intent= 99.05839393642326 f1_score= 99.78628187484973 3 | [trainer.py:199] The best dev: epoch=3, score=99.99 4 | [trainer.py:200] The corresponding test: epoch=3, score=100.00 5 | [trainer.py:201] Finished archiving the models, start final testing. 6 | [trainer.py:203] model_file: /home/cloudminds/Mywork/corpus/ner_corpus/times_model/lstm_crf.m 7 | [trainer.py:253] cost_time=0.031036s, ave_time=0.341051ms 8 | [trainer.py:263] type:number P:100.00%, R:100.00%, F1:100.000% 9 | [trainer.py:263] type:year P:100.00%, R:100.00%, F1:100.000% 10 | [trainer.py:263] type:date P:100.00%, R:100.00%, F1:100.000% 11 | [trainer.py:289] test_slot P:100.00, R:100.00, F1:100.000 12 | 13 | [trainer.py:302] f1_weighted_intent=91.784692, f1_macro_intent=93.101826 14 | [trainer.py:303] f1_score_intent= 92.0481191049753 classification_report_intent: 15 | precision recall f1-score support 16 | 17 | 0 1.000 1.000 1.000 4 18 | 1 1.000 0.571 0.727 14 19 | 2 1.000 1.000 1.000 18 20 | 3 0.962 0.962 0.962 26 21 | 4 0.773 1.000 0.872 17 22 | 5 1.000 1.000 1.000 1 23 | 6 0.917 1.000 0.957 11 24 | 25 | accuracy 0.923 91 26 | macro avg 0.950 0.933 0.931 91 27 | weighted avg 0.936 0.923 0.918 91 28 | 29 | # 扩展times的语料后 下降一点,是因为没有针对这些场景么 30 | [trainer.py:182] f1_score_slot= 99.87074867480472 f1_score_intent= 98.7267364507435 f1_score= 99.68007997079451 31 | [trainer.py:199] The best dev: epoch=8, score=99.92 32 | [trainer.py:200] The corresponding test: epoch=8, score=100.00 33 | [trainer.py:201] Finished archiving the models, start final testing. 34 | [trainer.py:203] model_file: /home/cloudminds/Mywork/corpus/ner_corpus/times_model/lstm_crf.m 35 | [trainer.py:253] cost_time=0.029509s, ave_time=0.324270ms 36 | [trainer.py:263] type:number P:100.00%, R:100.00%, F1:100.000% 37 | [trainer.py:263] type:year P:100.00%, R:100.00%, F1:100.000% 38 | [trainer.py:263] type:date P:100.00%, R:100.00%, F1:100.000% 39 | [trainer.py:289] test_slot P:100.00, R:100.00, F1:100.000 40 | 41 | [trainer.py:302] f1_weighted_intent=89.032198, f1_macro_intent=90.676121 42 | [trainer.py:303] f1_score_intent= 89.36098260170371 classification_report_intent: 43 | precision recall f1-score support 44 | 45 | 0 1.000 1.000 1.000 4 46 | 1 1.000 0.429 0.600 14 47 | 2 1.000 1.000 1.000 18 48 | 3 0.962 0.962 0.962 26 49 | 4 0.708 1.000 0.829 17 50 | 5 1.000 1.000 1.000 1 51 | 6 0.917 1.000 0.957 11 52 | 53 | accuracy 0.901 91 54 | macro avg 0.941 0.913 0.907 91 55 | weighted avg 0.924 0.901 0.890 91 56 | -------------------------------------------------------------------------------- /tagging.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2019 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Classes representing a tag and a text editing task. 17 | 18 | Tag corresponds to an edit operation, while EditingTask is a container for the 19 | input that LaserTagger takes. EditingTask also has a method for realizing the 20 | output text given the predicted tags. 21 | """ 22 | 23 | from __future__ import absolute_import 24 | from __future__ import division 25 | 26 | from __future__ import print_function 27 | 28 | from enum import Enum 29 | from curLine_file import curLine 30 | 31 | import utils 32 | 33 | 34 | class TagType(Enum): 35 | """Base tag which indicates the type of an edit operation.""" 36 | # Keep the tagged token. 37 | KEEP = 1 38 | # Delete the tagged token. 39 | DELETE = 2 40 | # Keep the tagged token but swap the order of sentences. This tag is only 41 | # applied if there are two source texts and the tag is applied to the last 42 | # token of the first source. In other contexts, it's treated as KEEP. 43 | SWAP = 3 44 | 45 | 46 | class Tag(object): 47 | """Tag that corresponds to a token edit operation. 48 | 49 | Attributes: 50 | tag_type: TagType of the tag. 51 | added_phrase: A phrase that's inserted before the tagged token (can be 52 | empty). 53 | """ 54 | 55 | def __init__(self, tag): 56 | """Constructs a Tag object by parsing tag to tag_type and added_phrase. 57 | 58 | Args: 59 | tag: String representation for the tag which should have the following 60 | format "|" or simply "" if no phrase 61 | is added before the tagged token. Examples of valid tags include "KEEP", 62 | "DELETE|and", and "SWAP|.". 63 | 64 | Raises: 65 | ValueError: If is invalid. 66 | """ 67 | if '|' in tag: 68 | pos_pipe = tag.index('|') # 也可以直接split,再把[1:]拼接成一个字符串吧 69 | tag_type, added_phrase = tag[:pos_pipe], tag[pos_pipe + 1:] 70 | else: 71 | tag_type, added_phrase = tag, '' 72 | try: 73 | self.tag_type = TagType[tag_type] # for example: tag_type:KEEP self.tag_type:TagType.KEEP 74 | except KeyError: 75 | raise ValueError( 76 | 'TagType should be KEEP, DELETE or SWAP, not {}'.format(tag_type)) 77 | self.added_phrase = added_phrase 78 | 79 | def __str__(self): 80 | if not self.added_phrase: 81 | return self.tag_type.name 82 | else: 83 | return '{}|{}'.format(self.tag_type.name, self.added_phrase) 84 | 85 | 86 | class EditingTask(object): 87 | """Text-editing task. 88 | 89 | Attributes: 90 | sources: Source texts. 91 | source_tokens: Tokens of the source texts concatenated into a single list. 92 | first_tokens: The indices of the first tokens of each source text. 93 | """ 94 | 95 | def __init__(self, sources, location=None, tokenizer=None): 96 | """Initializes an instance of EditingTask. 97 | 98 | Args: 99 | sources: A list of source strings. Typically contains only one string but 100 | for sentence fusion it contains two strings to be fused (whose order may 101 | be swapped). 102 | location: None或字符串, 0表示能变,1表示不能变 103 | """ 104 | self.sep = '' # for Chinses 105 | self.sources = sources 106 | source_token_lists = [tokenizer.tokenize(text) for text in self.sources] 107 | # Tokens of the source texts concatenated into a single list. 108 | self.source_tokens = [] 109 | # The indices of the first tokens of each source text. 110 | self.first_tokens = [] 111 | for token_list in source_token_lists: 112 | self.first_tokens.append(len(self.source_tokens)) 113 | self.source_tokens.extend(token_list) 114 | self.location = location 115 | 116 | self.token_index_map = {} # 为处理生成的UNK 117 | previous_id = 0 118 | for tokenizer_id, t in enumerate(self.source_tokens): 119 | if tokenizer_id > 0 and "UNK" in self.source_tokens[tokenizer_id - 1]: 120 | if t in self.source_tokens[previous_id:]: 121 | previous_id = previous_id + self.source_tokens[previous_id:].index(t) 122 | else: # 出现连续的UNK情况,目前的做法是假设长度为1 123 | previous_id += 1 124 | self.token_index_map[tokenizer_id] = previous_id 125 | if "UNK" not in t: 126 | length_t = len(t) 127 | if t.startswith("##", 0, 2): 128 | length_t -= 2 129 | previous_id += length_t 130 | 131 | def _realize_sequence(self, tokens, tags): 132 | """Realizes output text corresponding to a single source text. 133 | 134 | Args: 135 | tokens: Tokens of the source text. 136 | tags: Tags indicating the edit operations. 137 | 138 | Returns: 139 | The realized text. 140 | """ 141 | output_tokens = [] 142 | for index, (token, tag) in enumerate(zip(tokens, tags)): 143 | loc = "0" 144 | if self.location is not None: 145 | loc = self.location[index] 146 | if tag.added_phrase and ( 147 | loc == "0" or index == 0 or (index > 0 and self.location[index - 1] == "0")): # TODO 148 | if not tag.added_phrase.startswith("##", 0, 2): 149 | output_tokens.append(tag.added_phrase) 150 | else: # word piece 151 | if len(output_tokens) > 0: 152 | output_tokens[-1] += tag.added_phrase[2:] 153 | else: 154 | output_tokens.append(tag.added_phrase[2:]) 155 | if tag.tag_type in ( 156 | TagType.KEEP, TagType.SWAP) or loc == "1": # TODO 根据需要修改代码,location为"1"的位置不能被删除, 但目前是可以插入的 157 | token = token.upper() # TODO 因为当前语料中有不少都是大写的,所以把预测结果都转为大写 158 | if token.startswith("##", 0, 2): 159 | output_tokens.append(token[2:]) 160 | elif "UNK" in token: # 处理UNK的情况 161 | previoud_id = self.token_index_map[index] # unk对应word开始的位置 162 | next_previoud_id = previoud_id + 1 # unk对应word结束的位置 163 | if index + 1 in self.token_index_map: 164 | next_previoud_id = self.token_index_map[index + 1] 165 | token = self.sources[0][previoud_id:next_previoud_id] # TODO 166 | print(curLine(), "self.passage[%d,%d]=%s" % (previoud_id, next_previoud_id, token)) 167 | output_tokens.append(token) 168 | else: # word piece 169 | output_tokens.append(token) 170 | return self.sep.join(output_tokens) 171 | 172 | def _first_char_to_upper(self, text): 173 | """Upcases the first character of the text.""" 174 | try: 175 | return text[0].upper() + text[1:] 176 | except IndexError: 177 | return text 178 | 179 | def _first_char_to_lower(self, text): 180 | """Lowcases the first character of the text.""" 181 | try: 182 | return text[0].lower() + text[1:] 183 | except IndexError: 184 | return text 185 | 186 | def realize_output(self, tags): 187 | """Realize output text based on the source tokens and predicted tags. 188 | 189 | Args: 190 | tags: Predicted tags (one for each token in `self.source_tokens`). 191 | 192 | Returns: 193 | The realizer output text. 194 | 195 | Raises: 196 | ValueError: If the number of tags doesn't match the number of source 197 | tokens. 198 | """ 199 | if len(tags) != len(self.source_tokens): 200 | raise ValueError('The number of tags ({}) should match the number of ' 201 | 'source tokens ({})'.format( 202 | len(tags), len(self.source_tokens))) 203 | outputs = [] # Realized sources that are joined into the output text. 204 | if (len(self.first_tokens) == 2 and 205 | tags[self.first_tokens[1] - 1].tag_type == TagType.SWAP): 206 | order = [1, 0] 207 | else: 208 | order = range(len(self.first_tokens)) 209 | for source_idx in order: 210 | # Get the span of tokens for the source: [first_token, last_token). 211 | first_token = self.first_tokens[source_idx] 212 | if source_idx + 1 < len(self.first_tokens): 213 | last_token = self.first_tokens[source_idx + 1] # Not inclusive. 214 | else: 215 | last_token = len(self.source_tokens) 216 | # Realize the source and fix casing. 217 | realized_source = self._realize_sequence( 218 | self.source_tokens[first_token:last_token], 219 | tags[first_token:last_token]) 220 | if outputs: 221 | if len(outputs[0][-1:]) > 0 and outputs[0][-1:] in '.!?': 222 | realized_source = self._first_char_to_upper(realized_source) # 变大写 223 | else: 224 | # Note that ideally we should also test here whether the first word is 225 | # a proper noun or an abbreviation that should always be capitalized. 226 | realized_source = self._first_char_to_lower(realized_source) # 变小写 227 | # print(curLine(), len(outputs[0][-1:]), "outputs[0][-1:]:", outputs[0][-1:], "source_idx=",source_idx, ",realized_source:", realized_source) 228 | outputs.append(realized_source) 229 | prediction = self.sep.join(outputs) 230 | return prediction 231 | -------------------------------------------------------------------------------- /tagging_converter.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2019 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Conversion from training target text into target tags. 17 | 18 | The conversion algorithm from (source, target) pairs to (source, target_tags) 19 | pairs is described in Algorithm 1 of the LaserTagger paper 20 | (https://arxiv.org/abs/1909.01187). 21 | """ 22 | 23 | from __future__ import absolute_import 24 | from __future__ import division 25 | 26 | from __future__ import print_function 27 | 28 | import tagging 29 | import utils 30 | 31 | from typing import Iterable, Mapping, Sequence, Set, Text, Tuple 32 | 33 | 34 | class TaggingConverter(object): 35 | """Converter from training target texts into tagging format.""" 36 | 37 | def __init__(self, phrase_vocabulary, do_swap=True, tokenizer=None): 38 | """Initializes an instance of TaggingConverter. 39 | 40 | Args: 41 | phrase_vocabulary: Iterable of phrase vocabulary items (strings). 42 | do_swap: Whether to enable the SWAP tag. 43 | """ 44 | self._phrase_vocabulary = set( 45 | phrase.lower() for phrase in phrase_vocabulary) 46 | self._do_swap = do_swap 47 | # Maximum number of tokens in an added phrase (inferred from the 48 | # vocabulary). 49 | self._max_added_phrase_length = 0 50 | # Set of tokens that are part of a phrase in self.phrase_vocabulary. 51 | self._token_vocabulary = set() # word piece 的集合 52 | for phrase in self._phrase_vocabulary: 53 | tokens = tokenizer.tokenize(phrase) 54 | self._token_vocabulary |= set(tokens) # 集合的合并 55 | if len(tokens) > self._max_added_phrase_length: 56 | self._max_added_phrase_length = len(tokens) 57 | 58 | def compute_tags(self, task, target, tokenizer): 59 | """Computes tags needed for converting the source into the target. 60 | 61 | Args: 62 | task: tagging.EditingTask that specifies the input. 63 | target: Target text. 64 | 65 | Returns: 66 | List of tagging.Tag objects. If the source couldn't be converted into the 67 | target via tagging, returns an empty list. 68 | """ 69 | target_tokens = tokenizer.tokenize(target.lower()) 70 | tags = self._compute_tags_fixed_order(task.source_tokens, target_tokens) 71 | # If conversion fails, try to obtain the target after swapping the source 72 | # order. 73 | if not tags and len(task.sources) == 2 and self._do_swap: 74 | swapped_task = tagging.EditingTask(task.sources[::-1]) 75 | tags = self._compute_tags_fixed_order(swapped_task.source_tokens, 76 | target_tokens) 77 | if tags: 78 | tags = (tags[swapped_task.first_tokens[1]:] + 79 | tags[:swapped_task.first_tokens[1]]) 80 | # We assume that the last token (typically a period) is never deleted, 81 | # so we can overwrite the tag_type with SWAP (which keeps the token, 82 | # moving it and the sentence it's part of to the end). 83 | tags[task.first_tokens[1] - 1].tag_type = tagging.TagType.SWAP 84 | return tags 85 | 86 | def _compute_tags_fixed_order(self, source_tokens, target_tokens): 87 | """Computes tags when the order of sources is fixed. 88 | 89 | Args: 90 | source_tokens: List of source tokens. 91 | target_tokens: List of tokens to be obtained via edit operations. 92 | 93 | Returns: 94 | List of tagging.Tag objects. If the source couldn't be converted into the 95 | target via tagging, returns an empty list. 96 | """ 97 | tags = [tagging.Tag('DELETE') for _ in source_tokens] 98 | # Indices of the tokens currently being processed. 99 | source_token_idx = 0 100 | target_token_idx = 0 101 | while target_token_idx < len(target_tokens): 102 | tags[source_token_idx], target_token_idx = self._compute_single_tag( 103 | source_tokens[source_token_idx], target_token_idx, target_tokens) 104 | # TODO 可以有多种标注方式从source转化为target,目前限定到一种 105 | # If we're adding a phrase and the previous source token(s) were deleted, 106 | # we could add the phrase before a previously deleted token and still get 107 | # the same realized output. For example: 108 | # [DELETE, DELETE, KEEP|"what is"] 109 | # and 110 | # [DELETE|"what is", DELETE, KEEP] 111 | # Would yield the same realized output. Experimentally, we noticed that 112 | # the model works better / the learning task becomes easier when phrases 113 | # are always added before the first deleted token. Also note that in the 114 | # current implementation, this way of moving the added phrase backward is 115 | # the only way a DELETE tag can have an added phrase, so sequences like 116 | # [DELETE|"What", DELETE|"is"] will never be created. 117 | if tags[source_token_idx].added_phrase: 118 | # # the learning task becomes easier when phrases are always added before the first deleted token 119 | first_deletion_idx = self._find_first_deletion_idx( 120 | source_token_idx, tags) 121 | if first_deletion_idx != source_token_idx: 122 | tags[first_deletion_idx].added_phrase = ( 123 | tags[source_token_idx].added_phrase) 124 | tags[source_token_idx].added_phrase = '' 125 | source_token_idx += 1 126 | if source_token_idx >= len(tags): 127 | break 128 | 129 | # If all target tokens have been consumed, we have found a conversion and 130 | # can return the tags. Note that if there are remaining source tokens, they 131 | # are already marked deleted when initializing the tag list. 132 | if target_token_idx >= len(target_tokens): # all target tokens have been consumed 133 | return tags 134 | return [] # TODO 不能转化 135 | 136 | def _compute_single_tag( 137 | self, source_token, target_token_idx, 138 | target_tokens): 139 | """Computes a single tag. 140 | 141 | The tag may match multiple target tokens (via tag.added_phrase) so we return 142 | the next unmatched target token. 143 | 144 | Args: 145 | source_token: The token to be tagged. 146 | target_token_idx: Index of the current target tag. 147 | target_tokens: List of all target tokens. 148 | 149 | Returns: 150 | A tuple with (1) the computed tag and (2) the next target_token_idx. 151 | """ 152 | source_token = source_token.lower() 153 | target_token = target_tokens[target_token_idx].lower() 154 | if source_token == target_token: 155 | return tagging.Tag('KEEP'), target_token_idx + 1 156 | # source_token!=target_token的情况 157 | added_phrase = '' 158 | for num_added_tokens in range(1, self._max_added_phrase_length + 1): 159 | if target_token not in self._token_vocabulary: 160 | break 161 | added_phrase += (' ' if added_phrase else '') + target_token 162 | next_target_token_idx = target_token_idx + num_added_tokens 163 | if next_target_token_idx >= len(target_tokens): # 已经完成转化 164 | break 165 | target_token = target_tokens[next_target_token_idx].lower() 166 | if (source_token == target_token and 167 | added_phrase in self._phrase_vocabulary): 168 | return tagging.Tag('KEEP|' + added_phrase), next_target_token_idx + 1 169 | return tagging.Tag('DELETE'), target_token_idx 170 | 171 | def _find_first_deletion_idx(self, source_token_idx, tags): 172 | """Finds the start index of a span of deleted tokens. 173 | 174 | If `source_token_idx` is preceded by a span of deleted tokens, finds the 175 | start index of the span. Otherwise, returns `source_token_idx`. 176 | 177 | Args: 178 | source_token_idx: Index of the current source token. 179 | tags: List of tags. 180 | 181 | Returns: 182 | The index of the first deleted token preceding `source_token_idx` or 183 | `source_token_idx` if there are no deleted tokens right before it. 184 | """ 185 | # Backtrack until the beginning of the tag sequence. 186 | for idx in range(source_token_idx, 0, -1): 187 | if tags[idx - 1].tag_type != tagging.TagType.DELETE: 188 | return idx 189 | return 0 190 | 191 | 192 | def get_phrase_vocabulary_from_label_map( 193 | label_map): 194 | """Extract the set of all phrases from label map. 195 | 196 | Args: 197 | label_map: Mapping from tags to tag IDs. 198 | 199 | Returns: 200 | Set of all phrases appearing in the label map. 201 | """ 202 | phrase_vocabulary = set() 203 | for label in label_map.keys(): 204 | tag = tagging.Tag(label) 205 | if tag.added_phrase: 206 | phrase_vocabulary.add(tag.added_phrase) 207 | return phrase_vocabulary 208 | -------------------------------------------------------------------------------- /transformer_decoder.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2019 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # Lint as: python3 17 | """Transformer decoder.""" 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | 21 | from __future__ import print_function 22 | 23 | from typing import Any, Mapping, Text 24 | 25 | import tensorflow as tf 26 | 27 | from official_transformer import attention_layer 28 | from official_transformer import embedding_layer 29 | from official_transformer import ffn_layer 30 | from official_transformer import model_utils 31 | from official_transformer import transformer 32 | 33 | 34 | class TransformerDecoder(transformer.Transformer): 35 | """Transformer decoder. 36 | 37 | Attributes: 38 | train: Whether the model is in training mode. 39 | params: Model hyperparameters. 40 | """ 41 | 42 | def __init__(self, params, train): 43 | """Initializes layers to build Transformer model. 44 | 45 | Args: 46 | params: hyperparameter object defining layer sizes, dropout values, etc. 47 | train: boolean indicating whether the model is in training mode. Used to 48 | determine if dropout layers should be added. 49 | """ 50 | self.train = train 51 | self.params = params 52 | self.embedding_softmax_layer = embedding_layer.EmbeddingSharedWeights( 53 | params["vocab_size"], params["hidden_size"], 54 | method="matmul" if params["use_tpu"] else "gather") 55 | # override self.decoder_stack 56 | if self.params["use_full_attention"]: 57 | self.decoder_stack = transformer.DecoderStack(params, train) 58 | else: 59 | self.decoder_stack = DecoderStack(params, train) 60 | 61 | def __call__(self, inputs, encoder_outputs, targets=None): 62 | """Calculates target logits or inferred target sequences. 63 | 64 | Args: 65 | inputs: int tensor with shape [batch_size, input_length]. 66 | encoder_outputs: int tensor with shape 67 | [batch_size, input_length, hidden_size] 68 | targets: None or int tensor with shape [batch_size, target_length]. 69 | 70 | Returns: 71 | If targets is defined, then return logits for each word in the target 72 | sequence. float tensor with shape [batch_size, target_length, vocab_size] 73 | If target is none, then generate output sequence one token at a time. 74 | returns a dictionary { 75 | output: [batch_size, decoded length] 76 | score: [batch_size, float]} 77 | """ 78 | # Variance scaling is used here because it seems to work in many problems. 79 | # Other reasonable initializers may also work just as well. 80 | initializer = tf.variance_scaling_initializer( 81 | self.params["initializer_gain"], mode="fan_avg", distribution="uniform") 82 | with tf.variable_scope("Transformer", initializer=initializer): 83 | # Calculate attention bias for encoder self-attention and decoder 84 | # multi-headed attention layers. 85 | attention_bias = model_utils.get_padding_bias(inputs) 86 | 87 | # Generate output sequence if targets is None, or return logits if target 88 | # sequence is known. 89 | if targets is None: 90 | return self.predict(encoder_outputs, attention_bias) 91 | else: 92 | logits = self.decode(targets, encoder_outputs, attention_bias) 93 | return logits 94 | 95 | def _get_symbols_to_logits_fn(self, max_decode_length): 96 | """Returns a decoding function that calculates logits of the next tokens.""" 97 | 98 | timing_signal = model_utils.get_position_encoding( 99 | max_decode_length + 1, self.params["hidden_size"]) 100 | decoder_self_attention_bias = model_utils.get_decoder_self_attention_bias( 101 | max_decode_length) 102 | 103 | def symbols_to_logits_fn(ids, i, cache): 104 | """Generate logits for next potential IDs. 105 | 106 | Args: 107 | ids: Current decoded sequences. 108 | int tensor with shape [batch_size * beam_size, i + 1] 109 | i: Loop index 110 | cache: dictionary of values storing the encoder output, encoder-decoder 111 | attention bias, and previous decoder attention values. 112 | 113 | Returns: 114 | Tuple of 115 | (logits with shape [batch_size * beam_size, vocab_size], 116 | updated cache values) 117 | """ 118 | # Set decoder input to the last generated IDs 119 | decoder_input = ids[:, -1:] 120 | 121 | # Preprocess decoder input by getting embeddings and adding timing signal. 122 | decoder_input = self.embedding_softmax_layer(decoder_input) 123 | decoder_input += timing_signal[i:i + 1] 124 | 125 | self_attention_bias = decoder_self_attention_bias[:, :, i:i + 1, :i + 1] 126 | if self.params["use_full_attention"]: 127 | encoder_outputs = cache.get("encoder_outputs") 128 | else: 129 | encoder_outputs = cache.get("encoder_outputs")[:, i:i + 1] 130 | decoder_outputs = self.decoder_stack( 131 | decoder_input, encoder_outputs, self_attention_bias, 132 | cache.get("encoder_decoder_attention_bias"), cache) 133 | logits = self.embedding_softmax_layer.linear(decoder_outputs) 134 | logits = tf.squeeze(logits, axis=[1]) 135 | return logits, cache 136 | 137 | return symbols_to_logits_fn 138 | 139 | 140 | class DecoderStack(tf.layers.Layer): 141 | """Modified Transformer decoder stack. 142 | 143 | Like the standard Transformer decoder stack but: 144 | 1. Removes the encoder-decoder attention layer, and 145 | 2. Adds a layer to project the concatenated [encoder activations, hidden 146 | state] to the hidden size. 147 | """ 148 | 149 | def __init__(self, params, train): 150 | super(DecoderStack, self).__init__() 151 | self.layers = [] 152 | for _ in range(params["num_hidden_layers"]): 153 | self_attention_layer = attention_layer.SelfAttention( 154 | params["hidden_size"], params["num_heads"], 155 | params["attention_dropout"], train) 156 | feed_forward_network = ffn_layer.FeedFowardNetwork( # NOTYPO 157 | params["hidden_size"], params["filter_size"], 158 | params["relu_dropout"], train, params["allow_ffn_pad"]) 159 | 160 | proj_layer = tf.layers.Dense( 161 | # TODO 加了一层MLP,project the concatenated [encoder activations, hidden state] to the hidden size 162 | params["hidden_size"], use_bias=True, name="proj_layer") 163 | 164 | self.layers.append([ 165 | transformer.PrePostProcessingWrapper( 166 | self_attention_layer, params, train), 167 | transformer.PrePostProcessingWrapper( 168 | feed_forward_network, params, train), 169 | proj_layer]) 170 | 171 | self.output_normalization = transformer.LayerNormalization( 172 | params["hidden_size"]) 173 | 174 | def call(self, decoder_inputs, encoder_outputs, decoder_self_attention_bias, 175 | attention_bias=None, cache=None): 176 | """Returns the output of the decoder layer stacks. 177 | 178 | Args: 179 | decoder_inputs: tensor with shape [batch_size, target_length, hidden_size] 180 | encoder_outputs: tensor with shape [batch_size, input_length, hidden_size] 181 | decoder_self_attention_bias: bias for decoder self-attention layer. 182 | [1, 1, target_len, target_length] 183 | attention_bias: bias for encoder-decoder attention layer. 184 | [batch_size, 1, 1, input_length] 185 | cache: (Used for fast decoding) A nested dictionary storing previous 186 | decoder self-attention values. The items are: 187 | {layer_n: {"k": tensor with shape [batch_size, i, key_channels], 188 | "v": tensor with shape [batch_size, i, value_channels]}, 189 | ...} 190 | 191 | Returns: 192 | Output of decoder layer stack. 193 | float32 tensor with shape [batch_size, target_length, hidden_size] 194 | """ 195 | for n, layer in enumerate(self.layers): 196 | self_attention_layer = layer[0] 197 | feed_forward_network = layer[1] 198 | proj_layer = layer[2] 199 | 200 | decoder_inputs = tf.concat([decoder_inputs, encoder_outputs], axis=-1) 201 | decoder_inputs = proj_layer(decoder_inputs) 202 | 203 | # Run inputs through the sublayers. 204 | layer_name = "layer_%d" % n 205 | layer_cache = cache[layer_name] if cache is not None else None 206 | with tf.variable_scope(layer_name): 207 | with tf.variable_scope("self_attention"): 208 | decoder_inputs = self_attention_layer( 209 | decoder_inputs, decoder_self_attention_bias, cache=layer_cache) 210 | with tf.variable_scope("ffn"): 211 | decoder_inputs = feed_forward_network(decoder_inputs) 212 | 213 | return self.output_normalization(decoder_inputs) 214 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2019 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # Lint as: python3 17 | """Utility functions for LaserTagger.""" 18 | 19 | from __future__ import absolute_import 20 | from __future__ import division 21 | 22 | from __future__ import print_function 23 | 24 | import json 25 | from bert import tokenization 26 | import tensorflow as tf 27 | 28 | ### 中文用word piece, 保留空格 29 | class my_tokenizer_class(object): 30 | def __init__(self, vocab_file, do_lower_case): 31 | self.full_tokenizer = tokenization.FullTokenizer(vocab_file, do_lower_case=do_lower_case) 32 | 33 | # 需要包装一下,因为如果直接对中文用full_tokenizer.tokenize,会忽略文本中的空格 34 | def tokenize(self, text): 35 | segments = text.split(" ") 36 | word_pieces = [] 37 | for segId, segment in enumerate(segments): 38 | if segId > 0: 39 | word_pieces.append(" ") 40 | word_pieces.extend(self.full_tokenizer.tokenize(segment)) 41 | return word_pieces 42 | 43 | def convert_tokens_to_ids(self, tokens): 44 | id_list = [self.full_tokenizer.vocab[t] 45 | if t != " " else self.full_tokenizer.vocab["[unused20]"] for t in tokens] 46 | return id_list 47 | 48 | 49 | def yield_sources_and_targets( 50 | input_file, 51 | input_format): 52 | """Reads and yields source lists and targets from the input file. 53 | 54 | Args: 55 | input_file: Path to the input file. 56 | input_format: Format of the input file. 57 | 58 | Yields: 59 | Tuple with (list of source texts, target text). 60 | """ 61 | if input_format == 'wikisplit': 62 | yield_example_fn = _yield_wikisplit_examples 63 | elif input_format == 'discofuse': 64 | yield_example_fn = _yield_discofuse_examples 65 | else: 66 | raise ValueError('Unsupported input_format: {}'.format(input_format)) 67 | 68 | for sources, target in yield_example_fn(input_file): 69 | yield sources, target 70 | 71 | 72 | def _yield_wikisplit_examples( 73 | input_file): 74 | # The Wikisplit format expects a TSV file with the source on the first and the 75 | # target on the second column. 76 | with tf.gfile.GFile(input_file) as f: 77 | for line in f: 78 | source, target, lcs_rate = line.rstrip('\n').split('\t') 79 | yield [source], target 80 | 81 | 82 | def _yield_discofuse_examples( 83 | input_file): 84 | """Yields DiscoFuse examples. 85 | 86 | The documentation for this format: 87 | https://github.com/google-research-datasets/discofuse#data-format 88 | 89 | Args: 90 | input_file: Path to the input file. 91 | """ 92 | with tf.gfile.GFile(input_file) as f: 93 | for i, line in enumerate(f): 94 | if i == 0: # Skip the header line. 95 | continue 96 | coherent_1, coherent_2, incoherent_1, incoherent_2, _, _, _, _ = ( 97 | line.rstrip('\n').split('\t')) 98 | # Strip because the second coherent sentence might be empty. 99 | fusion = (coherent_1 + ' ' + coherent_2).strip() 100 | yield [incoherent_1, incoherent_2], fusion 101 | 102 | 103 | def read_label_map(path): 104 | """Returns label map read from the given path.""" 105 | with tf.gfile.GFile(path) as f: 106 | if path.endswith('.json'): 107 | return json.load(f) 108 | else: 109 | label_map = {} 110 | empty_line_encountered = False 111 | for tag in f: 112 | tag = tag.strip() 113 | if tag: 114 | label_map[tag] = len(label_map) 115 | else: 116 | if empty_line_encountered: 117 | raise ValueError( 118 | 'There should be no empty lines in the middle of the label map ' 119 | 'file.' 120 | ) 121 | empty_line_encountered = True 122 | return label_map 123 | --------------------------------------------------------------------------------