├── .gitignore
├── AR_architecture.png
├── LICENSE
├── README.md
├── __init__.py
├── bert_example.py
├── chat_rephrase
├── __init__.py
├── predict_for_chat.py
└── score_for_qa.txt
├── compute_lcs.py
├── configs
└── lasertagger_config.json
├── curLine_file.py
├── domain_rephrase
├── __init__.py
├── predict_for_domain.py
└── rephrase_for_domain.sh
├── get_pairs_chinese
├── __init__.py
├── curLine_file.py
├── get_text_pair_lcqmc.py
├── get_text_pair_shixi.py
├── get_text_pair_sv.py
└── merge_split_corpus.py
├── official_transformer
├── README
├── __init__.py
├── attention_layer.py
├── beam_search.py
├── embedding_layer.py
├── ffn_layer.py
├── model_params.py
├── model_utils.py
├── tpu.py
└── transformer.py
├── phrase_vocabulary_optimization.py
├── predict_main.py
├── predict_utils.py
├── prediction.txt
├── preprocess_main.py
├── qa_rephrase
├── __init__.py
├── predict_for_qa.py
└── score_for_qa.txt
├── rephrase.sh
├── rephrase_for_chat.sh
├── rephrase_for_qa.sh
├── rephrase_for_skill.sh
├── rephrase_server.sh
├── rephrase_server
├── __init__.py
├── rephrase_server_flask.py
└── test_server.py
├── requirements.txt
├── run_lasertagger.py
├── run_lasertagger_utils.py
├── sari_hook.py
├── score_lib.py
├── score_main.py
├── sentence_fusion_task.png
├── skill_rephrase
├── __init__.py
├── predict_for_skill.py
└── score_for_skill.txt
├── tagging.py
├── tagging_converter.py
├── transformer_decoder.py
└── utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 | *-notice.txt
6 | score.txt
7 | .idea/
8 | # C extensions
9 | *.so
10 |
11 | # Distribution / packaging
12 | .Python
13 | build/
14 | develop-eggs/
15 | dist/
16 | downloads/
17 | eggs/
18 | .eggs/
19 | lib/
20 | lib64/
21 | parts/
22 | sdist/
23 | var/
24 | wheels/
25 | pip-wheel-metadata/
26 | share/python-wheels/
27 | *.egg-info/
28 | .installed.cfg
29 | *.egg
30 | MANIFEST
31 |
32 | # PyInstaller
33 | # Usually these files are written by a python script from a template
34 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
35 | *.manifest
36 | *.spec
37 |
38 | # Installer logs
39 | pip-log.txt
40 | pip-delete-this-directory.txt
41 |
42 | # Unit test / coverage reports
43 | htmlcov/
44 | .tox/
45 | .nox/
46 | .coverage
47 | .coverage.*
48 | .cache
49 | nosetests.xml
50 | coverage.xml
51 | *.cover
52 | *.py,cover
53 | .hypothesis/
54 | .pytest_cache/
55 |
56 | # Translations
57 | *.mo
58 | *.pot
59 |
60 | # Django stuff:
61 | *.log
62 | local_settings.py
63 | db.sqlite3
64 | db.sqlite3-journal
65 |
66 | # Flask stuff:
67 | instance/
68 | .webassets-cache
69 |
70 | # Scrapy stuff:
71 | .scrapy
72 |
73 | # Sphinx documentation
74 | docs/_build/
75 |
76 | # PyBuilder
77 | target/
78 |
79 | # Jupyter Notebook
80 | .ipynb_checkpoints
81 |
82 | # IPython
83 | profile_default/
84 | ipython_config.py
85 |
86 | # pyenv
87 | .python-version
88 |
89 | # pipenv
90 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
91 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
92 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
93 | # install all needed dependencies.
94 | #Pipfile.lock
95 |
96 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
97 | __pypackages__/
98 |
99 | # Celery stuff
100 | celerybeat-schedule
101 | celerybeat.pid
102 |
103 | # SageMath parsed files
104 | *.sage.py
105 |
106 | # Environments
107 | .env
108 | .venv
109 | env/
110 | venv/
111 | ENV/
112 | env.bak/
113 | venv.bak/
114 |
115 | # Spyder project settings
116 | .spyderproject
117 | .spyproject
118 |
119 | # Rope project settings
120 | .ropeproject
121 |
122 | # mkdocs documentation
123 | /site
124 |
125 | # mypy
126 | .mypy_cache/
127 | .dmypy.json
128 | dmypy.json
129 |
130 | # Pyre type checker
131 | .pyre/
132 |
--------------------------------------------------------------------------------
/AR_architecture.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Mleader2/text_scalpel/131922c5dcfc558a46a7590925e348c39cb24951/AR_architecture.png
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # LaserTagger
2 | 一.概述
3 | 文本复述任务是指把一句/段文本A改写成文本B,要求文本B采用与文本A略有差异的表述方式来表达与之意思相近的文本。
4 | 改进谷歌的LaserTagger模型,使用LCQMC等中文语料训练文本复述模型,即修改一段文本并保持原有语义。
5 | 复述的结果可用于数据增强,文本泛化,从而增加特定场景的语料规模,提高模型泛化能力。
6 |
7 |
8 | 二.模型介绍
9 | 谷歌在文献《Encode, Tag, Realize: High-Precision Text Editing》中采用序列标注的框架进行文本编辑,在文本拆分和自动摘要任务上取得了最佳效果。
10 | 在同样采用BERT作为编码器的条件下,本方法相比于Seq2Seq的方法具有更高的可靠度,更快的训练和推理效率,且在语料规模较小的情况下优势更明显。
11 |
12 |

13 |
14 | 谷歌公开了本文献对应的代码,但是原有任务与当前任务有一定的差异性,需要修改部分代码,主要修改如下:
15 | A.分词方式:原代码针对英文,以空格为间隔分成若干词。现在针对中文,分成若干字。
16 | B.推理效率:原代码每次只对一个文本进行复述,改成每次对batch_size个文本进行复述,推理效率提高6倍。
17 |
18 | 三.文件说明和实验步骤
19 | 1.安装python模块
20 | 参见"requirements.txt", "rephrase.sh"
21 | 2.下载预训练模型
22 | 考虑模型推理的效率,目前本项目采用RoBERTa-tiny-clue(中文版)预训练模型。
23 | 由于目前网络上有不同版本,现将本项目使用的预训练模型上传的百度网盘。链接: https://pan.baidu.com/s/1yho8ihR9C6rBbY-IJjSagA 提取码: 2a97
24 | 如果想采用其他预训练模型,请修改“configs/lasertagger_config.json".
25 | 3.训练和评测模型
26 | 根据自己情况修改脚本"rephrase.sh"中2个文件夹的路径,然后运行 bash rephrase.sh HOST_NAME
27 | 变量HOST_NAME是作者为了方便设定路径使用的,请根据自己情况修改;
28 | 如果只是离线的对文本进行批量的泛化,可以注释脚本中其他部分,只用predict_main.py就可以满足需求。
29 | 4.启动文本复述服务 根据自己需要,可选
30 | 根据自己情况修改"rephrase_server.sh"文件中几个文件夹的路径,使用命令"sh rephrase_server.sh"可以启动一个文本复述的API服务.
31 | 本API服务可以接收一个http的POST请求,解析并对其中的文本进行泛化,具体接口请看“rephrase_server/rephrase_server_flask.py"
32 |
33 | 文本复述的语料需要自己整理语义一致的文本对。如果用自己业务场景下的语料最好,当然数量不能太少,如果没有或不够就加上LCQMC等语料中的正例。
34 | 然后用最长公共子串的长度限制一下,因为这个方法要求source和target的字面表达不能差异过大,可以参考一下“get_text_pair_lcqmc.py”。
35 | 目前,我的train.txt,tune.txt中都是三列即text1,text2,lcs_score,之间用tab"\t"分割。
36 |
37 | 有几个脚本文件如rephrase_for_qa.sh,rephrase_for_chat.sh,rephrase_for_skill.sh是作者自己办公需要的,可以忽略
38 |
39 | 四.实验效果
40 | 1. 在公开数据集Wiki Split上复现模型:
41 | Wiki Split数据集是英文语料,训练模型将一句话拆分成两句话,并保持语义一致,语法合理,语义连贯通顺,如下图所示。
42 |
43 | 
44 |
45 | Exact score=15,SARI score=61.5,KEEP score=93,ADDITION score=32,DELETION score=59,
46 | 基本与论文中的Exact score=15.2;SARI score=61.7一致(这些分数均为越高越好)。
47 | 2. 在自己构造的中文数据集训练文本复述模型:
48 | (1)语料来源
49 | (A)一部分语料来自于LCQMC语料中的正例,即语义接近的一对文本;
50 | (B)另一部分语料来自于业务FAQ下面同一答案的问题;
51 | 因为模型的原理,要求文本A和B在具有一定的重合字数,故过滤掉上述两个来源中字面表述差异大的文本,如“我要去厕所”与“卫生间在哪里”。对语料筛选后对模型进行训练和测试。
52 | (2)测试结果:
53 | 对25918对文本进行复述和自动化评估,评测分数如下(越高越好):
54 | Exact score=29,SARI score=64,KEEP score=84,ADDITION score=39,DELETION score=66.
55 | CPU上耗时0.5小时,平均复述一句话需要0.72秒。
56 | 可能是语言和任务不同,在中文文本复述上的评测分数比公开数据集高一些。
57 |
58 | 五.一些trick
59 | 1.可以设定对于某些字或词不做修改
60 | 如对实体识别NER的语料泛化,需要保证模型不能修改其中的实体;
61 | 对业务语料泛化,也可以根据情况保证模型不能修改其中的关键字 如日期,航班号等;
62 | 目前,是通过正则的方式定位这些不能被模型修改的位置,然后将这些位置的location设置为1,具体实现参见tagging.py.
63 | 2.增加复述文本与原文本的差异度
64 | 可以对训练语料中的text_a先进行随机的swag操作,相应地脚本中enable_swap_tag改为true,再训练模型将其改写为text_b;
65 | 实际应用或测试时同样将原始文本text_a先进行随机的swag操作,然后利用模型改写为text_b;
66 | 因为训练语料中text_a是不通顺,但text_b是通顺的,所以实际应用或测试时仍然会得到通顺的复述结果。
67 |
68 | ## How to Cite LaserTagger
69 |
70 | ```
71 | @inproceedings{malmi2019lasertagger,
72 | title={Encode, Tag, Realize: High-Precision Text Editing},
73 | author={Eric Malmi and Sebastian Krause and Sascha Rothe and Daniil Mirylenka and Aliaksei Severyn},
74 | booktitle={EMNLP-IJCNLP},
75 | year={2019}
76 | }
77 | ```
78 |
79 | ## License
80 |
81 | Apache 2.0; see [LICENSE](LICENSE) for details.
82 |
--------------------------------------------------------------------------------
/__init__.py:
--------------------------------------------------------------------------------
1 | import compute_lcs
--------------------------------------------------------------------------------
/bert_example.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2019 The Google Research Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Build BERT Examples from text (source, target) pairs."""
17 |
18 | from __future__ import absolute_import
19 | from __future__ import division
20 | from __future__ import print_function
21 | import collections
22 | import tagging
23 | import tensorflow as tf
24 | from utils import my_tokenizer_class
25 | from curLine_file import curLine
26 |
27 | class BertExample(object):
28 | """Class for training and inference examples for BERT.
29 |
30 | Attributes:
31 | editing_task: The EditingTask from which this example was created. Needed
32 | when realizing labels predicted for this example.
33 | features: Feature dictionary.
34 | """
35 |
36 | def __init__(self, input_ids,
37 | input_mask,
38 | segment_ids, labels,
39 | labels_mask,
40 | token_start_indices,
41 | task, default_label):
42 | input_len = len(input_ids)
43 | if not (input_len == len(input_mask) and input_len == len(segment_ids) and
44 | input_len == len(labels) and input_len == len(labels_mask)):
45 | raise ValueError(
46 | 'All feature lists should have the same length ({})'.format(
47 | input_len))
48 |
49 | self.features = collections.OrderedDict([
50 | ('input_ids', input_ids),
51 | ('input_mask', input_mask),
52 | ('segment_ids', segment_ids),
53 | ('labels', labels),
54 | ('labels_mask', labels_mask),
55 | ])
56 | self._token_start_indices = token_start_indices
57 | self.editing_task = task
58 | self._default_label = default_label
59 |
60 | def pad_to_max_length(self, max_seq_length, pad_token_id):
61 | """Pad the feature vectors so that they all have max_seq_length.
62 |
63 | Args:
64 | max_seq_length: The length that features will have after padding.
65 | pad_token_id: input_ids feature is padded with this ID, other features
66 | with ID 0.
67 | """
68 | pad_len = max_seq_length - len(self.features['input_ids'])
69 | for key in self.features:
70 | pad_id = pad_token_id if key == 'input_ids' else 0
71 | self.features[key].extend([pad_id] * pad_len)
72 | if len(self.features[key]) != max_seq_length:
73 | raise ValueError('{} has length {} (should be {}).'.format(
74 | key, len(self.features[key]), max_seq_length))
75 |
76 | def to_tf_example(self):
77 | """Returns this object as a tf.Example."""
78 |
79 | def int_feature(values):
80 | return tf.train.Feature(int64_list=tf.train.Int64List(value=list(values)))
81 |
82 | tf_features = collections.OrderedDict([
83 | (key, int_feature(val)) for key, val in self.features.items()
84 | ])
85 | return tf.train.Example(features=tf.train.Features(feature=tf_features))
86 |
87 | def get_token_labels(self):
88 | """Returns labels/tags for the original tokens, not for wordpieces."""
89 | labels = []
90 | for idx in self._token_start_indices:
91 | # For unmasked and untruncated tokens, use the label in the features, and
92 | # for the truncated tokens, use the default label.
93 | if (idx < len(self.features['labels']) and
94 | self.features['labels_mask'][idx]):
95 | current_label = self.features['labels'][idx]
96 | # if current_label >= 0:
97 | labels.append(self.features['labels'][idx])
98 | # else: # stop
99 | # labels.append(self._default_label)
100 | else:
101 | labels.append(self._default_label)
102 | if labels[-1]<0:
103 | print(curLine(), idx, len(self.features['labels']), "mask=",self.features['labels_mask'][idx], self.features['labels'][idx], labels[-1] )
104 | return labels
105 |
106 |
107 | class BertExampleBuilder(object):
108 | """Builder class for BertExample objects."""
109 |
110 | def __init__(self, label_map, vocab_file,
111 | max_seq_length, do_lower_case,
112 | converter):
113 | """Initializes an instance of BertExampleBuilder.
114 |
115 | Args:
116 | label_map: Mapping from tags to tag IDs.
117 | vocab_file: Path to BERT vocabulary file.
118 | max_seq_length: Maximum sequence length.
119 | do_lower_case: Whether to lower case the input text. Should be True for
120 | uncased models and False for cased models.
121 | converter: Converter from text targets to tags.
122 | """
123 | self._label_map = label_map
124 | self._tokenizer = my_tokenizer_class(vocab_file, do_lower_case=do_lower_case)
125 | self._max_seq_length = max_seq_length
126 | self._converter = converter
127 | self._pad_id = self._get_pad_id()
128 | self._keep_tag_id = self._label_map['KEEP']
129 |
130 | def build_bert_example(
131 | self,
132 | sources,
133 | target=None,
134 | use_arbitrary_target_ids_for_infeasible_examples=False,
135 | location=None
136 | ):
137 | """Constructs a BERT Example.
138 |
139 | Args:
140 | sources: List of source texts.
141 | target: Target text or None when building an example during inference.
142 | use_arbitrary_target_ids_for_infeasible_examples: Whether to build an
143 | example with arbitrary target ids even if the target can't be obtained
144 | via tagging.
145 |
146 | Returns:
147 | BertExample, or None if the conversion from text to tags was infeasible
148 | and use_arbitrary_target_ids_for_infeasible_examples == False.
149 | """
150 | # Compute target labels.
151 | task = tagging.EditingTask(sources, location=location, tokenizer=self._tokenizer)
152 | if target is not None:
153 | tags = self._converter.compute_tags(task, target, tokenizer=self._tokenizer)
154 | if not tags: # 不可转化,取决于 use_arbitrary_target_ids_for_infeasible_examples
155 | if use_arbitrary_target_ids_for_infeasible_examples:
156 | # Create a tag sequence [KEEP, DELETE, KEEP, DELETE, ...] which is
157 | # unlikely to be predicted by chance.
158 | tags = [tagging.Tag('KEEP') if i % 2 == 0 else tagging.Tag('DELETE')
159 | for i, _ in enumerate(task.source_tokens)]
160 | else:
161 | return None
162 | else:
163 | # If target is not provided, we set all target labels to KEEP.
164 | tags = [tagging.Tag('KEEP') for _ in task.source_tokens]
165 | labels = [self._label_map[str(tag)] for tag in tags]
166 | # tokens, labels, token_start_indices = self._split_to_wordpieces( # wordpiece: tag是以word为单位的,组成word的piece的标注与这个word相同
167 | # task.source_tokens, labels)
168 | if len(task.source_tokens) > self._max_seq_length - 2:
169 | print(curLine(), "%d tokens is to long," % len(task.source_tokens), "truncate task.source_tokens:",
170 | task.source_tokens)
171 | token_start_indices = [indices+1 for indices in range(len(task.source_tokens))]
172 |
173 | # 截断到self._max_seq_length - 2
174 | tokens = self._truncate_list(task.source_tokens)
175 | labels = self._truncate_list(labels)
176 |
177 | input_tokens = ['[CLS]'] + tokens + ['[SEP]']
178 | labels_mask = [0] + [1] * len(labels) + [0]
179 | labels = [0] + labels + [0]
180 |
181 | input_ids = self._tokenizer.convert_tokens_to_ids(input_tokens)
182 | input_mask = [1] * len(input_ids)
183 | segment_ids = [0] * len(input_ids)
184 | example = BertExample(
185 | input_ids=input_ids,
186 | input_mask=input_mask,
187 | segment_ids=segment_ids,
188 | labels=labels,
189 | labels_mask=labels_mask,
190 | token_start_indices=token_start_indices,
191 | task=task,
192 | default_label=self._keep_tag_id)
193 | example.pad_to_max_length(self._max_seq_length, self._pad_id)
194 | return example
195 |
196 | # def _split_to_wordpieces(self, tokens, labels):
197 | # """Splits tokens (and the labels accordingly) to WordPieces.
198 | #
199 | # Args:
200 | # tokens: Tokens to be split.
201 | # labels: Labels (one per token) to be split.
202 | #
203 | # Returns:
204 | # 3-tuple with the split tokens, split labels, and the indices of the
205 | # WordPieces that start a token.
206 | # """
207 | # bert_tokens = [] # Original tokens split into wordpieces.
208 | # bert_labels = [] # Label for each wordpiece.
209 | # # Index of each wordpiece that starts a new token.
210 | # token_start_indices = []
211 | # for i, token in enumerate(tokens):
212 | # # '+ 1' is because bert_tokens will be prepended by [CLS] token later.
213 | # token_start_indices.append(len(bert_tokens) + 1)
214 | # pieces = self._tokenizer.tokenize(token)
215 | # bert_tokens.extend(pieces)
216 | # bert_labels.extend([labels[i]] * len(pieces))
217 | # return bert_tokens, bert_labels, token_start_indices
218 |
219 | def _truncate_list(self, x):
220 | """Returns truncated version of x according to the self._max_seq_length."""
221 | # Save two slots for the first [CLS] token and the last [SEP] token.
222 | return x[:self._max_seq_length - 2]
223 |
224 | def _get_pad_id(self):
225 | """Returns the ID of the [PAD] token (or 0 if it's not in the vocab)."""
226 | try:
227 | return self._tokenizer.convert_tokens_to_ids(['[PAD]'])[0]
228 | except KeyError:
229 | return 0
230 |
--------------------------------------------------------------------------------
/chat_rephrase/__init__.py:
--------------------------------------------------------------------------------
1 | import bert_example
2 | import predict_utils
3 | import tagging_converter
4 | import utils
--------------------------------------------------------------------------------
/chat_rephrase/predict_for_chat.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # 为domain识别泛化语料
3 |
4 | from __future__ import absolute_import
5 | from __future__ import division
6 | from __future__ import print_function
7 | import json
8 | from absl import app
9 | from absl import flags
10 | import os, sys, time
11 | from termcolor import colored
12 | import tensorflow as tf
13 |
14 | block_list = os.path.realpath(__file__).split("/")
15 | path = "/".join(block_list[:-2])
16 | sys.path.append(path)
17 | import bert_example
18 | import predict_utils
19 | import tagging_converter
20 | import utils
21 |
22 | from curLine_file import curLine
23 |
24 | FLAGS = flags.FLAGS
25 |
26 | flags.DEFINE_string(
27 | 'input_file', None,
28 | 'Path to the input file containing examples for which to compute '
29 | 'predictions.')
30 | flags.DEFINE_enum(
31 | 'input_format', None, ['wikisplit', 'discofuse'],
32 | 'Format which indicates how to parse the input_file.')
33 | flags.DEFINE_string(
34 | 'output_file', None,
35 | 'Path to the TSV file where the predictions are written to.')
36 | flags.DEFINE_string(
37 | 'label_map_file', None,
38 | 'Path to the label map file. Either a JSON file ending with ".json", that '
39 | 'maps each possible tag to an ID, or a text file that has one tag per '
40 | 'line.')
41 | flags.DEFINE_string('vocab_file', None, 'Path to the BERT vocabulary file.')
42 | flags.DEFINE_integer('max_seq_length', 128, 'Maximum sequence length.')
43 | flags.DEFINE_bool(
44 | 'do_lower_case', False,
45 | 'Whether to lower case the input text. Should be True for uncased '
46 | 'models and False for cased models.')
47 | flags.DEFINE_bool('enable_swap_tag', True, 'Whether to enable the SWAP tag.')
48 | flags.DEFINE_string('saved_model', None, 'Path to an exported TF model.')
49 |
50 | def predict_and_write(predictor, sources_batch, previous_line_list,context_list, writer, num_predicted, start_time, batch_num):
51 | prediction_batch = predictor.predict_batch(sources_batch=sources_batch)
52 | assert len(prediction_batch) == len(sources_batch)
53 | for id, [prediction, sources] in enumerate(zip(prediction_batch, sources_batch)):
54 | output = ""
55 | if len(prediction) > 1 and prediction != sources: # TODO ignore keep totelly and short prediction
56 | output= "%s%s" % (context_list[id], prediction) # 需要和context拼接么
57 | writer.write("%s\t%s\n" % (previous_line_list[id], output))
58 | batch_num = batch_num + 1
59 | num_predicted += len(prediction_batch)
60 | if batch_num % 200 == 0:
61 | cost_time = (time.time() - start_time) / 3600.0
62 | print("%s batch_id=%d, predict %d examples, cost %.3fh." %
63 | (curLine(), batch_num, num_predicted, cost_time))
64 | return num_predicted, batch_num
65 |
66 |
67 | def main(argv):
68 | if len(argv) > 1:
69 | raise app.UsageError('Too many command-line arguments.')
70 | flags.mark_flag_as_required('input_file')
71 | flags.mark_flag_as_required('input_format')
72 | flags.mark_flag_as_required('output_file')
73 | flags.mark_flag_as_required('label_map_file')
74 | flags.mark_flag_as_required('vocab_file')
75 | flags.mark_flag_as_required('saved_model')
76 |
77 | label_map = utils.read_label_map(FLAGS.label_map_file)
78 | converter = tagging_converter.TaggingConverter(
79 | tagging_converter.get_phrase_vocabulary_from_label_map(label_map),
80 | FLAGS.enable_swap_tag)
81 | builder = bert_example.BertExampleBuilder(label_map, FLAGS.vocab_file,
82 | FLAGS.max_seq_length,
83 | FLAGS.do_lower_case, converter)
84 | predictor = predict_utils.LaserTaggerPredictor(
85 | tf.contrib.predictor.from_saved_model(FLAGS.saved_model), builder,
86 | label_map)
87 | print(colored("%s input file:%s" % (curLine(), FLAGS.input_file), "red"))
88 | sourcesA_list = []
89 | with open(FLAGS.input_file) as f:
90 | for line in f:
91 | json_map = json.loads(line.rstrip('\n'))
92 | sourcesA_list.append(json_map["questions"])
93 | print(curLine(), len(sourcesA_list), "sourcesA_list:", sourcesA_list[-1])
94 | start_time = time.time()
95 | num_predicted = 0
96 | with tf.gfile.Open(FLAGS.output_file, 'w') as writer:
97 | for batch_id,sources_batch in enumerate(sourcesA_list):
98 | # sources_batch = sourcesA_list[batch_id * predict_batch_size: (batch_id + 1) * predict_batch_size]
99 | location_batch = []
100 | for source in sources_batch:
101 | location = list()
102 | for char in source[0]:
103 | if (char>='0' and char<='9') or char in '.- ' or (char>='a' and char<='z') or (char>='A' and char<='Z'):
104 | location.append("1") # TODO TODO
105 | else:
106 | location.append("0")
107 | location_batch.append("".join(location))
108 | prediction_batch = predictor.predict_batch(sources_batch=sources_batch, location_batch=location_batch)
109 | expand_list = []
110 | for prediction in prediction_batch: # TODO
111 | if prediction in sources_batch:
112 | continue
113 | expand_list.append(prediction)
114 |
115 | json_map = {"questions":sources_batch, "expands":expand_list}
116 | json_str = json.dumps(json_map, ensure_ascii=False)
117 | writer.write("%s\n" % json_str)
118 | # input(curLine())
119 | num_predicted += len(expand_list)
120 | if batch_id % 20 == 0:
121 | cost_time = (time.time() - start_time) / 60.0
122 | print("%s batch_id=%d/%d, predict %d/%d examples, cost %.2fmin." %
123 | (curLine(), batch_id + 1, len(sourcesA_list), num_predicted, num_predicted, cost_time))
124 | cost_time = (time.time() - start_time) / 60.0
125 | # print(curLine(), id, prediction_A, prediction_B, "target:", target, "current_batch_size=", current_batch_size)
126 | # print(curLine(), id, "sourceA:", sourceA, "sourceB:", sourceB, "target:", target)
127 | # logging.info(
128 | # f'{curLine()} {num_predicted} predictions saved to:{FLAGS.output_file}, cost {cost_time} min, ave {cost_time / num_predicted*60000}ms.')
129 |
130 | if __name__ == '__main__':
131 | app.run(main)
132 |
--------------------------------------------------------------------------------
/chat_rephrase/score_for_qa.txt:
--------------------------------------------------------------------------------
1 | 1200
2 | Exact score: 12.843
3 | SARI score: 51.313
4 | KEEP score: 76.046
5 | ADDITION score: 20.086
6 | DELETION score: 57.808
7 | cost time 0.337063 h
8 |
9 | 房山 Restoring parameters from /home/wzk/Mywork/corpus/rephrase_corpus/output/models/wikisplit_experiment_name/model.ckpt-3643
10 | Exact score: 13.672
11 | SARI score: 53.450
12 | KEEP score: 77.125
13 | ADDITION score: 21.678
14 | DELETION score: 61.546
15 | cost time 0.260398 h
16 | 继续train epoch从7.0改为30.0
17 | Saving checkpoints for 36436 into /home/wzk/Mywork/corpus/rephrase_corpus/output/models/wikisplit_experiment_name/model.ckpt.
18 | Exact score: 19.414
19 | SARI score: 56.749
20 | KEEP score: 78.622
21 | ADDITION score: 26.927
22 | DELETION score: 64.697
23 | cost time 0.260367 h
24 | model.ckpt-35643
25 | Exact score: 19.440
26 | SARI score: 56.725
27 | KEEP score: 78.620
28 | ADDITION score: 26.917
29 | DELETION score: 64.638
30 | cost time 0.259561 h
31 | Restoring parameters from /home/wzk/Mywork/corpus/rephrase_corpus/output/models/wikisplit_experiment_name/model.ckpt-62436
32 | Exact score: 20.347
33 | SARI score: 57.107
34 | KEEP score: 78.693
35 | ADDITION score: 27.379
36 | DELETION score: 65.250
37 | cost time 0.575784 h
38 |
39 |
40 | 旧的
41 | 完善大小写的处理规则
42 | Exact score: 29.283
43 | SARI score: 64.002
44 | KEEP score: 84.942
45 | ADDITION score: 38.128
46 | DELETION score: 68.935
47 | cost time 0.41796 h
--------------------------------------------------------------------------------
/compute_lcs.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # 函数compute_lcs,用于计算两个列表的Longest Common Subsequence (LCS)
3 |
4 | def _compute_lcs(source, target):
5 | """Computes the Longest Common Subsequence (LCS).
6 |
7 | Description of the dynamic programming algorithm:
8 | https://www.algorithmist.com/index.php/Longest_Common_Subsequence
9 |
10 | Args:
11 | source: List of source tokens.
12 | target: List of target tokens.
13 |
14 | Returns:
15 | List of tokens in the LCS.
16 | """
17 | table = _lcs_table(source, target)
18 | return _backtrack(table, source, target, len(source), len(target))
19 |
20 |
21 | def _lcs_table(source, target):
22 | """Returns the Longest Common Subsequence dynamic programming table."""
23 | rows = len(source)
24 | cols = len(target)
25 | lcs_table = [[0] * (cols + 1) for _ in range(rows + 1)]
26 | for i in range(1, rows + 1):
27 | for j in range(1, cols + 1):
28 | if source[i - 1] == target[j - 1]:
29 | lcs_table[i][j] = lcs_table[i - 1][j - 1] + 1
30 | else:
31 | lcs_table[i][j] = max(lcs_table[i - 1][j], lcs_table[i][j - 1])
32 | return lcs_table
33 |
34 |
35 | def _backtrack(table, source, target, i, j):
36 | """Backtracks the Longest Common Subsequence table to reconstruct the LCS.
37 |
38 | Args:
39 | table: Precomputed LCS table.
40 | source: List of source tokens.
41 | target: List of target tokens.
42 | i: Current row index.
43 | j: Current column index.
44 |
45 | Returns:
46 | List of tokens corresponding to LCS.
47 | """
48 | if i == 0 or j == 0:
49 | return []
50 | if source[i - 1] == target[j - 1]:
51 | # Append the aligned token to output.
52 | return _backtrack(table, source, target, i - 1, j - 1) + [target[j - 1]]
53 | if table[i][j - 1] > table[i - 1][j]:
54 | return _backtrack(table, source, target, i, j - 1)
55 | else:
56 | return _backtrack(table, source, target, i - 1, j)
57 |
--------------------------------------------------------------------------------
/configs/lasertagger_config.json:
--------------------------------------------------------------------------------
1 | {
2 | "attention_probs_dropout_prob": 0.1,
3 | "hidden_act": "gelu",
4 | "hidden_dropout_prob": 0.1,
5 | "hidden_size": 312,
6 | "initializer_range": 0.02,
7 | "intermediate_size": 1248,
8 | "max_position_embeddings": 512,
9 | "num_attention_heads": 12,
10 | "num_hidden_layers": 4,
11 | "type_vocab_size": 2,
12 | "vocab_size": 8021,
13 | "use_t2t_decoder": true,
14 | "decoder_num_hidden_layers": 1,
15 | "decoder_hidden_size": 768,
16 | "decoder_num_attention_heads": 4,
17 | "decoder_filter_size": 3072,
18 | "use_full_attention": false
19 | }
20 |
--------------------------------------------------------------------------------
/curLine_file.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import sys
3 | def curLine():
4 | file_path = sys._getframe().f_back.f_code.co_filename # 获取调用函数的路径
5 | file_name=file_path[file_path.rfind("/") + 1:] # 获取调用函数所在的文件名
6 | lineno=sys._getframe().f_back.f_lineno#当前行号
7 | str="[%s:%s] "%(file_name,lineno)
8 | return str
--------------------------------------------------------------------------------
/domain_rephrase/__init__.py:
--------------------------------------------------------------------------------
1 | import bert_example
2 | import predict_utils
3 | import tagging_converter
4 | import utils
--------------------------------------------------------------------------------
/domain_rephrase/predict_for_domain.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # 为domain识别泛化语料
3 |
4 | from __future__ import absolute_import
5 | from __future__ import division
6 | from __future__ import print_function
7 |
8 | from absl import app
9 | from absl import flags
10 | from absl import logging
11 | import os, sys, time
12 | from termcolor import colored
13 | import tensorflow as tf
14 |
15 | block_list = os.path.realpath(__file__).split("/")
16 | path = "/".join(block_list[:-2])
17 | sys.path.append(path)
18 |
19 | import bert_example
20 | import predict_utils
21 | import tagging_converter
22 | import utils
23 |
24 | from curLine_file import curLine
25 |
26 | FLAGS = flags.FLAGS
27 |
28 | flags.DEFINE_string(
29 | 'input_file', None,
30 | 'Path to the input file containing examples for which to compute '
31 | 'predictions.')
32 | flags.DEFINE_enum(
33 | 'input_format', None, ['wikisplit', 'discofuse'],
34 | 'Format which indicates how to parse the input_file.')
35 | flags.DEFINE_string(
36 | 'output_file', None,
37 | 'Path to the TSV file where the predictions are written to.')
38 | flags.DEFINE_string(
39 | 'label_map_file', None,
40 | 'Path to the label map file. Either a JSON file ending with ".json", that '
41 | 'maps each possible tag to an ID, or a text file that has one tag per '
42 | 'line.')
43 | flags.DEFINE_string('vocab_file', None, 'Path to the BERT vocabulary file.')
44 | flags.DEFINE_integer('max_seq_length', 128, 'Maximum sequence length.')
45 | flags.DEFINE_bool(
46 | 'do_lower_case', False,
47 | 'Whether to lower case the input text. Should be True for uncased '
48 | 'models and False for cased models.')
49 | flags.DEFINE_bool('enable_swap_tag', True, 'Whether to enable the SWAP tag.')
50 | flags.DEFINE_string('saved_model', None, 'Path to an exported TF model.')
51 |
52 | def predict_and_write(predictor, sources_batch, previous_line_list,context_list, writer, num_predicted, start_time, batch_num):
53 | prediction_batch = predictor.predict_batch(sources_batch=sources_batch)
54 | assert len(prediction_batch) == len(sources_batch)
55 | for id, [prediction, sources] in enumerate(zip(prediction_batch, sources_batch)):
56 | output = ""
57 | if len(prediction) > 1 and prediction != sources: # TODO ignore keep totelly and short prediction
58 | output= "%s%s" % (context_list[id], prediction) # 需要和context拼接么
59 | # print(curLine(), "prediction:", prediction, "sources:", sources, ",output:", output, prediction != sources)
60 | writer.write("%s\t%s\n" % (previous_line_list[id], output))
61 | batch_num = batch_num + 1
62 | num_predicted += len(prediction_batch)
63 | if batch_num % 200 == 0:
64 | cost_time = (time.time() - start_time) / 3600.0
65 | print("%s batch_id=%d, predict %d examples, cost %.3fh." %
66 | (curLine(), batch_num, num_predicted, cost_time))
67 | return num_predicted, batch_num
68 |
69 | def remove_p(text_pre):
70 | text = text_pre[:9].replace('\n', '').replace('\r', '').replace('news', '').replace('food','').replace('GetPath', '') \
71 | .replace('GetYear','').replace('GetDate', '').replace('flight', '').replace('weather', '').replace('currency', '').replace('stock', '').replace('story', '') \
72 | .replace('drama', '').replace('jokes', '').replace('other','').replace('poetry', '').replace('GetLunar', '')+text_pre[9:]
73 | if len(text) == len(text_pre) and len(text)>9:
74 | text = text.replace('stoytelling', '').replace('storytypes', '').replace(
75 | 'SetPathPlace', '').replace('GetPoetryByTitle', '').replace('GetTitles', '') \
76 | .replace('GetSolarterm', '').replace('GetLastPhrases', '').replace('SelectPlaceIndex','').replace(
77 | 'photo.tag', '').replace('photo.rac', '').replace('GetSuitAndAvoid', '') \
78 | .replace('GetOnePoetry', '').replace('SetPathTrans', '').replace('navigation', '').replace(
79 | 'GetNextPhrases', '').replace('crosstalk', '')
80 | if len(text) == len(text_pre) and len(text) > 5:
81 | text = text.replace('currency,SetCurrentcy', '').replace('SetCurrentcy', '').replace('GetWeekDay', '').replace('GetAuthors', '') \
82 | .replace('trafficrestr', '').replace('GetTranslates', '').replace('GetAuthorNames','') \
83 | .replace('MatchTitle', '').replace('WeiboFirst', '').replace('music','').replace('times', '').replace('photo', '')
84 | return text
85 |
86 | def main(argv):
87 | if len(argv) > 1:
88 | raise app.UsageError('Too many command-line arguments.')
89 | flags.mark_flag_as_required('input_file')
90 | flags.mark_flag_as_required('input_format')
91 | flags.mark_flag_as_required('output_file')
92 | flags.mark_flag_as_required('label_map_file')
93 | flags.mark_flag_as_required('vocab_file')
94 | flags.mark_flag_as_required('saved_model')
95 |
96 | label_map = utils.read_label_map(FLAGS.label_map_file)
97 | converter = tagging_converter.TaggingConverter(
98 | tagging_converter.get_phrase_vocabulary_from_label_map(label_map),
99 | FLAGS.enable_swap_tag)
100 | builder = bert_example.BertExampleBuilder(label_map, FLAGS.vocab_file,
101 | FLAGS.max_seq_length,
102 | FLAGS.do_lower_case, converter)
103 | predictor = predict_utils.LaserTaggerPredictor(
104 | tf.contrib.predictor.from_saved_model(FLAGS.saved_model), builder,
105 | label_map)
106 | print(colored("%s input file:%s" % (curLine(), FLAGS.input_file), "red"))
107 | predict_batch_size = 64
108 | batch_num = 0
109 | num_predicted = 0
110 | with tf.gfile.Open(FLAGS.output_file, 'w') as writer:
111 | with open(FLAGS.input_file, "r") as f:
112 | sources_batch = []
113 | previous_line_list = []
114 | context_list = []
115 | line_number = 0
116 | start_time = time.time()
117 | while True:
118 | line_number +=1
119 | line = f.readline().rstrip('\n').strip("\"").strip(" ")
120 | if len(line) == 0:
121 | break
122 |
123 | column_index = line.index(",")
124 | text = line[column_index+1:].strip("\"") # context and query
125 | # for charChinese_id, char in enumerate(line[column_index+1:]):
126 | # if (char>='a' and char<='z') or (char>='A' and char<='Z'):
127 | # continue
128 | # else:
129 | # break
130 | source = remove_p(text)
131 | if source not in text: # TODO ignore的就给空字符串,这样输出也是空字符串
132 | print(curLine(), "line_number=%d, ignore:%s" % (line_number, text), ",source:", len(source), source)
133 | source = ""
134 | # continue
135 | context_list.append(text[:text.index(source)])
136 | previous_line_list.append(line)
137 | sources_batch.append(source)
138 | if len(sources_batch) == predict_batch_size:
139 | num_predicted, batch_num = predict_and_write(predictor, sources_batch,
140 | previous_line_list,context_list, writer, num_predicted, start_time, batch_num)
141 | sources_batch = []
142 | previous_line_list = []
143 | context_list = []
144 | # if num_predicted > 1000:
145 | # break
146 | if len(context_list)>0:
147 | num_predicted, batch_num = predict_and_write(predictor, sources_batch,
148 | previous_line_list, context_list, writer, num_predicted,
149 | start_time, batch_num)
150 | cost_time = (time.time() - start_time) / 60.0
151 | logging.info(
152 | f'{curLine()} {num_predicted} predictions saved to:{FLAGS.output_file}, cost {cost_time} min, ave {cost_time / num_predicted/60} hours.')
153 |
154 | if __name__ == '__main__':
155 | app.run(main)
156 |
--------------------------------------------------------------------------------
/domain_rephrase/rephrase_for_domain.sh:
--------------------------------------------------------------------------------
1 | # 扩充技能的语料
2 | # rephrase_for_skill.sh: 在rephrase.sh基础上改的
3 | # predict_for_skill.py: 在 predict_main.py基础上改的
4 | # score_for_skill.txt 结果对比
5 |
6 |
7 | # 成都
8 | # pyenv activate python373tf115
9 | # pip install -i https://pypi.douban.com/simple/ bert-tensorflow==1.0.1
10 | #pip install -i https://pypi.douban.com/simple/ tensorflow==1.15.0
11 | #python -m pip install --upgrade pip -i https://pypi.douban.com/simple
12 |
13 | # set gpu id to use
14 | export CUDA_VISIBLE_DEVICES=""
15 |
16 | # 房山
17 | # pyenv activate python363tf111
18 | # pip install bert-tensorflow==1.0.1
19 |
20 | #scp -r /home/cloudminds/PycharmProjects/lasertagger-Chinese/predict_main.py cloudminds@10.13.33.128:/home/cloudminds/PycharmProjects/lasertagger-Chinese
21 | #scp -r cloudminds@10.13.33.128:/home/wzk/Mywork/corpus/文本复述/output/models/wikisplit_experiment_name /home/cloudminds/Mywork/corpus/文本复述/output/models/
22 | # watch -n 1 nvidia-smi
23 |
24 | start_tm=`date +%s%N`;
25 |
26 | export HOST_NAME="cloudminds" # "wzk" #
27 | ### Optional parameters ###
28 |
29 | # If you train multiple models on the same data, change this label.
30 | EXPERIMENT=wikisplit_experiment
31 | # To quickly test that model training works, set the number of epochs to a
32 | # smaller value (e.g. 0.01).
33 | NUM_EPOCHS=10.0
34 | export TRAIN_BATCH_SIZE=256 # 512 OOM 256 OK
35 | PHRASE_VOCAB_SIZE=500
36 | MAX_INPUT_EXAMPLES=1000000
37 | SAVE_CHECKPOINT_STEPS=200
38 | export enable_swap_tag=false
39 | export output_arbitrary_targets_for_infeasible_examples=false
40 | export WIKISPLIT_DIR="/home/${HOST_NAME}/Mywork/corpus/rephrase_corpus"
41 | export OUTPUT_DIR="${WIKISPLIT_DIR}/output"
42 |
43 | #python phrase_vocabulary_optimization.py \
44 | # --input_file=${WIKISPLIT_DIR}/train.txt \
45 | # --input_format=wikisplit \
46 | # --vocabulary_size=500 \
47 | # --max_input_examples=1000000 \
48 | # --enable_swap_tag=${enable_swap_tag} \
49 | # --output_file=${OUTPUT_DIR}/label_map.txt
50 |
51 |
52 | export max_seq_length=40 # TODO
53 | export BERT_BASE_DIR="/home/${HOST_NAME}/Mywork/model/RoBERTa-tiny-clue" # chinese_L-12_H-768_A-12"
54 |
55 |
56 |
57 |
58 | # Check these numbers from the "*.num_examples" files created in step 2.
59 | export CONFIG_FILE=configs/lasertagger_config.json
60 | export EXPERIMENT=wikisplit_experiment_name
61 |
62 |
63 |
64 | ### 4. Prediction
65 |
66 | # Export the model.
67 | #python run_lasertagger.py \
68 | # --label_map_file=${OUTPUT_DIR}/label_map.txt \
69 | # --model_config_file=${CONFIG_FILE} \
70 | # --output_dir=${OUTPUT_DIR}/models/${EXPERIMENT} \
71 | # --do_export=true \
72 | # --export_path=${OUTPUT_DIR}/models/${EXPERIMENT}/export
73 |
74 | ## Get the most recently exported model directory.
75 | TIMESTAMP=$(ls "${OUTPUT_DIR}/models/${EXPERIMENT}/export/" | \
76 | grep -v "temp-" | sort -r | head -1)
77 | SAVED_MODEL_DIR=${OUTPUT_DIR}/models/${EXPERIMENT}/export/${TIMESTAMP}
78 | PREDICTION_FILE=${OUTPUT_DIR}/models/${EXPERIMENT}/pred.tsv
79 |
80 | python domain_rephrase/predict_for_domain.py \
81 | --input_file=/home/${HOST_NAME}/Mywork/corpus/ner_corpus/domain_corpus/train3.csv \
82 | --input_format=wikisplit \
83 | --output_file=/home/${HOST_NAME}/Mywork/corpus/ner_corpus/domain_corpus/train3_expand.csv \
84 | --label_map_file=${OUTPUT_DIR}/label_map.txt \
85 | --vocab_file=${BERT_BASE_DIR}/vocab.txt \
86 | --max_seq_length=${max_seq_length} \
87 | --saved_model=${SAVED_MODEL_DIR}
88 |
89 | #### 5. Evaluation
90 | #python score_main.py --prediction_file=${PREDICTION_FILE}
91 |
92 |
93 | end_tm=`date +%s%N`;
94 | use_tm=`echo $end_tm $start_tm | awk '{ print ($1 - $2) / 1000000000 /3600}'`
95 | echo "cost time" $use_tm "h"
--------------------------------------------------------------------------------
/get_pairs_chinese/__init__.py:
--------------------------------------------------------------------------------
1 | # 获取文本复述(rephrase)任务的语料
2 | import curLine_file
3 | import compute_lcs
--------------------------------------------------------------------------------
/get_pairs_chinese/curLine_file.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import sys
3 | def curLine():
4 | file_path = sys._getframe().f_back.f_code.co_filename # 获取调用函数的路径
5 | file_name=file_path[file_path.rfind("/") + 1:] # 获取调用函数所在的文件名
6 | lineno=sys._getframe().f_back.f_lineno#当前行号
7 | str="[%s:%s] "%(file_name,lineno)
8 | return str
--------------------------------------------------------------------------------
/get_pairs_chinese/get_text_pair_lcqmc.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # 利用文本匹配的语料,从正例中采样得到句子对(A,B),然后训练模型把A改写成B
3 | # 当前是针对LCQMC 改了几条语料的标注
4 | import random
5 | import os
6 | import sys
7 | sys.path.append("..")
8 | from compute_lcs import _compute_lcs
9 | from curLine_file import curLine
10 |
11 | def process(corpus_folder, raw_file_name, save_folder):
12 | corpus_list = []
13 | for name in raw_file_name:
14 | raw_file = os.path.join(corpus_folder, name)
15 | with open(raw_file, "r") as fr:
16 | lines = fr.readlines()
17 |
18 | for i ,line in enumerate(lines):
19 | source, target, label = line.strip().split("\t")
20 | if label=="0" or source==target:
21 | continue
22 | if label != "1":
23 | input(curLine()+line.strip())
24 | length = float(len(source) + len(target))
25 |
26 | source_length = len(source)
27 | if source_length > 8 and source_length<38 and (i+1)%2>0: # 对50%的长句构造交换操作
28 | rand = random.uniform(0.4, 0.9)
29 | source_pre = source
30 | swag_location = int(source_length*rand)
31 | source = "%s%s" % (source[swag_location:], source[:swag_location])
32 | lcs1 = _compute_lcs(source, target)
33 | lcs_rate= len(lcs1)/length
34 | if (lcs_rate<0.4):# 差异大,换回来
35 | source = source_pre
36 | else:
37 | print(curLine(), "source_pre:%s, source:%s, lcs_rate=%f" % (source_pre, source, lcs_rate))
38 |
39 | lcs1 = _compute_lcs(source, target)
40 | lcs_rate = len(lcs1) / length
41 | if (lcs_rate<0.2):
42 | continue # 变动过大,忽略
43 |
44 | # if (lcs_rate<0.4):
45 | # continue # 变动过大,忽略
46 | # if len(source)*1.15 < len(target):
47 | # new_t = source
48 | # source = target
49 | # target = new_t
50 | # print(curLine(), source, target, ",lcs1:",lcs1 , ",lcs_rate=", lcs_rate)
51 | corpus = "%s\t%s\t%f\n" % (source, target, lcs_rate)
52 | corpus_list.append(corpus)
53 | print(curLine(), len(corpus_list), "from %s" % raw_file)
54 | save_file = os.path.join(save_folder, "lcqmc.txt")
55 | with open(save_file, "w") as fw:
56 | fw.writelines(corpus_list)
57 | print(curLine(), "have save %d to %s" % (len(corpus_list), save_file))
58 |
59 | if __name__ == "__main__":
60 | corpus_folder = "/home/cloudminds/Mywork/corpus/Chinese_QA/LCQMC"
61 | raw_file_name = ["train.txt", "dev.txt", "test.txt"]
62 | save_folder = "/home/cloudminds/Mywork/corpus/rephrase_corpus"
63 | process(corpus_folder, raw_file_name, save_folder)
64 |
65 |
--------------------------------------------------------------------------------
/get_pairs_chinese/get_text_pair_shixi.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # 世西给的QA对(至少是答案相同)的句子在一行,从中采样得到句子对(A,B),然后训练模型把A改写成B
3 | # 当前是针对宝安机场 比较旧且少的数据集,
4 | import os
5 | from compute_lcs import _compute_lcs
6 | from curLine_file import curLine
7 |
8 | def process(corpus_folder, raw_file_name, save_folder):
9 | raw_file = os.path.join(corpus_folder, raw_file_name)
10 | with open(raw_file, "r") as fr:
11 | lines = fr.readlines()
12 | corpus_list = []
13 | for line in lines:
14 | sent_list = line.strip().split("&&")
15 | sent_num = len(sent_list)
16 | for i in range(1, sent_num, 2):
17 | source= sent_list[i-1]
18 | target = sent_list[i]
19 | length = float(len(source) + len(target))
20 | lcs1 = _compute_lcs(source, target)
21 | lcs_rate= len(lcs1)/length
22 | if (lcs_rate<0.3):
23 | continue # 变动过大,忽略
24 | if len(source)*1.15 < len(target):
25 | new_t = source
26 | source = target
27 | target = new_t
28 | corpus = "%s\t%s\t%f\n" % (source, target, lcs_rate)
29 | corpus_list.append(corpus)
30 | save_file = os.path.join(save_folder, "baoan_airport.txt")
31 | with open(save_file, "w") as fw:
32 | fw.writelines(corpus_list)
33 | print(curLine(), "have save %d to %s" % (len(corpus_list), save_file))
34 |
35 |
36 |
37 |
38 | if __name__ == "__main__":
39 | corpus_folder = "/home/cloudminds/Mywork/corpus/Chinese_QA/baoanairport"
40 | raw_file_name = "baoan_airport_processed.txt"
41 | save_folder = "/home/cloudminds/Mywork/corpus/rephrase_corpus"
42 | process(corpus_folder, raw_file_name, save_folder)
43 |
44 |
--------------------------------------------------------------------------------
/get_pairs_chinese/get_text_pair_sv.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # 从SV导出后的xlsx文件,相同含义(至少是答案相同)的句子在一行,从中采样得到句子对(A,B),然后训练模型把A改写成B
3 | # 当前是针对宝安机场
4 | import sys
5 | sys.path.append("..")
6 | import os
7 | import xlrd # 引入模块
8 | import random
9 | from compute_lcs import _compute_lcs
10 | from curLine_file import curLine
11 |
12 | def process(corpus_folder, raw_file_name):
13 | raw_file = os.path.join(corpus_folder, raw_file_name)
14 | # 打开文件,获取excel文件的workbook(工作簿)对象
15 | workbook = xlrd.open_workbook(raw_file) # 文件路径
16 |
17 | # 通过sheet索引获得sheet对象
18 | worksheet = workbook.sheet_by_index(0)
19 | nrows = worksheet.nrows # 获取该表总行数
20 | ncols = worksheet.ncols # 获取该表总列数
21 | print(curLine(), "raw_file_name:%s, worksheet:%s nrows=%d, ncols=%d" % (raw_file_name, worksheet.name,nrows, ncols))
22 | assert ncols == 3
23 | assert nrows > 0
24 | col_data = worksheet.col_values(0) # 获取第一列的内容
25 | corpus_list = []
26 | for line in col_data:
27 | sent_list = line.strip().split("&&")
28 | sent_num = len(sent_list)
29 | for i in range(1, sent_num, 2):
30 | source= sent_list[i-1]
31 | target = sent_list[i]
32 | # source_length = len(source)
33 | # if source_length > 8 and (i+1)%4>0: # 对50%的长句随机删除
34 | # rand = random.uniform(0.1, 0.9)
35 | # source_pre = source
36 | # swag_location = int(source_length*rand)
37 | # source = "%s%s" % (source[:swag_location], source[swag_location+1:])
38 | # print(curLine(), "source_pre:%s, source:%s" % (source_pre, source))
39 |
40 | length = float(len(source) + len(target))
41 | lcs1 = _compute_lcs(source, target)
42 | lcs_rate= len(lcs1)/length
43 | if (lcs_rate<0.2):
44 | continue # 变动过大,忽略
45 |
46 | # if (lcs_rate<0.3):
47 | # continue # 变动过大,忽略
48 | # if len(source)*1.15 < len(target):
49 | # new_t = source
50 | # source = target
51 | # target = new_t
52 | corpus = "%s\t%s\t%f\n" % (source, target, lcs_rate)
53 | corpus_list.append(corpus)
54 | return corpus_list
55 |
56 | def main(corpus_folder, save_folder):
57 | fileList = os.listdir(corpus_folder)
58 | corpus_list_total = []
59 | for raw_file_name in fileList:
60 | corpus_list = process(corpus_folder, raw_file_name)
61 | print(curLine(), raw_file_name, len(corpus_list))
62 | corpus_list_total.extend(corpus_list)
63 | save_file = os.path.join(save_folder, "baoan_airport_from_xlsx.txt")
64 | with open(save_file, "w") as fw:
65 | fw.writelines(corpus_list_total)
66 | print(curLine(), "have save %d to %s" % (len(corpus_list_total), save_file))
67 |
68 |
69 |
70 |
71 | if __name__ == "__main__":
72 | corpus_folder = "/home/cloudminds/Mywork/corpus/Chinese_QA/baoanairport/agent842-3月2日"
73 |
74 | save_folder = "/home/cloudminds/Mywork/corpus/rephrase_corpus"
75 | raw_file_name = "专业知识导出记录.xlsx"
76 | main(corpus_folder, save_folder)
77 |
78 |
--------------------------------------------------------------------------------
/get_pairs_chinese/merge_split_corpus.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # 对不同来源的语料融合,然后再划分出train,dev,test
3 | import os
4 | import numpy as np
5 | from curLine_file import curLine
6 |
7 | def merge(raw_file_name_list, save_folder):
8 | corpus_list = []
9 | for raw_file_name in raw_file_name_list:
10 | raw_file = os.path.join(save_folder, "%s.txt" % raw_file_name)
11 | with open(raw_file) as fr:
12 | lines = fr.readlines()
13 | corpus_list.extend(lines)
14 | if "baoan" in raw_file_name:
15 | corpus_list.extend(lines) # TODO
16 | return corpus_list
17 |
18 | def split(corpus_list, save_folder, trainRate=0.8):
19 | corpusNum = len(corpus_list)
20 | shuffle_indices = list(np.random.permutation(range(corpusNum)))
21 | indexTrain = int(trainRate * corpusNum)
22 | # indexDev= int((trainRate + devRate) * corpusNum)
23 | corpusTrain = []
24 | for i in shuffle_indices[:indexTrain]:
25 | corpusTrain.append(corpus_list[i])
26 | save_file = os.path.join(save_folder, "train.txt")
27 | with open(save_file, "w") as fw:
28 | fw.writelines(corpusTrain)
29 | print(curLine(), "have save %d to %s" % (len(corpusTrain), save_file))
30 |
31 | corpusDev = []
32 | for i in shuffle_indices[indexTrain:]: # TODO all corpus
33 | corpusDev.append(corpus_list[i])
34 | save_file = os.path.join(save_folder, "tune.txt")
35 | with open(save_file, "w") as fw:
36 | fw.writelines(corpusDev)
37 | print(curLine(), "have save %d to %s" % (len(corpusDev), save_file))
38 |
39 |
40 | save_file = os.path.join(save_folder, "test.txt")
41 | with open(save_file, "w") as fw:
42 | fw.writelines(corpusDev)
43 | print(curLine(), "have save %d to %s" % (len(corpusDev), save_file))
44 |
45 |
46 |
47 |
48 | if __name__ == "__main__":
49 | raw_file_name = ["baoan_airport", "lcqmc", "baoan_airport_from_xlsx"]
50 | save_folder = "/home/cloudminds/Mywork/corpus/rephrase_corpus"
51 | corpus_list = merge(raw_file_name, save_folder)
52 | split(corpus_list, save_folder, trainRate=0.8)
53 |
54 |
--------------------------------------------------------------------------------
/official_transformer/README:
--------------------------------------------------------------------------------
1 | This directory contains Transformer related code files. These are copied from:
2 | https://github.com/tensorflow/models/tree/master/official/transformer
3 | to make it possible to install all LaserTagger dependencies with pip (the
4 | official Transformer implementation doesn't support pip installation).
--------------------------------------------------------------------------------
/official_transformer/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Mleader2/text_scalpel/131922c5dcfc558a46a7590925e348c39cb24951/official_transformer/__init__.py
--------------------------------------------------------------------------------
/official_transformer/attention_layer.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The TensorFlow Authors. All Rights Reserved.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | # ==============================================================================
16 | """Implementation of multiheaded attention and self-attention layers."""
17 |
18 | from __future__ import absolute_import
19 | from __future__ import division
20 | from __future__ import print_function
21 |
22 | import tensorflow as tf
23 |
24 |
25 | class Attention(tf.layers.Layer):
26 | """Multi-headed attention layer."""
27 |
28 | def __init__(self, hidden_size, num_heads, attention_dropout, train):
29 | if hidden_size % num_heads != 0:
30 | raise ValueError("Hidden size must be evenly divisible by the number of "
31 | "heads.")
32 |
33 | super(Attention, self).__init__()
34 | self.hidden_size = hidden_size
35 | self.num_heads = num_heads
36 | self.attention_dropout = attention_dropout
37 | self.train = train
38 |
39 | # Layers for linearly projecting the queries, keys, and values.
40 | self.q_dense_layer = tf.layers.Dense(hidden_size, use_bias=False, name="q")
41 | self.k_dense_layer = tf.layers.Dense(hidden_size, use_bias=False, name="k")
42 | self.v_dense_layer = tf.layers.Dense(hidden_size, use_bias=False, name="v")
43 |
44 | self.output_dense_layer = tf.layers.Dense(hidden_size, use_bias=False,
45 | name="output_transform")
46 |
47 | def split_heads(self, x):
48 | """Split x into different heads, and transpose the resulting value.
49 |
50 | The tensor is transposed to insure the inner dimensions hold the correct
51 | values during the matrix multiplication.
52 |
53 | Args:
54 | x: A tensor with shape [batch_size, length, hidden_size]
55 |
56 | Returns:
57 | A tensor with shape [batch_size, num_heads, length, hidden_size/num_heads]
58 | """
59 | with tf.name_scope("split_heads"):
60 | batch_size = tf.shape(x)[0]
61 | length = tf.shape(x)[1]
62 |
63 | # Calculate depth of last dimension after it has been split.
64 | depth = (self.hidden_size // self.num_heads)
65 |
66 | # Split the last dimension
67 | x = tf.reshape(x, [batch_size, length, self.num_heads, depth])
68 |
69 | # Transpose the result
70 | return tf.transpose(x, [0, 2, 1, 3])
71 |
72 | def combine_heads(self, x):
73 | """Combine tensor that has been split.
74 |
75 | Args:
76 | x: A tensor [batch_size, num_heads, length, hidden_size/num_heads]
77 |
78 | Returns:
79 | A tensor with shape [batch_size, length, hidden_size]
80 | """
81 | with tf.name_scope("combine_heads"):
82 | batch_size = tf.shape(x)[0]
83 | length = tf.shape(x)[2]
84 | x = tf.transpose(x, [0, 2, 1, 3]) # --> [batch, length, num_heads, depth]
85 | return tf.reshape(x, [batch_size, length, self.hidden_size])
86 |
87 | def call(self, x, y, bias, cache=None):
88 | """Apply attention mechanism to x and y.
89 |
90 | Args:
91 | x: a tensor with shape [batch_size, length_x, hidden_size]
92 | y: a tensor with shape [batch_size, length_y, hidden_size]
93 | bias: attention bias that will be added to the result of the dot product.
94 | cache: (Used during prediction) dictionary with tensors containing results
95 | of previous attentions. The dictionary must have the items:
96 | {"k": tensor with shape [batch_size, i, key_channels],
97 | "v": tensor with shape [batch_size, i, value_channels]}
98 | where i is the current decoded length.
99 |
100 | Returns:
101 | Attention layer output with shape [batch_size, length_x, hidden_size]
102 | """
103 | # Linearly project the query (q), key (k) and value (v) using different
104 | # learned projections. This is in preparation of splitting them into
105 | # multiple heads. Multi-head attention uses multiple queries, keys, and
106 | # values rather than regular attention (which uses a single q, k, v).
107 | q = self.q_dense_layer(x)
108 | k = self.k_dense_layer(y)
109 | v = self.v_dense_layer(y)
110 |
111 | if cache is not None:
112 | # Combine cached keys and values with new keys and values.
113 | k = tf.concat([cache["k"], k], axis=1)
114 | v = tf.concat([cache["v"], v], axis=1)
115 |
116 | # Update cache
117 | cache["k"] = k
118 | cache["v"] = v
119 |
120 | # Split q, k, v into heads.
121 | q = self.split_heads(q)
122 | k = self.split_heads(k)
123 | v = self.split_heads(v)
124 |
125 | # Scale q to prevent the dot product between q and k from growing too large.
126 | depth = (self.hidden_size // self.num_heads)
127 | q *= depth ** -0.5
128 |
129 | # Calculate dot product attention
130 | logits = tf.matmul(q, k, transpose_b=True)
131 | logits += bias
132 | weights = tf.nn.softmax(logits, name="attention_weights")
133 | if self.train:
134 | weights = tf.nn.dropout(weights, 1.0 - self.attention_dropout)
135 | attention_output = tf.matmul(weights, v)
136 |
137 | # Recombine heads --> [batch_size, length, hidden_size]
138 | attention_output = self.combine_heads(attention_output)
139 |
140 | # Run the combined outputs through another linear projection layer.
141 | attention_output = self.output_dense_layer(attention_output)
142 | return attention_output
143 |
144 |
145 | class SelfAttention(Attention):
146 | """Multiheaded self-attention layer."""
147 |
148 | def call(self, x, bias, cache=None):
149 | return super(SelfAttention, self).call(x, x, bias, cache)
150 |
--------------------------------------------------------------------------------
/official_transformer/embedding_layer.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The TensorFlow Authors. All Rights Reserved.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | # ==============================================================================
16 | """Implementation of embedding layer with shared weights."""
17 |
18 | from __future__ import absolute_import
19 | from __future__ import division
20 | from __future__ import print_function
21 |
22 | import tensorflow as tf # pylint: disable=g-bad-import-order
23 |
24 | from official_transformer import tpu as tpu_utils
25 |
26 |
27 | class EmbeddingSharedWeights(tf.layers.Layer):
28 | """Calculates input embeddings and pre-softmax linear with shared weights."""
29 |
30 | def __init__(self, vocab_size, hidden_size, method="gather"):
31 | """Specify characteristic parameters of embedding layer.
32 |
33 | Args:
34 | vocab_size: Number of tokens in the embedding. (Typically ~32,000)
35 | hidden_size: Dimensionality of the embedding. (Typically 512 or 1024)
36 | method: Strategy for performing embedding lookup. "gather" uses tf.gather
37 | which performs well on CPUs and GPUs, but very poorly on TPUs. "matmul"
38 | one-hot encodes the indicies and formulates the embedding as a sparse
39 | matrix multiplication. The matmul formulation is wasteful as it does
40 | extra work, however matrix multiplication is very fast on TPUs which
41 | makes "matmul" considerably faster than "gather" on TPUs.
42 | """
43 | super(EmbeddingSharedWeights, self).__init__()
44 | self.vocab_size = vocab_size
45 | self.hidden_size = hidden_size
46 | if method not in ("gather", "matmul"):
47 | raise ValueError("method {} must be 'gather' or 'matmul'".format(method))
48 | self.method = method
49 |
50 | def build(self, _):
51 | with tf.variable_scope("embedding_and_softmax", reuse=tf.AUTO_REUSE):
52 | # Create and initialize weights. The random normal initializer was chosen
53 | # randomly, and works well.
54 | self.shared_weights = tf.get_variable(
55 | "weights", [self.vocab_size, self.hidden_size],
56 | initializer=tf.random_normal_initializer(
57 | 0., self.hidden_size ** -0.5))
58 |
59 | self.built = True
60 |
61 | def call(self, x):
62 | """Get token embeddings of x.
63 |
64 | Args:
65 | x: An int64 tensor with shape [batch_size, length]
66 | Returns:
67 | embeddings: float32 tensor with shape [batch_size, length, embedding_size]
68 | padding: float32 tensor with shape [batch_size, length] indicating the
69 | locations of the padding tokens in x.
70 | """
71 | with tf.name_scope("embedding"):
72 | # Create binary mask of size [batch_size, length]
73 | mask = tf.to_float(tf.not_equal(x, 0))
74 |
75 | if self.method == "gather":
76 | embeddings = tf.gather(self.shared_weights, x)
77 | embeddings *= tf.expand_dims(mask, -1)
78 | else: # matmul
79 | embeddings = tpu_utils.embedding_matmul(
80 | embedding_table=self.shared_weights,
81 | values=tf.cast(x, dtype=tf.int32),
82 | mask=mask
83 | )
84 | # embedding_matmul already zeros out masked positions, so
85 | # `embeddings *= tf.expand_dims(mask, -1)` is unnecessary.
86 |
87 |
88 | # Scale embedding by the sqrt of the hidden size
89 | embeddings *= self.hidden_size ** 0.5
90 |
91 | return embeddings
92 |
93 |
94 | def linear(self, x):
95 | """Computes logits by running x through a linear layer.
96 |
97 | Args:
98 | x: A float32 tensor with shape [batch_size, length, hidden_size]
99 | Returns:
100 | float32 tensor with shape [batch_size, length, vocab_size].
101 | """
102 | with tf.name_scope("presoftmax_linear"):
103 | batch_size = tf.shape(x)[0]
104 | length = tf.shape(x)[1]
105 |
106 | x = tf.reshape(x, [-1, self.hidden_size])
107 | logits = tf.matmul(x, self.shared_weights, transpose_b=True)
108 |
109 | return tf.reshape(logits, [batch_size, length, self.vocab_size])
110 |
--------------------------------------------------------------------------------
/official_transformer/ffn_layer.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The TensorFlow Authors. All Rights Reserved.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | # ==============================================================================
16 | """Implementation of fully connected network."""
17 |
18 | from __future__ import absolute_import
19 | from __future__ import division
20 | from __future__ import print_function
21 |
22 | import tensorflow as tf
23 |
24 |
25 | class FeedFowardNetwork(tf.layers.Layer):
26 | """Fully connected feedforward network."""
27 |
28 | def __init__(self, hidden_size, filter_size, relu_dropout, train, allow_pad):
29 | super(FeedFowardNetwork, self).__init__()
30 | self.hidden_size = hidden_size
31 | self.filter_size = filter_size
32 | self.relu_dropout = relu_dropout
33 | self.train = train
34 | self.allow_pad = allow_pad
35 |
36 | self.filter_dense_layer = tf.layers.Dense(
37 | filter_size, use_bias=True, activation=tf.nn.relu, name="filter_layer")
38 | self.output_dense_layer = tf.layers.Dense(
39 | hidden_size, use_bias=True, name="output_layer")
40 |
41 | def call(self, x, padding=None):
42 | """Return outputs of the feedforward network.
43 |
44 | Args:
45 | x: tensor with shape [batch_size, length, hidden_size]
46 | padding: (optional) If set, the padding values are temporarily removed
47 | from x (provided self.allow_pad is set). The padding values are placed
48 | back in the output tensor in the same locations.
49 | shape [batch_size, length]
50 |
51 | Returns:
52 | Output of the feedforward network.
53 | tensor with shape [batch_size, length, hidden_size]
54 | """
55 | padding = None if not self.allow_pad else padding
56 |
57 | # Retrieve dynamically known shapes
58 | batch_size = tf.shape(x)[0]
59 | length = tf.shape(x)[1]
60 |
61 | if padding is not None:
62 | with tf.name_scope("remove_padding"):
63 | # Flatten padding to [batch_size*length]
64 | pad_mask = tf.reshape(padding, [-1])
65 |
66 | nonpad_ids = tf.to_int32(tf.where(pad_mask < 1e-9))
67 |
68 | # Reshape x to [batch_size*length, hidden_size] to remove padding
69 | x = tf.reshape(x, [-1, self.hidden_size])
70 | x = tf.gather_nd(x, indices=nonpad_ids)
71 |
72 | # Reshape x from 2 dimensions to 3 dimensions.
73 | x.set_shape([None, self.hidden_size])
74 | x = tf.expand_dims(x, axis=0)
75 |
76 | output = self.filter_dense_layer(x)
77 | if self.train:
78 | output = tf.nn.dropout(output, 1.0 - self.relu_dropout)
79 | output = self.output_dense_layer(output)
80 |
81 | if padding is not None:
82 | with tf.name_scope("re_add_padding"):
83 | output = tf.squeeze(output, axis=0)
84 | output = tf.scatter_nd(
85 | indices=nonpad_ids,
86 | updates=output,
87 | shape=[batch_size * length, self.hidden_size]
88 | )
89 | output = tf.reshape(output, [batch_size, length, self.hidden_size])
90 | return output
91 |
--------------------------------------------------------------------------------
/official_transformer/model_params.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The TensorFlow Authors. All Rights Reserved.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | # ==============================================================================
16 | """Defines Transformer model parameters."""
17 |
18 | from collections import defaultdict
19 |
20 |
21 | BASE_PARAMS = defaultdict(
22 | lambda: None, # Set default value to None.
23 |
24 | # Input params
25 | default_batch_size=2048, # Maximum number of tokens per batch of examples.
26 | default_batch_size_tpu=32768,
27 | max_length=256, # Maximum number of tokens per example.
28 |
29 | # Model params
30 | initializer_gain=1.0, # Used in trainable variable initialization.
31 | vocab_size=33708, # Number of tokens defined in the vocabulary file.
32 | hidden_size=512, # Model dimension in the hidden layers.
33 | num_hidden_layers=6, # Number of layers in the encoder and decoder stacks.
34 | num_heads=8, # Number of heads to use in multi-headed attention.
35 | filter_size=2048, # Inner layer dimension in the feedforward network.
36 |
37 | # Dropout values (only used when training)
38 | layer_postprocess_dropout=0.1,
39 | attention_dropout=0.1,
40 | relu_dropout=0.1,
41 |
42 | # Training params
43 | label_smoothing=0.1,
44 | learning_rate=2.0,
45 | learning_rate_decay_rate=1.0,
46 | learning_rate_warmup_steps=16000,
47 |
48 | # Optimizer params
49 | optimizer_adam_beta1=0.9,
50 | optimizer_adam_beta2=0.997,
51 | optimizer_adam_epsilon=1e-09,
52 |
53 | # Default prediction params
54 | extra_decode_length=50,
55 | beam_size=4,
56 | alpha=0.6, # used to calculate length normalization in beam search
57 |
58 | # TPU specific parameters
59 | use_tpu=False,
60 | static_batch=False,
61 | allow_ffn_pad=True,
62 | )
63 |
64 | BIG_PARAMS = BASE_PARAMS.copy()
65 | BIG_PARAMS.update(
66 | default_batch_size=4096,
67 |
68 | # default batch size is smaller than for BASE_PARAMS due to memory limits.
69 | default_batch_size_tpu=16384,
70 |
71 | hidden_size=1024,
72 | filter_size=4096,
73 | num_heads=16,
74 | )
75 |
76 | # Parameters for running the model in multi gpu. These should not change the
77 | # params that modify the model shape (such as the hidden_size or num_heads).
78 | BASE_MULTI_GPU_PARAMS = BASE_PARAMS.copy()
79 | BASE_MULTI_GPU_PARAMS.update(
80 | learning_rate_warmup_steps=8000
81 | )
82 |
83 | BIG_MULTI_GPU_PARAMS = BIG_PARAMS.copy()
84 | BIG_MULTI_GPU_PARAMS.update(
85 | layer_postprocess_dropout=0.3,
86 | learning_rate_warmup_steps=8000
87 | )
88 |
89 | # Parameters for testing the model
90 | TINY_PARAMS = BASE_PARAMS.copy()
91 | TINY_PARAMS.update(
92 | default_batch_size=1024,
93 | default_batch_size_tpu=1024,
94 | hidden_size=32,
95 | num_heads=4,
96 | filter_size=256,
97 | )
98 |
--------------------------------------------------------------------------------
/official_transformer/model_utils.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The TensorFlow Authors. All Rights Reserved.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | # ==============================================================================
16 | """Transformer model helper methods."""
17 |
18 | from __future__ import absolute_import
19 | from __future__ import division
20 | from __future__ import print_function
21 |
22 | import math
23 |
24 | import numpy as np
25 | import tensorflow as tf
26 |
27 | # Very low numbers to represent -infinity. We do not actually use -Inf, since we
28 | # want to be able to multiply these values by zero to get zero. (-Inf * 0 = NaN)
29 | _NEG_INF_FP32 = -1e9
30 | _NEG_INF_FP16 = np.finfo(np.float16).min
31 |
32 |
33 | def get_position_encoding(
34 | length, hidden_size, min_timescale=1.0, max_timescale=1.0e2): # TODO previous max_timescale=1.0e4
35 | """Return positional encoding.
36 |
37 | Calculates the position encoding as a mix of sine and cosine functions with
38 | geometrically increasing wavelengths.
39 | Defined and formulized in Attention is All You Need, section 3.5.
40 |
41 | Args:
42 | length: Sequence length.
43 | hidden_size: Size of the
44 | min_timescale: Minimum scale that will be applied at each position
45 | max_timescale: Maximum scale that will be applied at each position
46 |
47 | Returns:
48 | Tensor with shape [length, hidden_size]
49 | """
50 | # We compute the positional encoding in float32 even if the model uses
51 | # float16, as many of the ops used, like log and exp, are numerically unstable
52 | # in float16.
53 | position = tf.cast(tf.range(length), tf.float32)
54 | num_timescales = hidden_size // 2
55 | log_timescale_increment = (
56 | math.log(float(max_timescale) / float(min_timescale)) /
57 | (tf.cast(num_timescales, tf.float32) - 1))
58 | inv_timescales = min_timescale * tf.exp(
59 | tf.cast(tf.range(num_timescales), tf.float32) * -log_timescale_increment)
60 | scaled_time = tf.expand_dims(position, 1) * tf.expand_dims(inv_timescales, 0)
61 | signal = tf.concat([tf.sin(scaled_time), tf.cos(scaled_time)], axis=1)
62 | return signal
63 |
64 |
65 | def get_decoder_self_attention_bias(length, dtype=tf.float32):
66 | """Calculate bias for decoder that maintains model's autoregressive property.
67 |
68 | Creates a tensor that masks out locations that correspond to illegal
69 | connections, so prediction at position i cannot draw information from future
70 | positions.
71 |
72 | Args:
73 | length: int length of sequences in batch.
74 | dtype: The dtype of the return value.
75 |
76 | Returns:
77 | float tensor of shape [1, 1, length, length]
78 | """
79 | neg_inf = _NEG_INF_FP16 if dtype == tf.float16 else _NEG_INF_FP32
80 | with tf.name_scope("decoder_self_attention_bias"):
81 | valid_locs = tf.linalg.band_part(input=tf.ones([length, length], dtype=dtype), # 2个length分别是输入和输出的最大长度
82 | num_lower=-1, num_upper=0) # Lower triangular part is 1, other is 0
83 | valid_locs = tf.reshape(valid_locs, [1, 1, length, length])
84 | decoder_bias = neg_inf * (1.0 - valid_locs)
85 | return decoder_bias
86 |
87 |
88 | def get_padding(x, padding_value=0, dtype=tf.float32):
89 | """Return float tensor representing the padding values in x.
90 |
91 | Args:
92 | x: int tensor with any shape
93 | padding_value: int value that
94 | dtype: The dtype of the return value.
95 |
96 | Returns:
97 | float tensor with same shape as x containing values 0 or 1.
98 | 0 -> non-padding, 1 -> padding
99 | """
100 | with tf.name_scope("padding"):
101 | return tf.cast(tf.equal(x, padding_value), dtype)
102 |
103 |
104 | def get_padding_bias(x):
105 | """Calculate bias tensor from padding values in tensor.
106 |
107 | Bias tensor that is added to the pre-softmax multi-headed attention logits,
108 | which has shape [batch_size, num_heads, length, length]. The tensor is zero at
109 | non-padding locations, and -1e9 (negative infinity) at padding locations.
110 |
111 | Args:
112 | x: int tensor with shape [batch_size, length]
113 |
114 | Returns:
115 | Attention bias tensor of shape [batch_size, 1, 1, length].
116 | """
117 | with tf.name_scope("attention_bias"):
118 | padding = get_padding(x)
119 | attention_bias = padding * _NEG_INF_FP32
120 | attention_bias = tf.expand_dims(
121 | tf.expand_dims(attention_bias, axis=1), axis=1)
122 | return attention_bias
123 |
--------------------------------------------------------------------------------
/official_transformer/tpu.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The TensorFlow Authors. All Rights Reserved.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | # ==============================================================================
16 | """Functions specific to running TensorFlow on TPUs."""
17 |
18 | import tensorflow as tf
19 |
20 |
21 | # "local" is a magic word in the TPU cluster resolver; it informs the resolver
22 | # to use the local CPU as the compute device. This is useful for testing and
23 | # debugging; the code flow is ostensibly identical, but without the need to
24 | # actually have a TPU on the other end.
25 | LOCAL = "local"
26 |
27 |
28 | def construct_scalar_host_call(metric_dict, model_dir, prefix=""):
29 | """Construct a host call to log scalars when training on TPU.
30 |
31 | Args:
32 | metric_dict: A dict of the tensors to be logged.
33 | model_dir: The location to write the summary.
34 | prefix: The prefix (if any) to prepend to the metric names.
35 |
36 | Returns:
37 | A tuple of (function, args_to_be_passed_to_said_function)
38 | """
39 | # type: (dict, str) -> (function, list)
40 | metric_names = list(metric_dict.keys())
41 |
42 | def host_call_fn(global_step, *args):
43 | """Training host call. Creates scalar summaries for training metrics.
44 |
45 | This function is executed on the CPU and should not directly reference
46 | any Tensors in the rest of the `model_fn`. To pass Tensors from the
47 | model to the `metric_fn`, provide as part of the `host_call`. See
48 | https://www.tensorflow.org/api_docs/python/tf/contrib/tpu/TPUEstimatorSpec
49 | for more information.
50 |
51 | Arguments should match the list of `Tensor` objects passed as the second
52 | element in the tuple passed to `host_call`.
53 |
54 | Args:
55 | global_step: `Tensor with shape `[batch]` for the global_step
56 | *args: Remaining tensors to log.
57 |
58 | Returns:
59 | List of summary ops to run on the CPU host.
60 | """
61 | step = global_step[0]
62 | with tf.contrib.summary.create_file_writer(
63 | logdir=model_dir, filename_suffix=".host_call").as_default():
64 | with tf.contrib.summary.always_record_summaries():
65 | for i, name in enumerate(metric_names):
66 | tf.contrib.summary.scalar(prefix + name, args[i][0], step=step)
67 |
68 | return tf.contrib.summary.all_summary_ops()
69 |
70 | # To log the current learning rate, and gradient norm for Tensorboard, the
71 | # summary op needs to be run on the host CPU via host_call. host_call
72 | # expects [batch_size, ...] Tensors, thus reshape to introduce a batch
73 | # dimension. These Tensors are implicitly concatenated to
74 | # [params['batch_size']].
75 | global_step_tensor = tf.reshape(
76 | tf.compat.v1.train.get_or_create_global_step(), [1])
77 | other_tensors = [tf.reshape(metric_dict[key], [1]) for key in metric_names]
78 |
79 | return host_call_fn, [global_step_tensor] + other_tensors
80 |
81 |
82 | def embedding_matmul(embedding_table, values, mask, name="embedding_matmul"):
83 | """Performs embedding lookup via a matmul.
84 |
85 | The matrix to be multiplied by the embedding table Tensor is constructed
86 | via an implementation of scatter based on broadcasting embedding indices
87 | and performing an equality comparison against a broadcasted
88 | range(num_embedding_table_rows). All masked positions will produce an
89 | embedding vector of zeros.
90 |
91 | Args:
92 | embedding_table: Tensor of embedding table.
93 | Rank 2 (table_size x embedding dim)
94 | values: Tensor of embedding indices. Rank 2 (batch x n_indices)
95 | mask: Tensor of mask / weights. Rank 2 (batch x n_indices)
96 | name: Optional name scope for created ops
97 |
98 | Returns:
99 | Rank 3 tensor of embedding vectors.
100 | """
101 |
102 | with tf.name_scope(name):
103 | n_embeddings = embedding_table.get_shape().as_list()[0]
104 | batch_size, padded_size = values.shape.as_list()
105 |
106 | emb_idcs = tf.tile(
107 | tf.reshape(values, (batch_size, padded_size, 1)), (1, 1, n_embeddings))
108 | emb_weights = tf.tile(
109 | tf.reshape(mask, (batch_size, padded_size, 1)), (1, 1, n_embeddings))
110 | col_idcs = tf.tile(
111 | tf.reshape(tf.range(n_embeddings), (1, 1, n_embeddings)),
112 | (batch_size, padded_size, 1))
113 | one_hot = tf.where(
114 | tf.equal(emb_idcs, col_idcs), emb_weights,
115 | tf.zeros((batch_size, padded_size, n_embeddings)))
116 |
117 | return tf.tensordot(one_hot, embedding_table, 1)
118 |
--------------------------------------------------------------------------------
/phrase_vocabulary_optimization.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2019 The Google Research Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | # Lint as: python3
17 | """Optimizes the vocabulary of phrases that can be added by LaserTagger.
18 |
19 | The goal is to find a fixed-size set of phrases that cover as many training
20 | examples as possible. Based on the phrases, saves a file containing all possible
21 | tags to be predicted and another file reporting the percentage of covered
22 | training examples with different vocabulary sizes.
23 | """
24 |
25 | from __future__ import absolute_import
26 | from __future__ import division
27 |
28 | from __future__ import print_function
29 |
30 | import collections
31 | from typing import Sequence, Text
32 |
33 | from absl import app
34 | from absl import flags
35 | from absl import logging
36 |
37 | import utils
38 |
39 | import numpy as np
40 | import scipy.sparse
41 | import tensorflow as tf
42 | from compute_lcs import _compute_lcs
43 | from utils import my_tokenizer_class
44 | from curLine_file import curLine
45 |
46 | FLAGS = flags.FLAGS
47 |
48 | flags.DEFINE_string(
49 | 'input_file', None,
50 | 'Path to the input file containing source-target pairs from which the '
51 | 'vocabulary is optimized (see `input_format` flag and utils.py for '
52 | 'documentation).')
53 | flags.DEFINE_enum(
54 | 'input_format', None, ['wikisplit', 'discofuse'],
55 | 'Format which indicates how to parse the `input_file`. See utils.py for '
56 | 'documentation on the different formats.')
57 | flags.DEFINE_integer(
58 | 'max_input_examples', 50000,
59 | 'At most this many examples from the `input_file` are used for optimizing '
60 | 'the vocabulary.')
61 | flags.DEFINE_string(
62 | 'output_file', None,
63 | 'Path to the resulting file with all possible tags. Coverage numbers will '
64 | 'be written to a separate file which has the same path but ".log" appended '
65 | 'to it.')
66 | flags.DEFINE_bool('enable_swap_tag', True, 'Whether to enable the SWAP tag.')
67 | flags.DEFINE_integer('vocabulary_size', 500,
68 | 'Number of phrases to include in the vocabulary.')
69 | flags.DEFINE_integer(
70 | 'num_extra_statistics', 100,
71 | 'Number of extra phrases that are not included in the vocabulary but for '
72 | 'which we compute the coverage numbers. These numbers help determining '
73 | 'whether the vocabulary size should have been larger.')
74 |
75 | flags.DEFINE_string('vocab_file', None, 'Path to the BERT vocabulary file.')
76 | flags.DEFINE_bool(
77 | 'do_lower_case', True,
78 | 'Whether to lower case the input text. Should be True for uncased '
79 | 'models and False for cased models.')
80 |
81 | def _get_added_phrases(source: Text, target: Text, tokenizer) -> Sequence[Text]:
82 | """Computes the phrases that need to be added to the source to get the target.
83 |
84 | This is done by aligning each token in the LCS to the first match in the
85 | target and checking which phrases in the target remain unaligned.
86 |
87 | TODO(b/142853960): The LCS tokens should ideally be aligned to consecutive(连续不断的)
88 | target tokens whenever possible, instead of aligning them always to the first
89 | match. This should result in a more meaningful phrase vocabulary with a higher
90 | coverage.
91 |
92 | Note that the algorithm is case-insensitive and the resulting phrases are
93 | always lowercase.
94 |
95 | Args:
96 | source: Source text.
97 | target: Target text.
98 |
99 | Returns:
100 | List of added phrases.
101 | """
102 | # 英文是分成word sep=' ',中文是分成字 sep=''
103 | sep = ''
104 | source_tokens = tokenizer.tokenize(source)
105 | target_tokens = tokenizer.tokenize(target)
106 |
107 | kept_tokens = _compute_lcs(source_tokens, target_tokens)
108 | added_phrases = []
109 | # Index of the `kept_tokens` element that we are currently looking for.
110 | kept_idx = 0
111 | phrase = []
112 | for token in target_tokens:
113 | if kept_idx < len(kept_tokens) and token == kept_tokens[kept_idx]:
114 | kept_idx += 1
115 | if phrase:
116 | added_phrases.append(sep.join(phrase))
117 | phrase = []
118 | else:
119 | phrase.append(token)
120 | if phrase:
121 | added_phrases.append(sep.join(phrase))
122 | return added_phrases
123 |
124 |
125 | def _added_token_counts(data_iterator, try_swapping, max_input_examples=10000, tokenizer=None):
126 | """Computes how many times different phrases have to be added.
127 |
128 | Args:
129 | data_iterator: Iterator to yield source lists and targets. See function
130 | yield_sources_and_targets in utils.py for the available iterators. The
131 | strings in the source list will be concatenated, possibly after swapping
132 | their order if swapping is enabled.
133 | try_swapping: Whether to try if swapping sources results in less added text.
134 | max_input_examples: Maximum number of examples to be read from the iterator.
135 |
136 | Returns:
137 | Tuple (collections.Counter for phrases, added phrases for each example).
138 | """
139 | phrase_counter = collections.Counter()
140 | num_examples = 0
141 | all_added_phrases = []
142 | max_seq_length = 0
143 | for sources, target in data_iterator:
144 | # sources 可能是多句话,后面用空格拼接起来
145 | if num_examples >= max_input_examples:
146 | break
147 | source_merge = ' '.join(sources)
148 | if len(source_merge) > max_seq_length:
149 | print(curLine(), "max_seq_length=%d, len(source_merge)=%d,source_merge:%s" %
150 | (max_seq_length, len(source_merge), source_merge))
151 | max_seq_length = len(source_merge)
152 | logging.log_every_n(logging.INFO, f'{num_examples} examples processed.', 10000)
153 | added_phrases = _get_added_phrases(source_merge, target, tokenizer)
154 | if try_swapping and len(sources) == 2:
155 | added_phrases_swap = _get_added_phrases(' '.join(sources[::-1]), target, tokenizer)
156 | # If we can align more and have to add less after swapping, we assume that
157 | # the sources would be swapped during conversion.
158 | if len(''.join(added_phrases_swap)) < len(''.join(added_phrases)):
159 | added_phrases = added_phrases_swap
160 | for phrase in added_phrases:
161 | phrase_counter[phrase] += 1
162 | all_added_phrases.append(added_phrases)
163 | num_examples += 1
164 | logging.info(f'{num_examples} examples processed.\n')
165 | return phrase_counter, all_added_phrases, max_seq_length
166 |
167 |
168 | def _construct_added_phrases_matrix(all_added_phrases, phrase_counter):
169 | """Constructs a sparse phrase occurrence matrix.
170 |
171 | Examples are on rows and phrases on columns.
172 |
173 | Args:
174 | all_added_phrases: List of lists of added phrases (one list per example).
175 | phrase_counter: Frequence of each unique added phrase.
176 |
177 | Returns:
178 | Sparse boolean matrix whose element (i, j) indicates whether example i
179 | contains the added phrase j. Columns start from the most frequent phrase.
180 | """
181 | phrase_2_idx = {
182 | tup[0]: i for i, tup in enumerate(phrase_counter.most_common())
183 | }
184 | matrix = scipy.sparse.dok_matrix((len(all_added_phrases), len(phrase_2_idx)),
185 | dtype=np.bool)
186 | for i, added_phrases in enumerate(all_added_phrases):
187 | for phrase in added_phrases:
188 | phrase_idx = phrase_2_idx[phrase]
189 | matrix[i, phrase_idx] = True
190 | # Convert to CSC format to support more efficient column slicing.
191 | return matrix.tocsc()
192 |
193 |
194 | def _count_covered_examples(matrix, vocabulary_size):
195 | """Returns the number of examples whose added phrases are in the vocabulary.
196 |
197 | This assumes the vocabulary is created simply by selecting the
198 | `vocabulary_size` most frequent phrases.
199 |
200 | Args:
201 | matrix: Phrase occurrence matrix with the most frequent phrases on the
202 | left-most columns.
203 | vocabulary_size: Number of most frequent phrases to include in the
204 | vocabulary.
205 | """
206 | # Ignore the `vocabulary_size` most frequent (i.e. leftmost) phrases (i.e.
207 | # columns) and count the rows with zero added phrases.
208 | return (matrix[:, vocabulary_size:].sum(axis=1) == 0).sum()
209 |
210 |
211 | def main(argv):
212 | if len(argv) > 1:
213 | raise app.UsageError('Too many command-line arguments.')
214 | flags.mark_flag_as_required('input_file')
215 | flags.mark_flag_as_required('input_format')
216 | flags.mark_flag_as_required('output_file')
217 |
218 | data_iterator = utils.yield_sources_and_targets(FLAGS.input_file,
219 | FLAGS.input_format)
220 | tokenizer = my_tokenizer_class(FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case)
221 | phrase_counter, all_added_phrases, max_seq_length = _added_token_counts(
222 | data_iterator, FLAGS.enable_swap_tag, FLAGS.max_input_examples, tokenizer)
223 | matrix = _construct_added_phrases_matrix(all_added_phrases, phrase_counter)
224 | num_examples = len(all_added_phrases)
225 |
226 | statistics_file = FLAGS.output_file + '.log'
227 | with tf.io.gfile.GFile(FLAGS.output_file, 'w') as writer:
228 | with tf.io.gfile.GFile(statistics_file, 'w') as stats_writer:
229 | stats_writer.write('Idx\tFrequency\tCoverage (%)\tPhrase\n')
230 | writer.write('KEEP\n')
231 | writer.write('DELETE\n')
232 | if FLAGS.enable_swap_tag:
233 | writer.write('SWAP\n')
234 | for i, (phrase, count) in enumerate(
235 | phrase_counter.most_common(FLAGS.vocabulary_size +
236 | FLAGS.num_extra_statistics)):
237 | # Write tags.
238 | if i < FLAGS.vocabulary_size: # TODO 为什么要限制一个phrase既能在KEEP前,也能在DELETE前??
239 | writer.write(f'KEEP|{phrase}\n')
240 | writer.write(f'DELETE|{phrase}\n')
241 | # Write statistics.
242 | coverage = 100.0 * _count_covered_examples(matrix, i + 1) / num_examples # 用前i+1个高频phrase能覆盖的语料的比例
243 | stats_writer.write(f'{i + 1}\t{count}\t{coverage:.2f}\t{phrase}\n')
244 | logging.info(f'Wrote tags to: {FLAGS.output_file}')
245 | logging.info(f'Wrote coverage numbers to: {statistics_file}')
246 | print(curLine(), "max_seq_length=", max_seq_length)
247 |
248 |
249 | if __name__ == '__main__':
250 | app.run(main)
251 |
--------------------------------------------------------------------------------
/predict_main.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2019 The Google Research Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | # Lint as: python3
17 | """Compute realized predictions for a dataset."""
18 |
19 | from __future__ import absolute_import
20 | from __future__ import division
21 | from __future__ import print_function
22 |
23 | from absl import app
24 | from absl import flags
25 | from absl import logging
26 | import math, time
27 | from termcolor import colored
28 | import tensorflow as tf
29 |
30 | import bert_example
31 | import predict_utils
32 | import tagging_converter
33 | import utils
34 | from utils import my_tokenizer_class
35 | from curLine_file import curLine
36 |
37 | FLAGS = flags.FLAGS
38 |
39 | flags.DEFINE_string(
40 | 'input_file', None,
41 | 'Path to the input file containing examples for which to compute '
42 | 'predictions.')
43 | flags.DEFINE_enum(
44 | 'input_format', None, ['wikisplit', 'discofuse'],
45 | 'Format which indicates how to parse the input_file.')
46 | flags.DEFINE_string(
47 | 'output_file', None,
48 | 'Path to the TSV file where the predictions are written to.')
49 | flags.DEFINE_string(
50 | 'label_map_file', None,
51 | 'Path to the label map file. Either a JSON file ending with ".json", that '
52 | 'maps each possible tag to an ID, or a text file that has one tag per '
53 | 'line.')
54 | flags.DEFINE_string('vocab_file', None, 'Path to the BERT vocabulary file.')
55 | flags.DEFINE_integer('max_seq_length', 128, 'Maximum sequence length.')
56 | flags.DEFINE_bool(
57 | 'do_lower_case', True,
58 | 'Whether to lower case the input text. Should be True for uncased '
59 | 'models and False for cased models.')
60 | flags.DEFINE_bool('enable_swap_tag', True, 'Whether to enable the SWAP tag.')
61 | flags.DEFINE_string('saved_model', None, 'Path to an exported TF model.')
62 |
63 |
64 | def main(argv):
65 | if len(argv) > 1:
66 | raise app.UsageError('Too many command-line arguments.')
67 | flags.mark_flag_as_required('input_file')
68 | flags.mark_flag_as_required('input_format')
69 | flags.mark_flag_as_required('output_file')
70 | flags.mark_flag_as_required('label_map_file')
71 | flags.mark_flag_as_required('vocab_file')
72 | flags.mark_flag_as_required('saved_model')
73 |
74 | label_map = utils.read_label_map(FLAGS.label_map_file)
75 | tokenizer = my_tokenizer_class(FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case)
76 | converter = tagging_converter.TaggingConverter(
77 | tagging_converter.get_phrase_vocabulary_from_label_map(label_map),
78 | FLAGS.enable_swap_tag, tokenizer)
79 | builder = bert_example.BertExampleBuilder(label_map, FLAGS.vocab_file,
80 | FLAGS.max_seq_length,
81 | FLAGS.do_lower_case, converter)
82 | predictor = predict_utils.LaserTaggerPredictor(
83 | tf.contrib.predictor.from_saved_model(FLAGS.saved_model), builder,
84 | label_map)
85 | print(colored("%s input file:%s" % (curLine(), FLAGS.input_file), "red"))
86 | sources_list = []
87 | target_list = []
88 | with tf.io.gfile.GFile(FLAGS.input_file) as f:
89 | for line in f:
90 | sources, target, lcs_rate = line.rstrip('\n').split('\t')
91 | sources_list.append([sources])
92 | target_list.append(target)
93 | number = len(sources_list) # 总样本数
94 | predict_batch_size = min(96, number)
95 | batch_num = math.ceil(float(number) / predict_batch_size)
96 |
97 | start_time = time.time()
98 | num_predicted = 0
99 | with tf.gfile.Open(FLAGS.output_file, 'w') as writer:
100 | writer.write(f'source\tprediction\ttarget\n')
101 | for batch_id in range(batch_num):
102 | sources_batch = sources_list[batch_id * predict_batch_size: (batch_id + 1) * predict_batch_size]
103 | prediction_batch = predictor.predict_batch(sources_batch=sources_batch)
104 | assert len(prediction_batch) == len(sources_batch)
105 | num_predicted += len(prediction_batch)
106 | for id, [prediction, sources] in enumerate(zip(prediction_batch, sources_batch)):
107 | target = target_list[batch_id * predict_batch_size + id]
108 | writer.write(f'{"".join(sources)}\t{prediction}\t{target}\n')
109 | if batch_id % 20 == 0:
110 | cost_time = (time.time() - start_time) / 60.0
111 | print("%s batch_id=%d/%d, predict %d/%d examples, cost %.2fmin." %
112 | (curLine(), batch_id + 1, batch_num, num_predicted, number, cost_time))
113 | cost_time = (time.time() - start_time) / 60.0
114 | logging.info(
115 | f'{curLine()} {num_predicted} predictions saved to:{FLAGS.output_file}, cost {cost_time} min, ave {cost_time / num_predicted} min.')
116 |
117 |
118 | if __name__ == '__main__':
119 | app.run(main)
120 |
--------------------------------------------------------------------------------
/predict_utils.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2019 The Google Research Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | # Lint as: python3
17 | """Utility functions for running inference with a LaserTagger model."""
18 |
19 | from __future__ import absolute_import
20 | from __future__ import division
21 |
22 | from __future__ import print_function
23 | from collections import defaultdict
24 | import tagging
25 | from curLine_file import curLine
26 |
27 | class LaserTaggerPredictor(object):
28 | """Class for computing and realizing predictions with LaserTagger."""
29 |
30 | def __init__(self, tf_predictor,
31 | example_builder,
32 | label_map):
33 | """Initializes an instance of LaserTaggerPredictor.
34 |
35 | Args:
36 | tf_predictor: Loaded Tensorflow model.
37 | example_builder: BERT example builder.
38 | label_map: Mapping from tags to tag IDs.
39 | """
40 | self._predictor = tf_predictor
41 | self._example_builder = example_builder
42 | self._id_2_tag = {
43 | tag_id: tagging.Tag(tag) for tag, tag_id in label_map.items()
44 | }
45 |
46 | def predict_batch(self, sources_batch, location_batch=None): # 由predict改成
47 | """Returns realized prediction for given sources."""
48 | # Predict tag IDs.
49 | keys = ['input_ids', 'input_mask', 'segment_ids']
50 | input_info = defaultdict(list)
51 | example_list = []
52 | location = None
53 | for id, sources in enumerate(sources_batch):
54 | if location_batch is not None:
55 | location = location_batch[id] # 表示是否能修改
56 | example = self._example_builder.build_bert_example(sources, location=location)
57 | if example is None:
58 | raise ValueError("Example couldn't be built.")
59 | for key in keys:
60 | input_info[key].append(example.features[key])
61 | example_list.append(example)
62 |
63 | out = self._predictor(input_info)
64 | prediction_list = []
65 | for output, outputs_outputs, example in zip(out['pred'], out['outputs_outputs'], example_list):
66 | predicted_ids = output.tolist()
67 | # Realize output.
68 | example.features['labels'] = predicted_ids
69 | # Mask out the begin and the end token.
70 | example.features['labels_mask'] = [0] + [1] * (len(predicted_ids) - 2) + [0]
71 | labels = [
72 | self._id_2_tag[label_id] for label_id in example.get_token_labels()
73 | ]
74 | prediction = example.editing_task.realize_output(labels)
75 | prediction_list.append(prediction)
76 | return prediction_list
77 |
--------------------------------------------------------------------------------
/preprocess_main.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2019 The Google Research Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | # Lint as: python3
17 | """Convert a dataset into the TFRecord format.
18 |
19 | The resulting TFRecord file will be used when training a LaserTagger model.
20 | """
21 |
22 | from __future__ import absolute_import
23 | from __future__ import division
24 |
25 | from __future__ import print_function
26 |
27 | from typing import Text
28 |
29 | from absl import app
30 | from absl import flags
31 | from absl import logging
32 |
33 | import bert_example
34 | import tagging_converter
35 | import utils
36 |
37 | import tensorflow as tf
38 | from curLine_file import curLine
39 |
40 | FLAGS = flags.FLAGS
41 |
42 | flags.DEFINE_string(
43 | 'input_file', None,
44 | 'Path to the input file containing examples to be converted to '
45 | 'tf.Examples.')
46 | flags.DEFINE_enum(
47 | 'input_format', None, ['wikisplit', 'discofuse'],
48 | 'Format which indicates how to parse the input_file.')
49 | flags.DEFINE_string('output_tfrecord', None,
50 | 'Path to the resulting TFRecord file.')
51 | flags.DEFINE_string(
52 | 'label_map_file', None,
53 | 'Path to the label map file. Either a JSON file ending with ".json", that '
54 | 'maps each possible tag to an ID, or a text file that has one tag per '
55 | 'line.')
56 | flags.DEFINE_string('vocab_file', None, 'Path to the BERT vocabulary file.')
57 | flags.DEFINE_integer('max_seq_length', 128, 'Maximum sequence length.')
58 | flags.DEFINE_bool(
59 | 'do_lower_case', True,
60 | 'Whether to lower case the input text. Should be True for uncased '
61 | 'models and False for cased models.')
62 | flags.DEFINE_bool('enable_swap_tag', True, 'Whether to enable the SWAP tag.')
63 | flags.DEFINE_bool(
64 | 'output_arbitrary_targets_for_infeasible_examples', False,
65 | 'Set this to True when preprocessing the development set. Determines '
66 | 'whether to output a TF example also for sources that can not be converted '
67 | 'to target via the available tagging operations. In these cases, the '
68 | 'target ids will correspond to the tag sequence KEEP-DELETE-KEEP-DELETE... '
69 | 'which should be very unlikely to be predicted by chance. This will be '
70 | 'useful for getting more accurate eval scores during training.')
71 |
72 |
73 | def _write_example_count(count: int) -> Text:
74 | """Saves the number of converted examples to a file.
75 |
76 | This count is used when determining the number of training steps.
77 |
78 | Args:
79 | count: The number of converted examples.
80 |
81 | Returns:
82 | The filename to which the count is saved.
83 | """
84 | count_fname = FLAGS.output_tfrecord + '.num_examples.txt'
85 | with tf.io.gfile.GFile(count_fname, 'w') as count_writer:
86 | count_writer.write(str(count))
87 | return count_fname
88 |
89 |
90 | def main(argv):
91 | if len(argv) > 1:
92 | raise app.UsageError('Too many command-line arguments.')
93 | flags.mark_flag_as_required('input_file')
94 | flags.mark_flag_as_required('input_format')
95 | flags.mark_flag_as_required('output_tfrecord')
96 | flags.mark_flag_as_required('label_map_file')
97 | flags.mark_flag_as_required('vocab_file')
98 |
99 | label_map = utils.read_label_map(FLAGS.label_map_file)
100 | tokenizer = utils.my_tokenizer_class(FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case)
101 | converter = tagging_converter.TaggingConverter(
102 | tagging_converter.get_phrase_vocabulary_from_label_map(label_map), # phrase_vocabulary set
103 | FLAGS.enable_swap_tag, tokenizer=tokenizer)
104 | # print(curLine(), len(label_map), "label_map:", label_map, converter._max_added_phrase_length)
105 | builder = bert_example.BertExampleBuilder(label_map, FLAGS.vocab_file,
106 | FLAGS.max_seq_length,
107 | FLAGS.do_lower_case, converter)
108 |
109 | num_converted = 0
110 | with tf.io.TFRecordWriter(FLAGS.output_tfrecord) as writer:
111 | for i, (sources, target) in enumerate(utils.yield_sources_and_targets(
112 | FLAGS.input_file, FLAGS.input_format)):
113 | logging.log_every_n(
114 | logging.INFO,
115 | f'{i} examples processed, {num_converted} converted to tf.Example.',
116 | 10000)
117 | example = builder.build_bert_example(
118 | sources, target,
119 | FLAGS.output_arbitrary_targets_for_infeasible_examples)
120 | if example is None:
121 | continue # 根据output_arbitrary_targets_for_infeasible_examples,不能转化的忽略或随机,如果随机也会加到num_converted
122 | writer.write(example.to_tf_example().SerializeToString())
123 | num_converted += 1
124 | logging.info(f'Done. {num_converted} examples converted to tf.Example.')
125 | count_fname = _write_example_count(num_converted)
126 | logging.info(f'Wrote:\n{FLAGS.output_tfrecord}\n{count_fname}')
127 |
128 |
129 | if __name__ == '__main__':
130 | app.run(main)
131 |
--------------------------------------------------------------------------------
/qa_rephrase/__init__.py:
--------------------------------------------------------------------------------
1 | import bert_example
2 | import predict_utils
3 | import tagging_converter
4 | import utils
--------------------------------------------------------------------------------
/qa_rephrase/predict_for_qa.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # 为domain识别泛化语料
3 |
4 | from __future__ import absolute_import
5 | from __future__ import division
6 | from __future__ import print_function
7 |
8 | from absl import app
9 | from absl import flags
10 | from absl import logging
11 | import os, sys, time
12 | from termcolor import colored
13 | import math
14 | import tensorflow as tf
15 |
16 | block_list = os.path.realpath(__file__).split("/")
17 | path = "/".join(block_list[:-2])
18 | sys.path.append(path)
19 | from compute_lcs import _compute_lcs
20 | import bert_example
21 | import predict_utils
22 | import tagging_converter
23 | import utils
24 |
25 | from curLine_file import curLine
26 |
27 | FLAGS = flags.FLAGS
28 |
29 | flags.DEFINE_string(
30 | 'input_file', None,
31 | 'Path to the input file containing examples for which to compute '
32 | 'predictions.')
33 | flags.DEFINE_enum(
34 | 'input_format', None, ['wikisplit', 'discofuse'],
35 | 'Format which indicates how to parse the input_file.')
36 | flags.DEFINE_string(
37 | 'output_file', None,
38 | 'Path to the TSV file where the predictions are written to.')
39 | flags.DEFINE_string(
40 | 'label_map_file', None,
41 | 'Path to the label map file. Either a JSON file ending with ".json", that '
42 | 'maps each possible tag to an ID, or a text file that has one tag per '
43 | 'line.')
44 | flags.DEFINE_string('vocab_file', None, 'Path to the BERT vocabulary file.')
45 | flags.DEFINE_integer('max_seq_length', 128, 'Maximum sequence length.')
46 | flags.DEFINE_bool(
47 | 'do_lower_case', False,
48 | 'Whether to lower case the input text. Should be True for uncased '
49 | 'models and False for cased models.')
50 | flags.DEFINE_bool('enable_swap_tag', True, 'Whether to enable the SWAP tag.')
51 | flags.DEFINE_string('saved_model', None, 'Path to an exported TF model.')
52 |
53 | def predict_and_write(predictor, sources_batch, previous_line_list,context_list, writer, num_predicted, start_time, batch_num):
54 | prediction_batch = predictor.predict_batch(sources_batch=sources_batch)
55 | assert len(prediction_batch) == len(sources_batch)
56 | for id, [prediction, sources] in enumerate(zip(prediction_batch, sources_batch)):
57 | output = ""
58 | if len(prediction) > 1 and prediction != sources: # TODO ignore keep totelly and short prediction
59 | output= "%s%s" % (context_list[id], prediction) # 需要和context拼接么
60 | # print(curLine(), "prediction:", prediction, "sources:", sources, ",output:", output, prediction != sources)
61 | writer.write("%s\t%s\n" % (previous_line_list[id], output))
62 | batch_num = batch_num + 1
63 | num_predicted += len(prediction_batch)
64 | if batch_num % 200 == 0:
65 | cost_time = (time.time() - start_time) / 3600.0
66 | print("%s batch_id=%d, predict %d examples, cost %.3fh." %
67 | (curLine(), batch_num, num_predicted, cost_time))
68 | return num_predicted, batch_num
69 |
70 |
71 | def main(argv):
72 | if len(argv) > 1:
73 | raise app.UsageError('Too many command-line arguments.')
74 | flags.mark_flag_as_required('input_file')
75 | flags.mark_flag_as_required('input_format')
76 | flags.mark_flag_as_required('output_file')
77 | flags.mark_flag_as_required('label_map_file')
78 | flags.mark_flag_as_required('vocab_file')
79 | flags.mark_flag_as_required('saved_model')
80 |
81 | label_map = utils.read_label_map(FLAGS.label_map_file)
82 | converter = tagging_converter.TaggingConverter(
83 | tagging_converter.get_phrase_vocabulary_from_label_map(label_map),
84 | FLAGS.enable_swap_tag)
85 | builder = bert_example.BertExampleBuilder(label_map, FLAGS.vocab_file,
86 | FLAGS.max_seq_length,
87 | FLAGS.do_lower_case, converter)
88 | predictor = predict_utils.LaserTaggerPredictor(
89 | tf.contrib.predictor.from_saved_model(FLAGS.saved_model), builder,
90 | label_map)
91 | print(colored("%s input file:%s" % (curLine(), FLAGS.input_file), "red"))
92 | sourcesA_list = []
93 | sourcesB_list = []
94 | target_list = []
95 | with tf.io.gfile.GFile(FLAGS.input_file) as f:
96 | for line in f:
97 | sourceA, sourceB, label = line.rstrip('\n').split('\t')
98 | sourcesA_list.append([sourceA.strip(".")])
99 | sourcesB_list.append([sourceB.strip(".")])
100 | target_list.append(label)
101 |
102 |
103 |
104 | number = len(sourcesA_list) # 总样本数
105 | predict_batch_size = min(32, number)
106 | batch_num = math.ceil(float(number) / predict_batch_size)
107 |
108 | start_time = time.time()
109 | num_predicted = 0
110 | prediction_list = []
111 | with tf.gfile.Open(FLAGS.output_file, 'w') as writer:
112 | for batch_id in range(batch_num):
113 | sources_batch = sourcesA_list[batch_id * predict_batch_size: (batch_id + 1) * predict_batch_size]
114 | batch_b = sourcesB_list[batch_id * predict_batch_size: (batch_id + 1) * predict_batch_size]
115 | location_batch = []
116 | sources_batch.extend(batch_b)
117 | for source in sources_batch:
118 | location = list()
119 | for char in source[0]:
120 | if (char>='0' and char<='9') or char in '.- ' or (char>='a' and char<='z') or (char>='A' and char<='Z'):
121 | location.append("1") # TODO TODO
122 | else:
123 | location.append("0")
124 | location_batch.append("".join(location))
125 | prediction_batch = predictor.predict_batch(sources_batch=sources_batch, location_batch=location_batch)
126 | current_batch_size = int(len(sources_batch)/2)
127 | assert len(prediction_batch) == current_batch_size*2
128 |
129 | for id in range(0, current_batch_size):
130 | target = target_list[num_predicted+id]
131 | prediction_A = prediction_batch[id]
132 | prediction_B = prediction_batch[current_batch_size+id]
133 | sourceA = "".join(sources_batch[id])
134 | sourceB = "".join(sources_batch[current_batch_size + id])
135 | if prediction_A == prediction_B: # 其中一个换为source
136 | lcsA = len(_compute_lcs(sourceA, prediction_A))
137 | if lcsA < 8: # A的变化大
138 | prediction_B = sourceB
139 | else:
140 | lcsB = len(_compute_lcs(sourceB, prediction_B))
141 | if lcsA <= lcsB: # A的变化大
142 | prediction_B = sourceB
143 | else:
144 | prediction_A = sourceA
145 | print(curLine(), batch_id, prediction_A, prediction_B, "target:", target, "current_batch_size=",
146 | current_batch_size, "lcsA=%d,lcsB=%d" % (lcsA, lcsB))
147 | writer.write(f'{prediction_A}\t{prediction_B}\t{target}\n')
148 |
149 | prediction_list.append("%s\t%s\n"% (sourceA, prediction_A))
150 | # print(curLine(), id,"sourceA:", sourceA, "sourceB:",sourceB, "target:", target)
151 | prediction_list.append("%s\t%s\n" % (sourceB, prediction_B))
152 | num_predicted += current_batch_size
153 | if batch_id % 20 == 0:
154 | cost_time = (time.time() - start_time) / 60.0
155 | print(curLine(), id, prediction_A, prediction_B, "target:", target, "current_batch_size=", current_batch_size)
156 | print(curLine(), id,"sourceA:", sourceA, "sourceB:",sourceB, "target:", target)
157 | print("%s batch_id=%d/%d, predict %d/%d examples, cost %.2fmin." %
158 | (curLine(), batch_id + 1, batch_num, num_predicted, number, cost_time))
159 | with open("prediction.txt", "w") as prediction_file:
160 | prediction_file.writelines(prediction_list)
161 | print(curLine(), "save to prediction_qa.txt.")
162 | cost_time = (time.time() - start_time) / 60.0
163 | print(curLine(), id, prediction_A, prediction_B, "target:", target, "current_batch_size=", current_batch_size)
164 | print(curLine(), id, "sourceA:", sourceA, "sourceB:", sourceB, "target:", target)
165 | logging.info(
166 | f'{curLine()} {num_predicted} predictions saved to:{FLAGS.output_file}, cost {cost_time} min, ave {cost_time / num_predicted*60000}ms.')
167 |
168 | if __name__ == '__main__':
169 | app.run(main)
170 |
--------------------------------------------------------------------------------
/qa_rephrase/score_for_qa.txt:
--------------------------------------------------------------------------------
1 | 1200
2 | Exact score: 12.843
3 | SARI score: 51.313
4 | KEEP score: 76.046
5 | ADDITION score: 20.086
6 | DELETION score: 57.808
7 | cost time 0.337063 h
8 |
9 | 房山 Restoring parameters from /home/wzk/Mywork/corpus/rephrase_corpus/output/models/wikisplit_experiment_name/model.ckpt-3643
10 | Exact score: 13.672
11 | SARI score: 53.450
12 | KEEP score: 77.125
13 | ADDITION score: 21.678
14 | DELETION score: 61.546
15 | cost time 0.260398 h
16 | 继续train epoch从7.0改为30.0
17 | Saving checkpoints for 36436 into /home/wzk/Mywork/corpus/rephrase_corpus/output/models/wikisplit_experiment_name/model.ckpt.
18 | Exact score: 19.414
19 | SARI score: 56.749
20 | KEEP score: 78.622
21 | ADDITION score: 26.927
22 | DELETION score: 64.697
23 | cost time 0.260367 h
24 | model.ckpt-35643
25 | Exact score: 19.440
26 | SARI score: 56.725
27 | KEEP score: 78.620
28 | ADDITION score: 26.917
29 | DELETION score: 64.638
30 | cost time 0.259561 h
31 | Restoring parameters from /home/wzk/Mywork/corpus/rephrase_corpus/output/models/wikisplit_experiment_name/model.ckpt-62436
32 | Exact score: 20.347
33 | SARI score: 57.107
34 | KEEP score: 78.693
35 | ADDITION score: 27.379
36 | DELETION score: 65.250
37 | cost time 0.575784 h
38 |
39 |
40 | 旧的
41 | 完善大小写的处理规则
42 | Exact score: 29.283
43 | SARI score: 64.002
44 | KEEP score: 84.942
45 | ADDITION score: 38.128
46 | DELETION score: 68.935
47 | cost time 0.41796 h
--------------------------------------------------------------------------------
/rephrase.sh:
--------------------------------------------------------------------------------
1 | # 扩充文本匹配的语料 文本复述任务
2 | export HOST_NAME=$1
3 | # set gpu id to use
4 | if [[ "wzk" == "$HOST_NAME" ]]
5 | then
6 | # set gpu id to use
7 | export CUDA_VISIBLE_DEVICES=0
8 | else
9 | # not use gpu
10 | export CUDA_VISIBLE_DEVICES=""
11 | fi
12 |
13 | start_tm=`date +%s%N`;
14 |
15 |
16 | ### Optional parameters ###
17 | # If you train multiple models on the same data, change this label.
18 | export EXPERIMENT="scapal_model" #"wikisplit_experiment_name_BertTiny" #
19 |
20 | # To quickly test that model training works, set the number of epochs to a
21 | # smaller value (e.g. 0.01).
22 | export num_train_epochs=30
23 | export TRAIN_BATCH_SIZE=256
24 | export PHRASE_VOCAB_SIZE=500
25 | export MAX_INPUT_EXAMPLES=1000000
26 | export SAVE_CHECKPOINT_STEPS=2000
27 | export keep_checkpoint_max=8
28 | export enable_swap_tag=false
29 | export output_arbitrary_targets_for_infeasible_examples=false
30 | export REPHRASE_DIR="/home/${HOST_NAME}/Mywork/corpus/rephrase_corpus"
31 | export BERT_BASE_DIR="/home/${HOST_NAME}/Mywork/model/RoBERTa-tiny-clue"
32 | export OUTPUT_DIR="${REPHRASE_DIR}/output"
33 |
34 | export max_seq_length=40
35 | # Check these numbers from the "*.num_examples" files created in step 2.
36 | export NUM_TRAIN_EXAMPLES=310922
37 | export NUM_EVAL_EXAMPLES=5000
38 | export CONFIG_FILE=configs/lasertagger_config.json
39 |
40 | ### Get the most recently exported model directory.
41 | TIMESTAMP=$(ls "${OUTPUT_DIR}/models/${EXPERIMENT}/export/" | \
42 | grep -v "temp-" | sort -r | head -1)
43 | SAVED_MODEL_DIR=${OUTPUT_DIR}/models/${EXPERIMENT}/export/${TIMESTAMP}
44 | PREDICTION_FILE=${OUTPUT_DIR}/models/${EXPERIMENT}/pred.tsv
45 |
46 |
47 | #python phrase_vocabulary_optimization.py \
48 | # --input_file=${REPHRASE_DIR}/train.txt \
49 | # --input_format=wikisplit \
50 | # --vocabulary_size=${PHRASE_VOCAB_SIZE} \
51 | # --max_input_examples=1000000 \
52 | # --enable_swap_tag=${enable_swap_tag} \
53 | # --output_file=${OUTPUT_DIR}/label_map.txt \
54 | # --vocab_file=${BERT_BASE_DIR}/vocab.txt
55 | #
56 | #python preprocess_main.py \
57 | # --input_file=${REPHRASE_DIR}/tune.txt \
58 | # --input_format=wikisplit \
59 | # --output_tfrecord=${OUTPUT_DIR}/tune.tf_record \
60 | # --label_map_file=${OUTPUT_DIR}/label_map.txt \
61 | # --vocab_file=${BERT_BASE_DIR}/vocab.txt \
62 | # --max_seq_length=${max_seq_length} \
63 | # --enable_swap_tag=${enable_swap_tag} \
64 | # --output_arbitrary_targets_for_infeasible_examples=${output_arbitrary_targets_for_infeasible_examples}
65 | #python preprocess_main.py \
66 | # --input_file=${REPHRASE_DIR}/train.txt \
67 | # --input_format=wikisplit \
68 | # --output_tfrecord=${OUTPUT_DIR}/train.tf_record \
69 | # --label_map_file=${OUTPUT_DIR}/label_map.txt \
70 | # --vocab_file=${BERT_BASE_DIR}/vocab.txt \
71 | # --max_seq_length=${max_seq_length} \
72 | # --enable_swap_tag=${enable_swap_tag} \
73 | # --output_arbitrary_targets_for_infeasible_examples=${output_arbitrary_targets_for_infeasible_examples}
74 |
75 |
76 | #python run_lasertagger.py \
77 | # --training_file=${OUTPUT_DIR}/train.tf_record \
78 | # --eval_file=${OUTPUT_DIR}/tune.tf_record \
79 | # --label_map_file=${OUTPUT_DIR}/label_map.txt \
80 | # --model_config_file=${CONFIG_FILE} \
81 | # --output_dir=${OUTPUT_DIR}/models/${EXPERIMENT} \
82 | # --init_checkpoint=${BERT_BASE_DIR}/bert_model.ckpt \
83 | # --do_train=true \
84 | # --do_eval=true \
85 | # --num_train_epochs=${num_train_epochs} \
86 | # --train_batch_size=${TRAIN_BATCH_SIZE} \
87 | # --save_checkpoints_steps=${SAVE_CHECKPOINT_STEPS} \
88 | # --keep_checkpoint_max=${keep_checkpoint_max} \
89 | # --max_seq_length=${max_seq_length} \
90 | # --num_train_examples=${NUM_TRAIN_EXAMPLES} \
91 | # --num_eval_examples=${NUM_EVAL_EXAMPLES}
92 |
93 |
94 | ## Export the model.
95 | echo "Export the model."
96 | python run_lasertagger.py \
97 | --label_map_file=${OUTPUT_DIR}/label_map.txt \
98 | --model_config_file=${CONFIG_FILE} \
99 | --output_dir=${OUTPUT_DIR}/models/${EXPERIMENT} \
100 | --do_export=true \
101 | --export_path=${OUTPUT_DIR}/models/${EXPERIMENT}/export
102 | #
103 | #
104 | #
105 | ### 4. Prediction
106 | #echo "predict"
107 | python predict_main.py \
108 | --input_file=${REPHRASE_DIR}/test.txt \
109 | --input_format=wikisplit \
110 | --output_file=${PREDICTION_FILE} \
111 | --label_map_file=${OUTPUT_DIR}/label_map.txt \
112 | --vocab_file=${BERT_BASE_DIR}/vocab.txt \
113 | --max_seq_length=${max_seq_length} \
114 | --enable_swap_tag=${enable_swap_tag} \
115 | --saved_model=${SAVED_MODEL_DIR}
116 |
117 | ### 5. Evaluation
118 | python score_main.py --prediction_file=${PREDICTION_FILE}
119 |
120 |
121 | end_tm=`date +%s%N`;
122 | use_tm=`echo $end_tm $start_tm | awk '{ print ($1 - $2) / 1000000000 /3600}'`
123 | echo "cost time" $use_tm "h"
--------------------------------------------------------------------------------
/rephrase_for_chat.sh:
--------------------------------------------------------------------------------
1 | # 为闲聊的文本匹配语料做数据增强
2 | # 扩充文本匹配的语料 文本复述任务
3 |
4 | # set gpu id to use
5 | export CUDA_VISIBLE_DEVICES=""
6 | start_tm=`date +%s%N`;
7 |
8 | export HOST_NAME="wzk"
9 | ### Optional parameters ###
10 |
11 | # If you train multiple models on the same data, change this label.
12 | EXPERIMENT=wikisplit_experiment
13 | # To quickly test that model training works, set the number of epochs to a
14 | # smaller value (e.g. 0.01).
15 | NUM_EPOCHS=60.0
16 | export TRAIN_BATCH_SIZE=256
17 | export PHRASE_VOCAB_SIZE=500
18 | export MAX_INPUT_EXAMPLES=1000000
19 | export SAVE_CHECKPOINT_STEPS=200
20 | export enable_swap_tag=true
21 | export output_arbitrary_targets_for_infeasible_examples=false
22 | export WIKISPLIT_DIR="/home/${HOST_NAME}/Mywork/corpus/rephrase_corpus"
23 | export OUTPUT_DIR="${WIKISPLIT_DIR}/output"
24 |
25 | #python phrase_vocabulary_optimization.py \
26 | # --input_file=${WIKISPLIT_DIR}/train.txt \
27 | # --input_format=wikisplit \
28 | # --vocabulary_size=500 \
29 | # --max_input_examples=1000000 \
30 | # --enable_swap_tag=${enable_swap_tag} \
31 | # --output_file=${OUTPUT_DIR}/label_map.txt
32 |
33 |
34 | export max_seq_length=40 # TODO
35 | export BERT_BASE_DIR="/home/${HOST_NAME}/Mywork/model/RoBERTa-tiny-clue" # chinese_L-12_H-768_A-12"
36 |
37 | #python preprocess_main.py \
38 | # --input_file=${WIKISPLIT_DIR}/tune.txt \
39 | # --input_format=wikisplit \
40 | # --output_tfrecord=${OUTPUT_DIR}/tune.tf_record \
41 | # --label_map_file=${OUTPUT_DIR}/label_map.txt \
42 | # --vocab_file=${BERT_BASE_DIR}/vocab.txt \
43 | # --max_seq_length=${max_seq_length} \
44 | # --enable_swap_tag=${enable_swap_tag} \
45 | # --output_arbitrary_targets_for_infeasible_examples=${output_arbitrary_targets_for_infeasible_examples} # TODO true
46 | #
47 | #python preprocess_main.py \
48 | # --input_file=${WIKISPLIT_DIR}/train.txt \
49 | # --input_format=wikisplit \
50 | # --output_tfrecord=${OUTPUT_DIR}/train.tf_record \
51 | # --label_map_file=${OUTPUT_DIR}/label_map.txt \
52 | # --vocab_file=${BERT_BASE_DIR}/vocab.txt \
53 | # --max_seq_length=${max_seq_length} \
54 | # --enable_swap_tag=${enable_swap_tag} \
55 | # --output_arbitrary_targets_for_infeasible_examples=${output_arbitrary_targets_for_infeasible_examples} # TODO false
56 |
57 |
58 |
59 | # Check these numbers from the "*.num_examples" files created in step 2.
60 | export NUM_TRAIN_EXAMPLES=310922
61 | export NUM_EVAL_EXAMPLES=5000
62 | export CONFIG_FILE=configs/lasertagger_config.json
63 | export EXPERIMENT=wikisplit_experiment_name
64 |
65 |
66 |
67 | #python run_lasertagger.py \
68 | # --training_file=${OUTPUT_DIR}/train.tf_record \
69 | # --eval_file=${OUTPUT_DIR}/tune.tf_record \
70 | # --label_map_file=${OUTPUT_DIR}/label_map.txt \
71 | # --model_config_file=${CONFIG_FILE} \
72 | # --output_dir=${OUTPUT_DIR}/models/${EXPERIMENT} \
73 | # --init_checkpoint=${BERT_BASE_DIR}/bert_model.ckpt \
74 | # --do_train=true \
75 | # --do_eval=true \
76 | # --train_batch_size=${TRAIN_BATCH_SIZE} \
77 | # --save_checkpoints_steps=200 \
78 | # --max_seq_length=${max_seq_length} \
79 | # --num_train_examples=${NUM_TRAIN_EXAMPLES} \
80 | # --num_eval_examples=${NUM_EVAL_EXAMPLES}
81 |
82 | #CUDA_VISIBLE_DEVICES="" nohup python run_lasertagger.py \
83 | # --training_file=${OUTPUT_DIR}/train.tf_record \
84 | # --eval_file=${OUTPUT_DIR}/tune.tf_record \
85 | # --label_map_file=${OUTPUT_DIR}/label_map.txt \
86 | # --model_config_file=${CONFIG_FILE} \
87 | # --output_dir=${OUTPUT_DIR}/models/${EXPERIMENT} \
88 | # --init_checkpoint=${BERT_BASE_DIR}/bert_model.ckpt \
89 | # --do_train=true \
90 | # --do_eval=true \
91 | # --train_batch_size=${TRAIN_BATCH_SIZE} \
92 | # --save_checkpoints_steps=${SAVE_CHECKPOINT_STEPS} \
93 | # --num_train_epochs=${NUM_EPOCHS} \
94 | # --max_seq_length=${max_seq_length} \
95 | # --num_train_examples=${NUM_TRAIN_EXAMPLES} \
96 | # --num_eval_examples=${NUM_EVAL_EXAMPLES} > log.txt 2>&1 &
97 |
98 |
99 | ### 4. Prediction
100 |
101 | # Export the model.
102 | python run_lasertagger.py \
103 | --label_map_file=${OUTPUT_DIR}/label_map.txt \
104 | --model_config_file=${CONFIG_FILE} \
105 | --output_dir=${OUTPUT_DIR}/models/${EXPERIMENT} \
106 | --do_export=true \
107 | --export_path=${OUTPUT_DIR}/models/${EXPERIMENT}/export
108 | #
109 | ### Get the most recently exported model directory.
110 | TIMESTAMP=$(ls "${OUTPUT_DIR}/models/${EXPERIMENT}/export/" | \
111 | grep -v "temp-" | sort -r | head -1)
112 | SAVED_MODEL_DIR=${OUTPUT_DIR}/models/${EXPERIMENT}/export/${TIMESTAMP}
113 | PREDICTION_FILE=${OUTPUT_DIR}/models/${EXPERIMENT}/expand_simq_v4.0.json
114 |
115 | python chat_rephrase/predict_for_chat.py \
116 | --input_file=/home/${HOST_NAME}/Mywork/corpus/闲聊/simq_v4.0.json \
117 | --input_format=wikisplit \
118 | --output_file=${PREDICTION_FILE} \
119 | --label_map_file=${OUTPUT_DIR}/label_map.txt \
120 | --vocab_file=${BERT_BASE_DIR}/vocab.txt \
121 | --max_seq_length=${max_seq_length} \
122 | --enable_swap_tag=${enable_swap_tag} \
123 | --saved_model=${SAVED_MODEL_DIR}
124 |
125 | # downloag file of pred_qa.tsv, give shixi
126 |
127 | #[predict_for_qa.py:166] 238766 predictions saved to:/home/wzk/Mywork/corpus/rephrase_corpus/output/models/wikisplit_experiment_name/pred_qa.tsv, cost 332.37216806411743 min, ave 83.52248680233805ms.
128 | #cost time 5.54552 h
129 |
130 |
131 |
132 | #### 5. Evaluation
133 | #python score_main.py --prediction_file=${PREDICTION_FILE}
134 |
135 |
136 | end_tm=`date +%s%N`;
137 | use_tm=`echo $end_tm $start_tm | awk '{ print ($1 - $2) / 1000000000 /3600}'`
138 | echo "cost time" $use_tm "h"
--------------------------------------------------------------------------------
/rephrase_for_qa.sh:
--------------------------------------------------------------------------------
1 | # 扩充文本匹配的语料 文本复述任务
2 | # set gpu id to use
3 | export CUDA_VISIBLE_DEVICES=""
4 |
5 |
6 | start_tm=`date +%s%N`;
7 |
8 | export HOST_NAME="wzk" #"cloudminds" #
9 | ### Optional parameters ###
10 |
11 | # If you train multiple models on the same data, change this label.
12 | EXPERIMENT=wikisplit_experiment
13 | # To quickly test that model training works, set the number of epochs to a
14 | # smaller value (e.g. 0.01).
15 | NUM_EPOCHS=60.0
16 | export TRAIN_BATCH_SIZE=256 # 512 OOM 256 OK
17 | PHRASE_VOCAB_SIZE=500
18 | MAX_INPUT_EXAMPLES=1000000
19 | SAVE_CHECKPOINT_STEPS=200
20 | export enable_swap_tag=true
21 | export output_arbitrary_targets_for_infeasible_examples=false
22 | export WIKISPLIT_DIR="/home/${HOST_NAME}/Mywork/corpus/rephrase_corpus"
23 | export OUTPUT_DIR="${WIKISPLIT_DIR}/output"
24 |
25 | #python phrase_vocabulary_optimization.py \
26 | # --input_file=${WIKISPLIT_DIR}/train.txt \
27 | # --input_format=wikisplit \
28 | # --vocabulary_size=500 \
29 | # --max_input_examples=1000000 \
30 | # --enable_swap_tag=${enable_swap_tag} \
31 | # --output_file=${OUTPUT_DIR}/label_map.txt
32 |
33 |
34 | export max_seq_length=40 # TODO
35 | export BERT_BASE_DIR="/home/${HOST_NAME}/Mywork/model/RoBERTa-tiny-clue" # chinese_L-12_H-768_A-12"
36 |
37 | #python preprocess_main.py \
38 | # --input_file=${WIKISPLIT_DIR}/tune.txt \
39 | # --input_format=wikisplit \
40 | # --output_tfrecord=${OUTPUT_DIR}/tune.tf_record \
41 | # --label_map_file=${OUTPUT_DIR}/label_map.txt \
42 | # --vocab_file=${BERT_BASE_DIR}/vocab.txt \
43 | # --max_seq_length=${max_seq_length} \
44 | # --enable_swap_tag=${enable_swap_tag} \
45 | # --output_arbitrary_targets_for_infeasible_examples=${output_arbitrary_targets_for_infeasible_examples} # TODO true
46 | #
47 | #python preprocess_main.py \
48 | # --input_file=${WIKISPLIT_DIR}/train.txt \
49 | # --input_format=wikisplit \
50 | # --output_tfrecord=${OUTPUT_DIR}/train.tf_record \
51 | # --label_map_file=${OUTPUT_DIR}/label_map.txt \
52 | # --vocab_file=${BERT_BASE_DIR}/vocab.txt \
53 | # --max_seq_length=${max_seq_length} \
54 | # --enable_swap_tag=${enable_swap_tag} \
55 | # --output_arbitrary_targets_for_infeasible_examples=${output_arbitrary_targets_for_infeasible_examples} # TODO false
56 |
57 |
58 |
59 | # Check these numbers from the "*.num_examples" files created in step 2.
60 | export NUM_TRAIN_EXAMPLES=310922
61 | export NUM_EVAL_EXAMPLES=5000
62 | export CONFIG_FILE=configs/lasertagger_config.json
63 | export EXPERIMENT=wikisplit_experiment_name
64 |
65 |
66 |
67 | #python run_lasertagger.py \
68 | # --training_file=${OUTPUT_DIR}/train.tf_record \
69 | # --eval_file=${OUTPUT_DIR}/tune.tf_record \
70 | # --label_map_file=${OUTPUT_DIR}/label_map.txt \
71 | # --model_config_file=${CONFIG_FILE} \
72 | # --output_dir=${OUTPUT_DIR}/models/${EXPERIMENT} \
73 | # --init_checkpoint=${BERT_BASE_DIR}/bert_model.ckpt \
74 | # --do_train=true \
75 | # --do_eval=true \
76 | # --train_batch_size=${TRAIN_BATCH_SIZE} \
77 | # --save_checkpoints_steps=200 \
78 | # --max_seq_length=${max_seq_length} \
79 | # --num_train_examples=${NUM_TRAIN_EXAMPLES} \
80 | # --num_eval_examples=${NUM_EVAL_EXAMPLES}
81 |
82 | #CUDA_VISIBLE_DEVICES="" nohup python run_lasertagger.py \
83 | # --training_file=${OUTPUT_DIR}/train.tf_record \
84 | # --eval_file=${OUTPUT_DIR}/tune.tf_record \
85 | # --label_map_file=${OUTPUT_DIR}/label_map.txt \
86 | # --model_config_file=${CONFIG_FILE} \
87 | # --output_dir=${OUTPUT_DIR}/models/${EXPERIMENT} \
88 | # --init_checkpoint=${BERT_BASE_DIR}/bert_model.ckpt \
89 | # --do_train=true \
90 | # --do_eval=true \
91 | # --train_batch_size=${TRAIN_BATCH_SIZE} \
92 | # --save_checkpoints_steps=${SAVE_CHECKPOINT_STEPS} \
93 | # --num_train_epochs=${NUM_EPOCHS} \
94 | # --max_seq_length=${max_seq_length} \
95 | # --num_train_examples=${NUM_TRAIN_EXAMPLES} \
96 | # --num_eval_examples=${NUM_EVAL_EXAMPLES} > log.txt 2>&1 &
97 |
98 |
99 | ### 4. Prediction
100 |
101 | # Export the model.
102 | python run_lasertagger.py \
103 | --label_map_file=${OUTPUT_DIR}/label_map.txt \
104 | --model_config_file=${CONFIG_FILE} \
105 | --output_dir=${OUTPUT_DIR}/models/${EXPERIMENT} \
106 | --do_export=true \
107 | --export_path=${OUTPUT_DIR}/models/${EXPERIMENT}/export
108 | #
109 | ### Get the most recently exported model directory.
110 | TIMESTAMP=$(ls "${OUTPUT_DIR}/models/${EXPERIMENT}/export/" | \
111 | grep -v "temp-" | sort -r | head -1)
112 | SAVED_MODEL_DIR=${OUTPUT_DIR}/models/${EXPERIMENT}/export/${TIMESTAMP}
113 | PREDICTION_FILE=${OUTPUT_DIR}/models/${EXPERIMENT}/pred_qa.txt
114 |
115 | python qa_rephrase/predict_for_qa.py \
116 | --input_file=/home/${HOST_NAME}/Mywork/corpus/Chinese_QA/LCQMC/train.txt \
117 | --input_format=wikisplit \
118 | --output_file=${PREDICTION_FILE} \
119 | --label_map_file=${OUTPUT_DIR}/label_map.txt \
120 | --vocab_file=${BERT_BASE_DIR}/vocab.txt \
121 | --max_seq_length=${max_seq_length} \
122 | --enable_swap_tag=${enable_swap_tag} \
123 | --saved_model=${SAVED_MODEL_DIR}
124 |
125 | # downloag file of pred_qa.tsv, give shixi
126 |
127 | #[predict_for_qa.py:166] 238766 predictions saved to:/home/wzk/Mywork/corpus/rephrase_corpus/output/models/wikisplit_experiment_name/pred_qa.tsv, cost 332.37216806411743 min, ave 83.52248680233805ms.
128 | #cost time 5.54552 h
129 |
130 |
131 |
132 | #### 5. Evaluation
133 | #python score_main.py --prediction_file=${PREDICTION_FILE}
134 |
135 |
136 | end_tm=`date +%s%N`;
137 | use_tm=`echo $end_tm $start_tm | awk '{ print ($1 - $2) / 1000000000 /3600}'`
138 | echo "cost time" $use_tm "h"
--------------------------------------------------------------------------------
/rephrase_for_skill.sh:
--------------------------------------------------------------------------------
1 | # 扩充技能的语料
2 | # rephrase_for_skill.sh: 在rephrase.sh基础上改的
3 | # predict_for_skill.py: 在 predict_main.py基础上改的
4 | # score_for_skill.txt 结果对比
5 |
6 | # set gpu id to use
7 | export CUDA_VISIBLE_DEVICES=""
8 |
9 | start_tm=`date +%s%N`;
10 |
11 | export HOST_NAME="wzk"
12 | ### Optional parameters ###
13 |
14 | # If you train multiple models on the same data, change this label.
15 | EXPERIMENT=wikisplit_experiment
16 | # To quickly test that model training works, set the number of epochs to a
17 | # smaller value (e.g. 0.01).
18 | NUM_EPOCHS=10.0
19 | export TRAIN_BATCH_SIZE=256
20 | export PHRASE_VOCAB_SIZE=500
21 | export MAX_INPUT_EXAMPLES=1000000
22 | export SAVE_CHECKPOINT_STEPS=200
23 | export enable_swap_tag=false
24 | export output_arbitrary_targets_for_infeasible_examples=false
25 | export WIKISPLIT_DIR="/home/${HOST_NAME}/Mywork/corpus/rephrase_corpus"
26 | export OUTPUT_DIR="${WIKISPLIT_DIR}/output"
27 |
28 |
29 |
30 | export max_seq_length=40 # TODO
31 | export BERT_BASE_DIR="/home/${HOST_NAME}/Mywork/model/RoBERTa-tiny-clue" # chinese_L-12_H-768_A-12"
32 |
33 |
34 |
35 |
36 | # Check these numbers from the "*.num_examples" files created in step 2.
37 | export CONFIG_FILE=configs/lasertagger_config.json
38 | export EXPERIMENT=wikisplit_experiment_name
39 |
40 |
41 |
42 | ### 4. Prediction
43 |
44 | # Export the model.
45 | python run_lasertagger.py \
46 | --label_map_file=${OUTPUT_DIR}/label_map.txt \
47 | --model_config_file=${CONFIG_FILE} \
48 | --output_dir=${OUTPUT_DIR}/models/${EXPERIMENT} \
49 | --do_export=true \
50 | --export_path=${OUTPUT_DIR}/models/${EXPERIMENT}/export
51 |
52 | ## Get the most recently exported model directory.
53 | TIMESTAMP=$(ls "${OUTPUT_DIR}/models/${EXPERIMENT}/export/" | \
54 | grep -v "temp-" | sort -r | head -1)
55 | SAVED_MODEL_DIR=${OUTPUT_DIR}/models/${EXPERIMENT}/export/${TIMESTAMP}
56 | PREDICTION_FILE=${OUTPUT_DIR}/models/${EXPERIMENT}/pred.tsv
57 | export domain_name=times
58 | python skill_rephrase/predict_for_skill.py \
59 | --input_file=/home/${HOST_NAME}/Mywork/corpus/ner_corpus/times_corpus/slot_times.txt \
60 | --input_format=wikisplit \
61 | --output_file=/home/${HOST_NAME}/Mywork/corpus/ner_corpus/times_corpus/slot_times_expandexpand.json \
62 | --label_map_file=${OUTPUT_DIR}/label_map.txt \
63 | --vocab_file=${BERT_BASE_DIR}/vocab.txt \
64 | --max_seq_length=${max_seq_length} \
65 | --saved_model=${SAVED_MODEL_DIR}
66 |
67 | #### 5. Evaluation
68 | #python score_main.py --prediction_file=${PREDICTION_FILE}
69 |
70 |
71 | end_tm=`date +%s%N`;
72 | use_tm=`echo $end_tm $start_tm | awk '{ print ($1 - $2) / 1000000000 /3600}'`
73 | echo "cost time" $use_tm "h"
--------------------------------------------------------------------------------
/rephrase_server.sh:
--------------------------------------------------------------------------------
1 | # 文本复述(rephrase)的服务
2 | start_tm=`date +%s%N`;
3 |
4 | export HOST_NAME="wzk"
5 | ### Optional parameters ###
6 |
7 | # If you train multiple models on the same data, change this label.
8 | EXPERIMENT=wikisplit_experiment
9 | # To quickly test that model training works, set the number of epochs to a
10 | # smaller value (e.g. 0.01).
11 |
12 | export TRAIN_BATCH_SIZE=256
13 | export PHRASE_VOCAB_SIZE=500
14 | export MAX_INPUT_EXAMPLES=1000000
15 | export SAVE_CHECKPOINT_STEPS=200
16 | export enable_swap_tag=false
17 | export output_arbitrary_targets_for_infeasible_examples=false
18 | export WIKISPLIT_DIR="/home/${HOST_NAME}/Mywork/corpus/rephrase_corpus"
19 | export BERT_BASE_DIR="/home/${HOST_NAME}/Mywork/model/RoBERTa-tiny-clue" # chinese_L-12_H-768_A-12"
20 | export OUTPUT_DIR="${WIKISPLIT_DIR}/output"
21 |
22 | #python phrase_vocabulary_optimization.py \
23 | # --input_file=${WIKISPLIT_DIR}/train.txt \
24 | # --input_format=wikisplit \
25 | # --vocabulary_size=500 \
26 | # --max_input_examples=1000000 \
27 | # --enable_swap_tag=${enable_swap_tag} \
28 | # --output_file=${OUTPUT_DIR}/label_map.txt
29 |
30 |
31 | export max_seq_length=40
32 |
33 | #python preprocess_main.py \
34 | # --input_file=${WIKISPLIT_DIR}/tune.txt \
35 | # --input_format=wikisplit \
36 | # --output_tfrecord=${OUTPUT_DIR}/tune.tf_record \
37 | # --label_map_file=${OUTPUT_DIR}/label_map.txt \
38 | # --vocab_file=${BERT_BASE_DIR}/vocab.txt \
39 | # --max_seq_length=${max_seq_length} \
40 | # --output_arbitrary_targets_for_infeasible_examples=${output_arbitrary_targets_for_infeasible_examples} # TODO true
41 | #
42 | #python preprocess_main.py \
43 | # --input_file=${WIKISPLIT_DIR}/train.txt \
44 | # --input_format=wikisplit \
45 | # --output_tfrecord=${OUTPUT_DIR}/train.tf_record \
46 | # --label_map_file=${OUTPUT_DIR}/label_map.txt \
47 | # --vocab_file=${BERT_BASE_DIR}/vocab.txt \
48 | # --max_seq_length=${max_seq_length} \
49 | # --output_arbitrary_targets_for_infeasible_examples=${output_arbitrary_targets_for_infeasible_examples} # TODO false
50 |
51 |
52 |
53 | # Check these numbers from the "*.num_examples" files created in step 2.
54 | export NUM_TRAIN_EXAMPLES=310922
55 | export NUM_EVAL_EXAMPLES=5000
56 | export CONFIG_FILE=configs/lasertagger_config.json
57 | export EXPERIMENT=wikisplit_experiment_name
58 |
59 |
60 | ### 4. Prediction
61 |
62 | # Export the model.
63 | #python run_lasertagger.py \
64 | # --label_map_file=${OUTPUT_DIR}/label_map.txt \
65 | # --model_config_file=${CONFIG_FILE} \
66 | # --output_dir=${OUTPUT_DIR}/models/${EXPERIMENT} \
67 | # --do_export=true \
68 | # --export_path=${OUTPUT_DIR}/models/${EXPERIMENT}/export
69 |
70 | ## Get the most recently exported model directory.
71 | TIMESTAMP=$(ls "${OUTPUT_DIR}/models/${EXPERIMENT}/export/" | \
72 | grep -v "temp-" | sort -r | head -1)
73 | SAVED_MODEL_DIR=${OUTPUT_DIR}/models/${EXPERIMENT}/export/${TIMESTAMP}
74 | label_map_file=${OUTPUT_DIR}/label_map.txt
75 | export host="0.0.0.0"
76 | export port=6000
77 |
78 | # start server
79 | python rephrase_server/rephrase_server_flask.py \
80 | --input_file=${WIKISPLIT_DIR}/test.txt \
81 | --input_format=wikisplit \
82 | --output_file=${PREDICTION_FILE} \
83 | --label_map_file=${OUTPUT_DIR}/label_map.txt \
84 | --vocab_file=${BERT_BASE_DIR}/vocab.txt \
85 | --max_seq_length=${max_seq_length} \
86 | --saved_model=${SAVED_MODEL_DIR} \
87 | --host=${host} \
88 | --port=${port}
89 |
90 | end_tm=`date +%s%N`;
91 | use_tm=`echo $end_tm $start_tm | awk '{ print ($1 - $2) / 1000000000 /3600}'`
92 | echo "cost time" $use_tm "h"
--------------------------------------------------------------------------------
/rephrase_server/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Mleader2/text_scalpel/131922c5dcfc558a46a7590925e348c39cb24951/rephrase_server/__init__.py
--------------------------------------------------------------------------------
/rephrase_server/rephrase_server_flask.py:
--------------------------------------------------------------------------------
1 | # 文本复述服务 基于tensorflow框架
2 | import os, sys
3 | from absl import flags
4 | from absl import logging
5 | import numpy as np
6 | import json
7 | import logging
8 | # http 接口
9 | from flask import Flask, jsonify, request
10 | app = Flask(__name__)
11 |
12 | logger = logging.getLogger('log')
13 | logger.setLevel(logging.DEBUG)
14 |
15 | while logger.hasHandlers():
16 | for i in logger.handlers:
17 | logger.removeHandler(i)
18 |
19 | user_name = "" # wzk/
20 | version="1.0.0.0"
21 |
22 |
23 | block_list = os.path.realpath(__file__).split("/")
24 | path = "/".join(block_list[:-2])
25 | sys.path.append(path)
26 |
27 |
28 | import bert_example
29 | import predict_utils
30 | import tagging_converter
31 | import utils
32 | import tensorflow as tf
33 | # FLAGS = flags.FLAGS
34 | FLAGS = tf.app.flags.FLAGS
35 | flags.DEFINE_string(
36 | 'input_file', None,
37 | 'Path to the input file containing examples for which to compute '
38 | 'predictions.')
39 | flags.DEFINE_enum(
40 | 'input_format', None, ['wikisplit', 'discofuse'],
41 | 'Format which indicates how to parse the input_file.')
42 | flags.DEFINE_string(
43 | 'output_file', None,
44 | 'Path to the TSV file where the predictions are written to.')
45 | flags.DEFINE_string(
46 | 'label_map_file', None,
47 | 'Path to the label map file. Either a JSON file ending with ".json", that '
48 | 'maps each possible tag to an ID, or a text file that has one tag per '
49 | 'line.')
50 | flags.DEFINE_string('vocab_file', None, 'Path to the BERT vocabulary file.')
51 | flags.DEFINE_integer('max_seq_length', 40, 'Maximum sequence length.')
52 | flags.DEFINE_bool(
53 | 'do_lower_case', True,
54 | 'Whether to lower case the input text. Should be True for uncased '
55 | 'models and False for cased models.')
56 | flags.DEFINE_bool('enable_swap_tag', True, 'Whether to enable the SWAP tag.')
57 | flags.DEFINE_string('saved_model', None, 'Path to an exported TF model.')
58 |
59 | flags.DEFINE_string('host', None, 'host address.')
60 | flags.DEFINE_integer('port', None, 'port.')
61 |
62 |
63 | class RequestHandler():
64 | def __init__(self):
65 | # if len(argv) > 1:
66 | # raise app.UsageError('Too many command-line arguments.')
67 | flags.mark_flag_as_required('input_file')
68 | flags.mark_flag_as_required('input_format')
69 | flags.mark_flag_as_required('output_file')
70 | flags.mark_flag_as_required('label_map_file')
71 | flags.mark_flag_as_required('vocab_file')
72 | flags.mark_flag_as_required('saved_model')
73 |
74 | label_map = utils.read_label_map(FLAGS.label_map_file)
75 | converter = tagging_converter.TaggingConverter(
76 | tagging_converter.get_phrase_vocabulary_from_label_map(label_map),
77 | FLAGS.enable_swap_tag)
78 | builder = bert_example.BertExampleBuilder(label_map, FLAGS.vocab_file,
79 | FLAGS.max_seq_length,
80 | FLAGS.do_lower_case, converter)
81 | self.predictor = predict_utils.LaserTaggerPredictor(
82 | tf.contrib.predictor.from_saved_model(FLAGS.saved_model), builder,
83 | label_map)
84 |
85 | def infer(self, sources_batch):
86 | prediction_batch = self.predictor.predict_batch(sources_batch=sources_batch)
87 | assert len(prediction_batch) == len(sources_batch)
88 |
89 | return prediction_batch
90 |
91 | class MyEncoder(json.JSONEncoder):
92 | def default(self, obj):
93 | if isinstance(obj, np.integer):
94 | return int(obj)
95 | elif isinstance(obj, np.floating):
96 | return float(obj)
97 | elif isinstance(obj, np.ndarray):
98 | return obj.tolist()
99 | else:
100 | return super(MyEncoder, self).default(obj)
101 |
102 | # http POST 接口
103 | @app.route('/rephrase', methods=['POST']) # 推理
104 | def returnPost():
105 | data_json = {'version': version}
106 | try:
107 | question_raw_batch = request.json['text_list']
108 | except TypeError:
109 | data_json['status'] = -1
110 | data_json['message'] = "Fail:input information type error"
111 | data_json['data'] = []
112 | return json.dumps(data_json, cls=MyEncoder, ensure_ascii=False)
113 | question_batch = []
114 | for question_raw in question_raw_batch:
115 | question_batch.append([question_raw.strip()])
116 | decoded_output = rHandler.infer(question_batch)
117 | if len(decoded_output) == 0:
118 | data_json['status'] = -1
119 | data_json['message'] = "Fail: fail to get the retell of the text."
120 | data_json['data'] = []
121 | return json.dumps(data_json, cls=MyEncoder, ensure_ascii=False)
122 | data_json['status'] = 0
123 | data_json['message'] = "Success"
124 | data_json['data'] = {}
125 | data_json['data']['output'] = decoded_output
126 | return json.dumps(data_json, cls=MyEncoder, ensure_ascii=False)
127 |
128 | rHandler = RequestHandler()
129 |
130 | if __name__ == '__main__':
131 |
132 |
133 | app.run(host=FLAGS.host, port=FLAGS.port)
134 |
--------------------------------------------------------------------------------
/rephrase_server/test_server.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | #调用服务
3 | import time, math
4 | import json
5 | import requests
6 | import tensorflow as tf
7 | from curLine_file import curLine
8 |
9 | user_name="wzk"
10 | test_num_max=300000
11 | language='Chinese'
12 | # 此URL中IP地址为参赛者的服务器地址,应为可达的公网IP地址,端口默认21628
13 | url = "http://0.0.0.0:6000/rephrase"
14 |
15 | # 调用API接口
16 | def http_post(sources_batch):
17 | parameter = {'text_list': sources_batch}
18 | headers = {'Content-type': 'application/json'}
19 | status = -1
20 | output = None
21 | try:
22 | r = requests.post(url, data=json.dumps(
23 | parameter), headers=headers, timeout=10.5)
24 | if r.status_code == 200:
25 | result = r.json()
26 | # print(curLine(),result)
27 | status = result['status']
28 | version = result['version']
29 | if status == 0:
30 | data = result["data"]
31 | output = data['output']
32 | else:
33 | print(curLine(), "version:%s, status=%d, message:%s" % (version, status, result['message']))
34 | else:
35 | print("%sraise wrong,status_code: " % (curLine()), r.status_code)
36 | except Exception as e:
37 | print(curLine(), Exception, ' : ', e)
38 | input(curLine())
39 | return status, output
40 |
41 | def test():
42 | """
43 | 此函数为测试函数,将sh运行在服务器端后,用该程序在另一网络测试
44 | This function is a test function.
45 | Run this function for test in a network while ServerDemo.py is running on a server in a different network
46 | """
47 | sources_list = []
48 | target_list = []
49 | output_file = "/home/cloudminds/Mywork/corpus/rephrase_corpus/pred.tsv"
50 | input_file= "/home/cloudminds/Mywork/corpus/rephrase_corpus/test.txt"
51 | with tf.io.gfile.GFile(input_file) as f:
52 | for line in f:
53 | sources, target, lcs_rate = line.rstrip('\n').split('\t')
54 | sources_list.append(sources) # [sources])
55 | target_list.append(target)
56 | number = len(target_list) # 总样本数
57 | predict_batch_size = min(64, number) # TODO
58 | batch_num = math.ceil(float(number)/predict_batch_size)
59 | num_predicted = 0
60 | with open(output_file, 'w') as writer:
61 | writer.write(f'source\tprediction\ttarget\n')
62 | start_time = time.time()
63 | for batch_id in range(batch_num):
64 | sources_batch = sources_list[batch_id * predict_batch_size: (batch_id + 1)*predict_batch_size]
65 | # prediction_batch = predictor.predict_batch(sources_batch=sources_batch)
66 | status, prediction_batch = http_post(sources_batch)
67 | assert len(prediction_batch) == len(sources_batch)
68 | num_predicted += len(prediction_batch)
69 | for id,[prediction,sources] in enumerate(zip(prediction_batch, sources_batch)):
70 | target = target_list[batch_id * predict_batch_size + id]
71 | writer.write(f'{"".join(sources)}\t{prediction}\t{target}\n')
72 | if batch_id % 20 == 0:
73 | cost_time = (time.time()-start_time)/60.0
74 | print("%s batch_id=%d/%d, predict %d/%d examples, cost %.2fmin." %
75 | (curLine(), batch_id+1, batch_num, num_predicted,number, cost_time))
76 | cost_time = (time.time() - start_time) / 60.0
77 | print(curLine(), "%d predictions saved to %s, cost %f min, ave %f min."
78 | % (num_predicted, output_file, cost_time, cost_time / num_predicted))
79 |
80 |
81 | if __name__ == '__main__':
82 | rougeL_ave=test()
83 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | tensorflow==1.15.0 # CPU Version of TensorFlow.
2 | # tensorflow-gpu==1.15.0 # GPU version of TensorFlow.
3 | absl-py==0.8.1
4 | astor==0.8.0
5 | bert-tensorflow==1.0.1
6 | gast==0.2.2
7 | google-pasta==0.1.7
8 | grpcio==1.24.3
9 | h5py==2.10.0
10 | Keras-Applications==1.0.8
11 | Keras-Preprocessing==1.1.0
12 | Markdown==3.1.1
13 | numpy==1.17.3
14 | pkg-resources==0.0.0
15 | protobuf==3.10.0
16 | scipy==1.3.1
17 | six==1.12.0
18 | tensorboard==1.15.0
19 | tensorflow-estimator==1.15.1
20 | termcolor==1.1.0
21 | Werkzeug==0.16.0
22 | wrapt==1.11.2
23 | flask
24 |
--------------------------------------------------------------------------------
/sari_hook.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2019 The Google Research Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """SARI score for evaluating paraphrasing and other text generation models.
17 |
18 | The score is introduced in the following paper:
19 |
20 | Optimizing Statistical Machine Translation for Text Simplification
21 | Wei Xu, Courtney Napoles, Ellie Pavlick, Quanze Chen and Chris Callison-Burch
22 | In Transactions of the Association for Computational Linguistics (TACL) 2015
23 | http://cs.jhu.edu/~napoles/res/tacl2016-optimizing.pdf
24 |
25 | This implementation has two differences with the GitHub [1] implementation:
26 | (1) Define 0/0=1 instead of 0 to give higher scores for predictions that match
27 | a target exactly.
28 | (2) Fix an alleged bug [2] in the deletion score computation.
29 |
30 | [1] https://github.com/cocoxu/simplification/blob/master/SARI.py
31 | (commit 0210f15)
32 | [2] https://github.com/cocoxu/simplification/issues/6
33 | """
34 |
35 | from __future__ import absolute_import
36 | from __future__ import division
37 | from __future__ import print_function
38 |
39 | import collections
40 |
41 | import numpy as np
42 | import tensorflow as tf
43 |
44 | # The paper that intoduces the SARI score uses only the precision of the deleted
45 | # tokens (i.e. beta=0). To give more emphasis on recall, you may set, e.g.,
46 | # beta=1.
47 | BETA_FOR_SARI_DELETION_F_MEASURE = 0
48 |
49 |
50 | def _get_ngram_counter(ids, n):
51 | """Get a Counter with the ngrams of the given ID list.
52 |
53 | Args:
54 | ids: np.array or a list corresponding to a single sentence
55 | n: n-gram size
56 |
57 | Returns:
58 | collections.Counter with ID tuples as keys and 1s as values.
59 | """
60 | # Remove zero IDs used to pad the sequence.
61 | ids = [token_id for token_id in ids if token_id != 0]
62 | ngram_list = [tuple(ids[i:i + n]) for i in range(len(ids) + 1 - n)]
63 | ngrams = set(ngram_list)
64 | counts = collections.Counter()
65 | for ngram in ngrams:
66 | counts[ngram] = 1
67 | return counts
68 |
69 |
70 | def _get_fbeta_score(true_positives, selected, relevant, beta=1):
71 | """Compute Fbeta score.
72 |
73 | Args:
74 | true_positives: Number of true positive ngrams.
75 | selected: Number of selected ngrams.
76 | relevant: Number of relevant ngrams.
77 | beta: 0 gives precision only, 1 gives F1 score, and Inf gives recall only.
78 |
79 | Returns:
80 | Fbeta score.
81 | """
82 | precision = 1
83 | if selected > 0:
84 | precision = true_positives / selected
85 | if beta == 0:
86 | return precision
87 | recall = 1
88 | if relevant > 0:
89 | recall = true_positives / relevant
90 | if precision > 0 and recall > 0:
91 | beta2 = beta * beta
92 | return (1 + beta2) * precision * recall / (beta2 * precision + recall)
93 | else:
94 | return 0
95 |
96 |
97 | def get_addition_score(source_counts, prediction_counts, target_counts):
98 | """Compute the addition score (Equation 4 in the paper)."""
99 | added_to_prediction_counts = prediction_counts - source_counts
100 | true_positives = sum((added_to_prediction_counts & target_counts).values())
101 | selected = sum(added_to_prediction_counts.values())
102 | # Note that in the paper the summation is done over all the ngrams in the
103 | # output rather than the ngrams in the following set difference. Since the
104 | # former does not make as much sense we compute the latter, which is also done
105 | # in the GitHub implementation.
106 | relevant = sum((target_counts - source_counts).values())
107 | return _get_fbeta_score(true_positives, selected, relevant)
108 |
109 |
110 | def get_keep_score(source_counts, prediction_counts, target_counts):
111 | """Compute the keep score (Equation 5 in the paper)."""
112 | source_and_prediction_counts = source_counts & prediction_counts
113 | source_and_target_counts = source_counts & target_counts
114 | true_positives = sum((source_and_prediction_counts &
115 | source_and_target_counts).values())
116 | selected = sum(source_and_prediction_counts.values())
117 | relevant = sum(source_and_target_counts.values())
118 | return _get_fbeta_score(true_positives, selected, relevant)
119 |
120 |
121 | def get_deletion_score(source_counts, prediction_counts, target_counts, beta=0):
122 | """Compute the deletion score (Equation 6 in the paper)."""
123 | source_not_prediction_counts = source_counts - prediction_counts
124 | source_not_target_counts = source_counts - target_counts
125 | true_positives = sum((source_not_prediction_counts &
126 | source_not_target_counts).values())
127 | selected = sum(source_not_prediction_counts.values())
128 | relevant = sum(source_not_target_counts.values())
129 | return _get_fbeta_score(true_positives, selected, relevant, beta=beta)
130 |
131 |
132 | def get_sari_score(source_ids, prediction_ids, list_of_targets,
133 | max_gram_size=4, beta_for_deletion=0):
134 | """Compute the SARI score for a single prediction and one or more targets.
135 |
136 | Args:
137 | source_ids: a list / np.array of SentencePiece IDs
138 | prediction_ids: a list / np.array of SentencePiece IDs
139 | list_of_targets: a list of target ID lists / np.arrays
140 | max_gram_size: int. largest n-gram size we care about (e.g. 3 for unigrams,
141 | bigrams, and trigrams)
142 | beta_for_deletion: beta for deletion F score.
143 |
144 | Returns:
145 | the SARI score and its three components: add, keep, and deletion scores
146 | """
147 | addition_scores = []
148 | keep_scores = []
149 | deletion_scores = []
150 | for n in range(1, max_gram_size + 1):
151 | source_counts = _get_ngram_counter(source_ids, n)
152 | prediction_counts = _get_ngram_counter(prediction_ids, n)
153 | # All ngrams in the targets with count 1.
154 | target_counts = collections.Counter()
155 | # All ngrams in the targets with count r/num_targets, where r is the number
156 | # of targets where the ngram occurs.
157 | weighted_target_counts = collections.Counter()
158 | num_nonempty_targets = 0
159 | for target_ids_i in list_of_targets:
160 | target_counts_i = _get_ngram_counter(target_ids_i, n)
161 | if target_counts_i:
162 | weighted_target_counts += target_counts_i
163 | num_nonempty_targets += 1
164 | for gram in weighted_target_counts.keys():
165 | weighted_target_counts[gram] /= num_nonempty_targets
166 | target_counts[gram] = 1
167 | keep_scores.append(get_keep_score(source_counts, prediction_counts,
168 | weighted_target_counts))
169 | deletion_scores.append(get_deletion_score(source_counts, prediction_counts,
170 | weighted_target_counts,
171 | beta_for_deletion))
172 | addition_scores.append(get_addition_score(source_counts, prediction_counts,
173 | target_counts))
174 |
175 | avg_keep_score = sum(keep_scores) / max_gram_size
176 | avg_addition_score = sum(addition_scores) / max_gram_size
177 | avg_deletion_score = sum(deletion_scores) / max_gram_size
178 | sari = (avg_keep_score + avg_addition_score + avg_deletion_score) / 3.0
179 | return sari, avg_keep_score, avg_addition_score, avg_deletion_score
180 |
181 |
182 | def get_sari(source_ids, prediction_ids, target_ids, max_gram_size=4):
183 | """Computes the SARI scores from the given source, prediction and targets.
184 |
185 | Args:
186 | source_ids: A 2D tf.Tensor of size (batch_size , sequence_length)
187 | prediction_ids: A 2D tf.Tensor of size (batch_size, sequence_length)
188 | target_ids: A 3D tf.Tensor of size (batch_size, number_of_targets,
189 | sequence_length)
190 | max_gram_size: int. largest n-gram size we care about (e.g. 3 for unigrams,
191 | bigrams, and trigrams)
192 |
193 | Returns:
194 | A 4-tuple of 1D float Tensors of size (batch_size) for the SARI score and
195 | the keep, addition and deletion scores.
196 | """
197 |
198 | def get_sari_numpy(source_ids, prediction_ids, target_ids):
199 | """Iterate over elements in the batch and call the SARI function."""
200 | sari_scores = []
201 | keep_scores = []
202 | add_scores = []
203 | deletion_scores = []
204 | # Iterate over elements in the batch.
205 | for source_ids_i, prediction_ids_i, target_ids_i in zip(
206 | source_ids, prediction_ids, target_ids):
207 | sari, keep, add, deletion = get_sari_score(
208 | source_ids_i, prediction_ids_i, target_ids_i, max_gram_size,
209 | BETA_FOR_SARI_DELETION_F_MEASURE)
210 | sari_scores.append(sari)
211 | keep_scores.append(keep)
212 | add_scores.append(add)
213 | deletion_scores.append(deletion)
214 | return (np.asarray(sari_scores), np.asarray(keep_scores),
215 | np.asarray(add_scores), np.asarray(deletion_scores))
216 |
217 | sari, keep, add, deletion = tf.py_func(
218 | get_sari_numpy,
219 | [source_ids, prediction_ids, target_ids],
220 | [tf.float64, tf.float64, tf.float64, tf.float64])
221 | return sari, keep, add, deletion
222 |
223 |
224 | def sari_score(predictions, labels, features, **unused_kwargs):
225 | """Computes the SARI scores from the given source, prediction and targets.
226 |
227 | An approximate SARI scoring method since we do not glue word pieces or
228 | decode the ids and tokenize the output. By default, we use ngram order of 4.
229 | Also, this does not have beam search.
230 |
231 | Args:
232 | predictions: tensor, model predictions.
233 | labels: tensor, gold output.
234 | features: dict, containing inputs.
235 |
236 | Returns:
237 | sari: int, approx sari score
238 | """
239 | if "inputs" not in features:
240 | raise ValueError("sari_score requires inputs feature")
241 |
242 | # Convert the inputs and outputs to a [batch_size, sequence_length] tensor.
243 | inputs = tf.squeeze(features["inputs"], axis=[-1, -2])
244 | outputs = tf.to_int32(tf.argmax(predictions, axis=-1))
245 | outputs = tf.squeeze(outputs, axis=[-1, -2])
246 |
247 | # Convert the labels to a [batch_size, 1, sequence_length] tensor.
248 | labels = tf.squeeze(labels, axis=[-1, -2])
249 | labels = tf.expand_dims(labels, axis=1)
250 |
251 | score, _, _, _ = get_sari(inputs, outputs, labels)
252 | return score, tf.constant(1.0)
253 |
--------------------------------------------------------------------------------
/score_lib.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2019 The Google Research Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | # Lint as: python3
17 | """Utility functions for computing evaluation metrics."""
18 |
19 | from __future__ import absolute_import
20 | from __future__ import division
21 |
22 | from __future__ import print_function
23 |
24 | import re
25 | from nltk.translate import bleu_score
26 | import numpy as np
27 | import tensorflow as tf
28 | import sari_hook
29 | import utils
30 | from curLine_file import curLine
31 |
32 | def read_data(
33 | path,
34 | lowercase):
35 | """Reads data from prediction TSV file.
36 |
37 | The prediction file should contain 3 or more columns:
38 | 1: sources (concatenated)
39 | 2: prediction
40 | 3-n: targets (1 or more)
41 |
42 | Args:
43 | path: Path to the prediction file.
44 | lowercase: Whether to lowercase the data (to compute case insensitive
45 | scores).
46 |
47 | Returns:
48 | Tuple (list of sources, list of predictions, list of target lists)
49 | """
50 | sources = []
51 | predictions = []
52 | target_lists = []
53 | with tf.gfile.GFile(path) as f:
54 | for line_id, line in enumerate(f):
55 | if line_id == 0:
56 | continue
57 | source, pred, *targets = line.rstrip('\n').split('\t')
58 | if lowercase:
59 | source = source.lower()
60 | pred = pred.lower()
61 | targets = [t.lower() for t in targets]
62 | sources.append(source)
63 | predictions.append(pred)
64 | target_lists.append(targets)
65 | return sources, predictions, target_lists
66 |
67 |
68 | def compute_exact_score(predictions,
69 | target_lists):
70 | """Computes the Exact score (accuracy) of the predictions.
71 |
72 | Exact score is defined as the percentage of predictions that match at least
73 | one of the targets.
74 |
75 | Args:
76 | predictions: List of predictions.
77 | target_lists: List of targets (1 or more per prediction).
78 |
79 | Returns:
80 | Exact score between [0, 1].
81 | """
82 | num_matches = sum(any(pred == target for target in targets)
83 | for pred, targets in zip(predictions, target_lists))
84 | return num_matches / max(len(predictions), 0.1) # Avoids 0/0.
85 |
86 |
87 | def bleu(hyps, refs_list):
88 | """
89 | calculate bleu1, bleu2, bleu3
90 | """
91 | bleu_1 = []
92 | bleu_2 = []
93 |
94 | for hyp, refs in zip(hyps, refs_list):
95 | if len(hyp) <= 1:
96 | # print("ignore hyp:%s, refs:" % hyp, refs)
97 | bleu_1.append(0.0)
98 | bleu_2.append(0.0)
99 | continue
100 |
101 | score = bleu_score.sentence_bleu(
102 | refs, hyp,
103 | smoothing_function=None, # bleu_score.SmoothingFunction().method7,
104 | weights=[1, 0, 0, 0])
105 | # input(curLine())
106 | if score > 1.0:
107 | print(curLine(), refs, hyp)
108 | print(curLine(), "score=", score)
109 | input(curLine())
110 | bleu_1.append(score)
111 |
112 | score = bleu_score.sentence_bleu(
113 | refs, hyp,
114 | smoothing_function=None, # bleu_score.SmoothingFunction().method7,
115 | weights=[0.5, 0.5, 0, 0])
116 | bleu_2.append(score)
117 | bleu_1 = np.average(bleu_1)
118 | bleu_2 = np.average(bleu_2)
119 | bleu_average_score = (bleu_1 + bleu_2) * 0.5
120 | print("bleu_1=%f, bleu_2=%f, bleu_average_score=%f" % (bleu_1, bleu_2, bleu_average_score))
121 | return bleu_average_score
122 |
123 |
124 | def compute_sari_scores(
125 | sources,
126 | predictions,
127 | target_lists,
128 | ignore_wikisplit_separators=True
129 | ):
130 | """Computes SARI scores.
131 |
132 | Wraps the t2t implementation of SARI computation.
133 |
134 | Args:
135 | sources: List of sources.
136 | predictions: List of predictions.
137 | target_lists: List of targets (1 or more per prediction).
138 | ignore_wikisplit_separators: Whether to ignore "<::::>" tokens, used as
139 | sentence separators in Wikisplit, when evaluating. For the numbers
140 | reported in the paper, we accidentally ignored those tokens. Ignoring them
141 | does not affect the Exact score (since there's usually always a period
142 | before the separator to indicate sentence break), but it decreases the
143 | SARI score (since the Addition score goes down as the model doesn't get
144 | points for correctly adding <::::> anymore).
145 |
146 | Returns:
147 | Tuple (SARI score, keep score, addition score, deletion score).
148 | """
149 | sari_sum = 0
150 | keep_sum = 0
151 | add_sum = 0
152 | del_sum = 0
153 | for source, pred, targets in zip(sources, predictions, target_lists):
154 | if ignore_wikisplit_separators:
155 | source = re.sub(' <::::> ', ' ', source)
156 | pred = re.sub(' <::::> ', ' ', pred)
157 | targets = [re.sub(' <::::> ', ' ', t) for t in targets]
158 | source_ids = list(source) # utils.get_token_list(source)
159 | pred_ids = list(pred) # utils.get_token_list(pred)
160 | list_of_targets = [list(t) for t in targets]
161 | sari, keep, addition, deletion = sari_hook.get_sari_score(
162 | source_ids, pred_ids, list_of_targets, beta_for_deletion=1)
163 | sari_sum += sari
164 | keep_sum += keep
165 | add_sum += addition
166 | del_sum += deletion
167 | n = max(len(sources), 0.1) # Avoids 0/0.
168 | return (sari_sum / n, keep_sum / n, add_sum / n, del_sum / n)
169 |
--------------------------------------------------------------------------------
/score_main.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2019 The Google Research Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | # Lint as: python3
17 | """Calculates evaluation scores for a prediction TSV file.
18 |
19 | The prediction file is produced by predict_main.py and should contain 3 or more
20 | columns:
21 | 1: sources (concatenated)
22 | 2: prediction
23 | 3-n: targets (1 or more)
24 | """
25 |
26 | from __future__ import absolute_import
27 | from __future__ import division
28 |
29 | from __future__ import print_function
30 |
31 | from absl import app
32 | from absl import flags
33 | from absl import logging
34 |
35 | import score_lib
36 |
37 | FLAGS = flags.FLAGS
38 |
39 | flags.DEFINE_string(
40 | 'prediction_file', None,
41 | 'TSV file containing source, prediction, and target columns.')
42 | flags.DEFINE_bool(
43 | 'case_insensitive', True,
44 | 'Whether score computation should be case insensitive (in the LaserTagger '
45 | 'paper this was set to True).')
46 |
47 |
48 | def main(argv):
49 | if len(argv) > 1:
50 | raise app.UsageError('Too many command-line arguments.')
51 | flags.mark_flag_as_required('prediction_file')
52 |
53 | sources, predictions, target_lists = score_lib.read_data(
54 | FLAGS.prediction_file, FLAGS.case_insensitive)
55 | logging.info(f'Read file: {FLAGS.prediction_file}')
56 | exact = score_lib.compute_exact_score(predictions, target_lists)
57 | bleu = score_lib.bleu(predictions, target_lists)
58 | sari, keep, addition, deletion = score_lib.compute_sari_scores(
59 | sources, predictions, target_lists)
60 | print("num=", len(predictions))
61 | print(f'Exact score: {100 * exact:.3f}')
62 | print(f'Bleu score: {100 * bleu:.3f}')
63 | print(f'SARI score: {100 * sari:.3f}')
64 | print(f' KEEP score: {100 * keep:.3f}')
65 | print(f' ADDITION score: {100 * addition:.3f}')
66 | print(f' DELETION score: {100 * deletion:.3f}')
67 |
68 |
69 | if __name__ == '__main__':
70 | app.run(main)
71 |
--------------------------------------------------------------------------------
/sentence_fusion_task.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Mleader2/text_scalpel/131922c5dcfc558a46a7590925e348c39cb24951/sentence_fusion_task.png
--------------------------------------------------------------------------------
/skill_rephrase/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Mleader2/text_scalpel/131922c5dcfc558a46a7590925e348c39cb24951/skill_rephrase/__init__.py
--------------------------------------------------------------------------------
/skill_rephrase/predict_for_skill.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # 为任务型语料做泛化 意图和槽位识别
3 |
4 | from __future__ import absolute_import
5 | from __future__ import division
6 | from __future__ import print_function
7 |
8 | from absl import app
9 | from absl import flags
10 | from absl import logging
11 | import math, json
12 | import os, sys, time
13 | from termcolor import colored
14 | import tensorflow as tf
15 |
16 | block_list = os.path.realpath(__file__).split("/")
17 | path = "/".join(block_list[:-2])
18 | sys.path.append(path)
19 |
20 | import bert_example
21 | import predict_utils
22 | import tagging_converter
23 | import utils
24 |
25 | from curLine_file import curLine
26 |
27 | FLAGS = flags.FLAGS
28 |
29 | flags.DEFINE_string(
30 | 'input_file', None,
31 | 'Path to the input file containing examples for which to compute '
32 | 'predictions.')
33 | flags.DEFINE_enum(
34 | 'input_format', None, ['wikisplit', 'discofuse'],
35 | 'Format which indicates how to parse the input_file.')
36 | flags.DEFINE_string(
37 | 'output_file', None,
38 | 'Path to the TSV file where the predictions are written to.')
39 | flags.DEFINE_string(
40 | 'label_map_file', None,
41 | 'Path to the label map file. Either a JSON file ending with ".json", that '
42 | 'maps each possible tag to an ID, or a text file that has one tag per '
43 | 'line.')
44 | flags.DEFINE_string('vocab_file', None, 'Path to the BERT vocabulary file.')
45 | flags.DEFINE_integer('max_seq_length', 128, 'Maximum sequence length.')
46 | flags.DEFINE_bool(
47 | 'do_lower_case', False,
48 | 'Whether to lower case the input text. Should be True for uncased '
49 | 'models and False for cased models.')
50 | flags.DEFINE_bool('enable_swap_tag', True, 'Whether to enable the SWAP tag.')
51 | flags.DEFINE_string('saved_model', None, 'Path to an exported TF model.')
52 |
53 |
54 | def main(argv):
55 | if len(argv) > 1:
56 | raise app.UsageError('Too many command-line arguments.')
57 | flags.mark_flag_as_required('input_file')
58 | flags.mark_flag_as_required('input_format')
59 | flags.mark_flag_as_required('output_file')
60 | flags.mark_flag_as_required('label_map_file')
61 | flags.mark_flag_as_required('vocab_file')
62 | flags.mark_flag_as_required('saved_model')
63 |
64 | label_map = utils.read_label_map(FLAGS.label_map_file)
65 | converter = tagging_converter.TaggingConverter(
66 | tagging_converter.get_phrase_vocabulary_from_label_map(label_map),
67 | FLAGS.enable_swap_tag)
68 | builder = bert_example.BertExampleBuilder(label_map, FLAGS.vocab_file,
69 | FLAGS.max_seq_length,
70 | FLAGS.do_lower_case, converter)
71 | predictor = predict_utils.LaserTaggerPredictor(
72 | tf.contrib.predictor.from_saved_model(FLAGS.saved_model), builder,
73 | label_map)
74 | print(colored("%s input file:%s" % (curLine(), FLAGS.input_file), "red"))
75 | num_predicted = 0
76 |
77 | sources_list = []
78 | location_list = []
79 | corpus_id_list = []
80 | entity_list = []
81 | domainname_list = []
82 | intentname_list = []
83 | context_list = []
84 | template_id_list = []
85 | with open(FLAGS.input_file, "r") as f:
86 | corpus_json_list = json.load(f)
87 | # corpus_json_list = corpus_json_list[:100]
88 | for corpus_json in corpus_json_list:
89 | sources_list.append([corpus_json["oriText"]])
90 | location_list.append(corpus_json["location"])
91 | corpus_id_list.append(corpus_json["corpus_id"])
92 | entity_list.append(corpus_json["entity"])
93 | domainname_list.append(corpus_json["domainname"])
94 | intentname_list.append(corpus_json["intentname"])
95 | context_list.append(corpus_json["context"])
96 | template_id_list.append(corpus_json["template_id"])
97 | number = len(sources_list) # 总样本数
98 | predict_batch_size = min(64, number)
99 | batch_num = math.ceil(float(number) / predict_batch_size)
100 | start_time = time.time()
101 | index = 0
102 | for batch_id in range(batch_num):
103 | sources_batch = sources_list[batch_id * predict_batch_size: (batch_id + 1) * predict_batch_size]
104 | location_batch = location_list[batch_id * predict_batch_size: (batch_id + 1) * predict_batch_size]
105 | prediction_batch = predictor.predict_batch(sources_batch=sources_batch, location_batch=location_batch)
106 | assert len(prediction_batch) == len(sources_batch)
107 | num_predicted += len(prediction_batch)
108 | for id, [prediction, sources] in enumerate(zip(prediction_batch, sources_batch)):
109 | index = batch_id * predict_batch_size + id
110 | output_json = {"corpus_id": corpus_id_list[index], "oriText": prediction, "sources": sources[0],
111 | "entity": entity_list[index], "location": location_list[index],
112 | "domainname": domainname_list[index], "intentname": intentname_list[index],
113 | "context": context_list[index], "template_id": template_id_list[index]}
114 | corpus_json_list[index] = output_json
115 | if batch_id % 20 == 0:
116 | cost_time = (time.time() - start_time) / 60.0
117 | print("%s batch_id=%d/%d, predict %d/%d examples, cost %.2fmin." %
118 | (curLine(), batch_id + 1, batch_num, num_predicted, number, cost_time))
119 | assert len(corpus_json_list) == index + 1
120 | with open(FLAGS.output_file, 'w', encoding='utf-8') as writer:
121 | json.dump(corpus_json_list, writer, ensure_ascii=False, indent=4)
122 | cost_time = (time.time() - start_time) / 60.0
123 | logging.info(
124 | f'{curLine()} {num_predicted} predictions saved to:{FLAGS.output_file}, cost {cost_time} min, ave {cost_time / num_predicted} min.')
125 |
126 |
127 | if __name__ == '__main__':
128 | app.run(main)
129 |
--------------------------------------------------------------------------------
/skill_rephrase/score_for_skill.txt:
--------------------------------------------------------------------------------
1 | # 扩展times的语料前
2 | [trainer.py:182] f1_score_slot= 99.93185946253503 f1_score_intent= 99.05839393642326 f1_score= 99.78628187484973
3 | [trainer.py:199] The best dev: epoch=3, score=99.99
4 | [trainer.py:200] The corresponding test: epoch=3, score=100.00
5 | [trainer.py:201] Finished archiving the models, start final testing.
6 | [trainer.py:203] model_file: /home/cloudminds/Mywork/corpus/ner_corpus/times_model/lstm_crf.m
7 | [trainer.py:253] cost_time=0.031036s, ave_time=0.341051ms
8 | [trainer.py:263] type:number P:100.00%, R:100.00%, F1:100.000%
9 | [trainer.py:263] type:year P:100.00%, R:100.00%, F1:100.000%
10 | [trainer.py:263] type:date P:100.00%, R:100.00%, F1:100.000%
11 | [trainer.py:289] test_slot P:100.00, R:100.00, F1:100.000
12 |
13 | [trainer.py:302] f1_weighted_intent=91.784692, f1_macro_intent=93.101826
14 | [trainer.py:303] f1_score_intent= 92.0481191049753 classification_report_intent:
15 | precision recall f1-score support
16 |
17 | 0 1.000 1.000 1.000 4
18 | 1 1.000 0.571 0.727 14
19 | 2 1.000 1.000 1.000 18
20 | 3 0.962 0.962 0.962 26
21 | 4 0.773 1.000 0.872 17
22 | 5 1.000 1.000 1.000 1
23 | 6 0.917 1.000 0.957 11
24 |
25 | accuracy 0.923 91
26 | macro avg 0.950 0.933 0.931 91
27 | weighted avg 0.936 0.923 0.918 91
28 |
29 | # 扩展times的语料后 下降一点,是因为没有针对这些场景么
30 | [trainer.py:182] f1_score_slot= 99.87074867480472 f1_score_intent= 98.7267364507435 f1_score= 99.68007997079451
31 | [trainer.py:199] The best dev: epoch=8, score=99.92
32 | [trainer.py:200] The corresponding test: epoch=8, score=100.00
33 | [trainer.py:201] Finished archiving the models, start final testing.
34 | [trainer.py:203] model_file: /home/cloudminds/Mywork/corpus/ner_corpus/times_model/lstm_crf.m
35 | [trainer.py:253] cost_time=0.029509s, ave_time=0.324270ms
36 | [trainer.py:263] type:number P:100.00%, R:100.00%, F1:100.000%
37 | [trainer.py:263] type:year P:100.00%, R:100.00%, F1:100.000%
38 | [trainer.py:263] type:date P:100.00%, R:100.00%, F1:100.000%
39 | [trainer.py:289] test_slot P:100.00, R:100.00, F1:100.000
40 |
41 | [trainer.py:302] f1_weighted_intent=89.032198, f1_macro_intent=90.676121
42 | [trainer.py:303] f1_score_intent= 89.36098260170371 classification_report_intent:
43 | precision recall f1-score support
44 |
45 | 0 1.000 1.000 1.000 4
46 | 1 1.000 0.429 0.600 14
47 | 2 1.000 1.000 1.000 18
48 | 3 0.962 0.962 0.962 26
49 | 4 0.708 1.000 0.829 17
50 | 5 1.000 1.000 1.000 1
51 | 6 0.917 1.000 0.957 11
52 |
53 | accuracy 0.901 91
54 | macro avg 0.941 0.913 0.907 91
55 | weighted avg 0.924 0.901 0.890 91
56 |
--------------------------------------------------------------------------------
/tagging.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2019 The Google Research Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Classes representing a tag and a text editing task.
17 |
18 | Tag corresponds to an edit operation, while EditingTask is a container for the
19 | input that LaserTagger takes. EditingTask also has a method for realizing the
20 | output text given the predicted tags.
21 | """
22 |
23 | from __future__ import absolute_import
24 | from __future__ import division
25 |
26 | from __future__ import print_function
27 |
28 | from enum import Enum
29 | from curLine_file import curLine
30 |
31 | import utils
32 |
33 |
34 | class TagType(Enum):
35 | """Base tag which indicates the type of an edit operation."""
36 | # Keep the tagged token.
37 | KEEP = 1
38 | # Delete the tagged token.
39 | DELETE = 2
40 | # Keep the tagged token but swap the order of sentences. This tag is only
41 | # applied if there are two source texts and the tag is applied to the last
42 | # token of the first source. In other contexts, it's treated as KEEP.
43 | SWAP = 3
44 |
45 |
46 | class Tag(object):
47 | """Tag that corresponds to a token edit operation.
48 |
49 | Attributes:
50 | tag_type: TagType of the tag.
51 | added_phrase: A phrase that's inserted before the tagged token (can be
52 | empty).
53 | """
54 |
55 | def __init__(self, tag):
56 | """Constructs a Tag object by parsing tag to tag_type and added_phrase.
57 |
58 | Args:
59 | tag: String representation for the tag which should have the following
60 | format "|" or simply "" if no phrase
61 | is added before the tagged token. Examples of valid tags include "KEEP",
62 | "DELETE|and", and "SWAP|.".
63 |
64 | Raises:
65 | ValueError: If is invalid.
66 | """
67 | if '|' in tag:
68 | pos_pipe = tag.index('|') # 也可以直接split,再把[1:]拼接成一个字符串吧
69 | tag_type, added_phrase = tag[:pos_pipe], tag[pos_pipe + 1:]
70 | else:
71 | tag_type, added_phrase = tag, ''
72 | try:
73 | self.tag_type = TagType[tag_type] # for example: tag_type:KEEP self.tag_type:TagType.KEEP
74 | except KeyError:
75 | raise ValueError(
76 | 'TagType should be KEEP, DELETE or SWAP, not {}'.format(tag_type))
77 | self.added_phrase = added_phrase
78 |
79 | def __str__(self):
80 | if not self.added_phrase:
81 | return self.tag_type.name
82 | else:
83 | return '{}|{}'.format(self.tag_type.name, self.added_phrase)
84 |
85 |
86 | class EditingTask(object):
87 | """Text-editing task.
88 |
89 | Attributes:
90 | sources: Source texts.
91 | source_tokens: Tokens of the source texts concatenated into a single list.
92 | first_tokens: The indices of the first tokens of each source text.
93 | """
94 |
95 | def __init__(self, sources, location=None, tokenizer=None):
96 | """Initializes an instance of EditingTask.
97 |
98 | Args:
99 | sources: A list of source strings. Typically contains only one string but
100 | for sentence fusion it contains two strings to be fused (whose order may
101 | be swapped).
102 | location: None或字符串, 0表示能变,1表示不能变
103 | """
104 | self.sep = '' # for Chinses
105 | self.sources = sources
106 | source_token_lists = [tokenizer.tokenize(text) for text in self.sources]
107 | # Tokens of the source texts concatenated into a single list.
108 | self.source_tokens = []
109 | # The indices of the first tokens of each source text.
110 | self.first_tokens = []
111 | for token_list in source_token_lists:
112 | self.first_tokens.append(len(self.source_tokens))
113 | self.source_tokens.extend(token_list)
114 | self.location = location
115 |
116 | self.token_index_map = {} # 为处理生成的UNK
117 | previous_id = 0
118 | for tokenizer_id, t in enumerate(self.source_tokens):
119 | if tokenizer_id > 0 and "UNK" in self.source_tokens[tokenizer_id - 1]:
120 | if t in self.source_tokens[previous_id:]:
121 | previous_id = previous_id + self.source_tokens[previous_id:].index(t)
122 | else: # 出现连续的UNK情况,目前的做法是假设长度为1
123 | previous_id += 1
124 | self.token_index_map[tokenizer_id] = previous_id
125 | if "UNK" not in t:
126 | length_t = len(t)
127 | if t.startswith("##", 0, 2):
128 | length_t -= 2
129 | previous_id += length_t
130 |
131 | def _realize_sequence(self, tokens, tags):
132 | """Realizes output text corresponding to a single source text.
133 |
134 | Args:
135 | tokens: Tokens of the source text.
136 | tags: Tags indicating the edit operations.
137 |
138 | Returns:
139 | The realized text.
140 | """
141 | output_tokens = []
142 | for index, (token, tag) in enumerate(zip(tokens, tags)):
143 | loc = "0"
144 | if self.location is not None:
145 | loc = self.location[index]
146 | if tag.added_phrase and (
147 | loc == "0" or index == 0 or (index > 0 and self.location[index - 1] == "0")): # TODO
148 | if not tag.added_phrase.startswith("##", 0, 2):
149 | output_tokens.append(tag.added_phrase)
150 | else: # word piece
151 | if len(output_tokens) > 0:
152 | output_tokens[-1] += tag.added_phrase[2:]
153 | else:
154 | output_tokens.append(tag.added_phrase[2:])
155 | if tag.tag_type in (
156 | TagType.KEEP, TagType.SWAP) or loc == "1": # TODO 根据需要修改代码,location为"1"的位置不能被删除, 但目前是可以插入的
157 | token = token.upper() # TODO 因为当前语料中有不少都是大写的,所以把预测结果都转为大写
158 | if token.startswith("##", 0, 2):
159 | output_tokens.append(token[2:])
160 | elif "UNK" in token: # 处理UNK的情况
161 | previoud_id = self.token_index_map[index] # unk对应word开始的位置
162 | next_previoud_id = previoud_id + 1 # unk对应word结束的位置
163 | if index + 1 in self.token_index_map:
164 | next_previoud_id = self.token_index_map[index + 1]
165 | token = self.sources[0][previoud_id:next_previoud_id] # TODO
166 | print(curLine(), "self.passage[%d,%d]=%s" % (previoud_id, next_previoud_id, token))
167 | output_tokens.append(token)
168 | else: # word piece
169 | output_tokens.append(token)
170 | return self.sep.join(output_tokens)
171 |
172 | def _first_char_to_upper(self, text):
173 | """Upcases the first character of the text."""
174 | try:
175 | return text[0].upper() + text[1:]
176 | except IndexError:
177 | return text
178 |
179 | def _first_char_to_lower(self, text):
180 | """Lowcases the first character of the text."""
181 | try:
182 | return text[0].lower() + text[1:]
183 | except IndexError:
184 | return text
185 |
186 | def realize_output(self, tags):
187 | """Realize output text based on the source tokens and predicted tags.
188 |
189 | Args:
190 | tags: Predicted tags (one for each token in `self.source_tokens`).
191 |
192 | Returns:
193 | The realizer output text.
194 |
195 | Raises:
196 | ValueError: If the number of tags doesn't match the number of source
197 | tokens.
198 | """
199 | if len(tags) != len(self.source_tokens):
200 | raise ValueError('The number of tags ({}) should match the number of '
201 | 'source tokens ({})'.format(
202 | len(tags), len(self.source_tokens)))
203 | outputs = [] # Realized sources that are joined into the output text.
204 | if (len(self.first_tokens) == 2 and
205 | tags[self.first_tokens[1] - 1].tag_type == TagType.SWAP):
206 | order = [1, 0]
207 | else:
208 | order = range(len(self.first_tokens))
209 | for source_idx in order:
210 | # Get the span of tokens for the source: [first_token, last_token).
211 | first_token = self.first_tokens[source_idx]
212 | if source_idx + 1 < len(self.first_tokens):
213 | last_token = self.first_tokens[source_idx + 1] # Not inclusive.
214 | else:
215 | last_token = len(self.source_tokens)
216 | # Realize the source and fix casing.
217 | realized_source = self._realize_sequence(
218 | self.source_tokens[first_token:last_token],
219 | tags[first_token:last_token])
220 | if outputs:
221 | if len(outputs[0][-1:]) > 0 and outputs[0][-1:] in '.!?':
222 | realized_source = self._first_char_to_upper(realized_source) # 变大写
223 | else:
224 | # Note that ideally we should also test here whether the first word is
225 | # a proper noun or an abbreviation that should always be capitalized.
226 | realized_source = self._first_char_to_lower(realized_source) # 变小写
227 | # print(curLine(), len(outputs[0][-1:]), "outputs[0][-1:]:", outputs[0][-1:], "source_idx=",source_idx, ",realized_source:", realized_source)
228 | outputs.append(realized_source)
229 | prediction = self.sep.join(outputs)
230 | return prediction
231 |
--------------------------------------------------------------------------------
/tagging_converter.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2019 The Google Research Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Conversion from training target text into target tags.
17 |
18 | The conversion algorithm from (source, target) pairs to (source, target_tags)
19 | pairs is described in Algorithm 1 of the LaserTagger paper
20 | (https://arxiv.org/abs/1909.01187).
21 | """
22 |
23 | from __future__ import absolute_import
24 | from __future__ import division
25 |
26 | from __future__ import print_function
27 |
28 | import tagging
29 | import utils
30 |
31 | from typing import Iterable, Mapping, Sequence, Set, Text, Tuple
32 |
33 |
34 | class TaggingConverter(object):
35 | """Converter from training target texts into tagging format."""
36 |
37 | def __init__(self, phrase_vocabulary, do_swap=True, tokenizer=None):
38 | """Initializes an instance of TaggingConverter.
39 |
40 | Args:
41 | phrase_vocabulary: Iterable of phrase vocabulary items (strings).
42 | do_swap: Whether to enable the SWAP tag.
43 | """
44 | self._phrase_vocabulary = set(
45 | phrase.lower() for phrase in phrase_vocabulary)
46 | self._do_swap = do_swap
47 | # Maximum number of tokens in an added phrase (inferred from the
48 | # vocabulary).
49 | self._max_added_phrase_length = 0
50 | # Set of tokens that are part of a phrase in self.phrase_vocabulary.
51 | self._token_vocabulary = set() # word piece 的集合
52 | for phrase in self._phrase_vocabulary:
53 | tokens = tokenizer.tokenize(phrase)
54 | self._token_vocabulary |= set(tokens) # 集合的合并
55 | if len(tokens) > self._max_added_phrase_length:
56 | self._max_added_phrase_length = len(tokens)
57 |
58 | def compute_tags(self, task, target, tokenizer):
59 | """Computes tags needed for converting the source into the target.
60 |
61 | Args:
62 | task: tagging.EditingTask that specifies the input.
63 | target: Target text.
64 |
65 | Returns:
66 | List of tagging.Tag objects. If the source couldn't be converted into the
67 | target via tagging, returns an empty list.
68 | """
69 | target_tokens = tokenizer.tokenize(target.lower())
70 | tags = self._compute_tags_fixed_order(task.source_tokens, target_tokens)
71 | # If conversion fails, try to obtain the target after swapping the source
72 | # order.
73 | if not tags and len(task.sources) == 2 and self._do_swap:
74 | swapped_task = tagging.EditingTask(task.sources[::-1])
75 | tags = self._compute_tags_fixed_order(swapped_task.source_tokens,
76 | target_tokens)
77 | if tags:
78 | tags = (tags[swapped_task.first_tokens[1]:] +
79 | tags[:swapped_task.first_tokens[1]])
80 | # We assume that the last token (typically a period) is never deleted,
81 | # so we can overwrite the tag_type with SWAP (which keeps the token,
82 | # moving it and the sentence it's part of to the end).
83 | tags[task.first_tokens[1] - 1].tag_type = tagging.TagType.SWAP
84 | return tags
85 |
86 | def _compute_tags_fixed_order(self, source_tokens, target_tokens):
87 | """Computes tags when the order of sources is fixed.
88 |
89 | Args:
90 | source_tokens: List of source tokens.
91 | target_tokens: List of tokens to be obtained via edit operations.
92 |
93 | Returns:
94 | List of tagging.Tag objects. If the source couldn't be converted into the
95 | target via tagging, returns an empty list.
96 | """
97 | tags = [tagging.Tag('DELETE') for _ in source_tokens]
98 | # Indices of the tokens currently being processed.
99 | source_token_idx = 0
100 | target_token_idx = 0
101 | while target_token_idx < len(target_tokens):
102 | tags[source_token_idx], target_token_idx = self._compute_single_tag(
103 | source_tokens[source_token_idx], target_token_idx, target_tokens)
104 | # TODO 可以有多种标注方式从source转化为target,目前限定到一种
105 | # If we're adding a phrase and the previous source token(s) were deleted,
106 | # we could add the phrase before a previously deleted token and still get
107 | # the same realized output. For example:
108 | # [DELETE, DELETE, KEEP|"what is"]
109 | # and
110 | # [DELETE|"what is", DELETE, KEEP]
111 | # Would yield the same realized output. Experimentally, we noticed that
112 | # the model works better / the learning task becomes easier when phrases
113 | # are always added before the first deleted token. Also note that in the
114 | # current implementation, this way of moving the added phrase backward is
115 | # the only way a DELETE tag can have an added phrase, so sequences like
116 | # [DELETE|"What", DELETE|"is"] will never be created.
117 | if tags[source_token_idx].added_phrase:
118 | # # the learning task becomes easier when phrases are always added before the first deleted token
119 | first_deletion_idx = self._find_first_deletion_idx(
120 | source_token_idx, tags)
121 | if first_deletion_idx != source_token_idx:
122 | tags[first_deletion_idx].added_phrase = (
123 | tags[source_token_idx].added_phrase)
124 | tags[source_token_idx].added_phrase = ''
125 | source_token_idx += 1
126 | if source_token_idx >= len(tags):
127 | break
128 |
129 | # If all target tokens have been consumed, we have found a conversion and
130 | # can return the tags. Note that if there are remaining source tokens, they
131 | # are already marked deleted when initializing the tag list.
132 | if target_token_idx >= len(target_tokens): # all target tokens have been consumed
133 | return tags
134 | return [] # TODO 不能转化
135 |
136 | def _compute_single_tag(
137 | self, source_token, target_token_idx,
138 | target_tokens):
139 | """Computes a single tag.
140 |
141 | The tag may match multiple target tokens (via tag.added_phrase) so we return
142 | the next unmatched target token.
143 |
144 | Args:
145 | source_token: The token to be tagged.
146 | target_token_idx: Index of the current target tag.
147 | target_tokens: List of all target tokens.
148 |
149 | Returns:
150 | A tuple with (1) the computed tag and (2) the next target_token_idx.
151 | """
152 | source_token = source_token.lower()
153 | target_token = target_tokens[target_token_idx].lower()
154 | if source_token == target_token:
155 | return tagging.Tag('KEEP'), target_token_idx + 1
156 | # source_token!=target_token的情况
157 | added_phrase = ''
158 | for num_added_tokens in range(1, self._max_added_phrase_length + 1):
159 | if target_token not in self._token_vocabulary:
160 | break
161 | added_phrase += (' ' if added_phrase else '') + target_token
162 | next_target_token_idx = target_token_idx + num_added_tokens
163 | if next_target_token_idx >= len(target_tokens): # 已经完成转化
164 | break
165 | target_token = target_tokens[next_target_token_idx].lower()
166 | if (source_token == target_token and
167 | added_phrase in self._phrase_vocabulary):
168 | return tagging.Tag('KEEP|' + added_phrase), next_target_token_idx + 1
169 | return tagging.Tag('DELETE'), target_token_idx
170 |
171 | def _find_first_deletion_idx(self, source_token_idx, tags):
172 | """Finds the start index of a span of deleted tokens.
173 |
174 | If `source_token_idx` is preceded by a span of deleted tokens, finds the
175 | start index of the span. Otherwise, returns `source_token_idx`.
176 |
177 | Args:
178 | source_token_idx: Index of the current source token.
179 | tags: List of tags.
180 |
181 | Returns:
182 | The index of the first deleted token preceding `source_token_idx` or
183 | `source_token_idx` if there are no deleted tokens right before it.
184 | """
185 | # Backtrack until the beginning of the tag sequence.
186 | for idx in range(source_token_idx, 0, -1):
187 | if tags[idx - 1].tag_type != tagging.TagType.DELETE:
188 | return idx
189 | return 0
190 |
191 |
192 | def get_phrase_vocabulary_from_label_map(
193 | label_map):
194 | """Extract the set of all phrases from label map.
195 |
196 | Args:
197 | label_map: Mapping from tags to tag IDs.
198 |
199 | Returns:
200 | Set of all phrases appearing in the label map.
201 | """
202 | phrase_vocabulary = set()
203 | for label in label_map.keys():
204 | tag = tagging.Tag(label)
205 | if tag.added_phrase:
206 | phrase_vocabulary.add(tag.added_phrase)
207 | return phrase_vocabulary
208 |
--------------------------------------------------------------------------------
/transformer_decoder.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2019 The Google Research Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | # Lint as: python3
17 | """Transformer decoder."""
18 | from __future__ import absolute_import
19 | from __future__ import division
20 |
21 | from __future__ import print_function
22 |
23 | from typing import Any, Mapping, Text
24 |
25 | import tensorflow as tf
26 |
27 | from official_transformer import attention_layer
28 | from official_transformer import embedding_layer
29 | from official_transformer import ffn_layer
30 | from official_transformer import model_utils
31 | from official_transformer import transformer
32 |
33 |
34 | class TransformerDecoder(transformer.Transformer):
35 | """Transformer decoder.
36 |
37 | Attributes:
38 | train: Whether the model is in training mode.
39 | params: Model hyperparameters.
40 | """
41 |
42 | def __init__(self, params, train):
43 | """Initializes layers to build Transformer model.
44 |
45 | Args:
46 | params: hyperparameter object defining layer sizes, dropout values, etc.
47 | train: boolean indicating whether the model is in training mode. Used to
48 | determine if dropout layers should be added.
49 | """
50 | self.train = train
51 | self.params = params
52 | self.embedding_softmax_layer = embedding_layer.EmbeddingSharedWeights(
53 | params["vocab_size"], params["hidden_size"],
54 | method="matmul" if params["use_tpu"] else "gather")
55 | # override self.decoder_stack
56 | if self.params["use_full_attention"]:
57 | self.decoder_stack = transformer.DecoderStack(params, train)
58 | else:
59 | self.decoder_stack = DecoderStack(params, train)
60 |
61 | def __call__(self, inputs, encoder_outputs, targets=None):
62 | """Calculates target logits or inferred target sequences.
63 |
64 | Args:
65 | inputs: int tensor with shape [batch_size, input_length].
66 | encoder_outputs: int tensor with shape
67 | [batch_size, input_length, hidden_size]
68 | targets: None or int tensor with shape [batch_size, target_length].
69 |
70 | Returns:
71 | If targets is defined, then return logits for each word in the target
72 | sequence. float tensor with shape [batch_size, target_length, vocab_size]
73 | If target is none, then generate output sequence one token at a time.
74 | returns a dictionary {
75 | output: [batch_size, decoded length]
76 | score: [batch_size, float]}
77 | """
78 | # Variance scaling is used here because it seems to work in many problems.
79 | # Other reasonable initializers may also work just as well.
80 | initializer = tf.variance_scaling_initializer(
81 | self.params["initializer_gain"], mode="fan_avg", distribution="uniform")
82 | with tf.variable_scope("Transformer", initializer=initializer):
83 | # Calculate attention bias for encoder self-attention and decoder
84 | # multi-headed attention layers.
85 | attention_bias = model_utils.get_padding_bias(inputs)
86 |
87 | # Generate output sequence if targets is None, or return logits if target
88 | # sequence is known.
89 | if targets is None:
90 | return self.predict(encoder_outputs, attention_bias)
91 | else:
92 | logits = self.decode(targets, encoder_outputs, attention_bias)
93 | return logits
94 |
95 | def _get_symbols_to_logits_fn(self, max_decode_length):
96 | """Returns a decoding function that calculates logits of the next tokens."""
97 |
98 | timing_signal = model_utils.get_position_encoding(
99 | max_decode_length + 1, self.params["hidden_size"])
100 | decoder_self_attention_bias = model_utils.get_decoder_self_attention_bias(
101 | max_decode_length)
102 |
103 | def symbols_to_logits_fn(ids, i, cache):
104 | """Generate logits for next potential IDs.
105 |
106 | Args:
107 | ids: Current decoded sequences.
108 | int tensor with shape [batch_size * beam_size, i + 1]
109 | i: Loop index
110 | cache: dictionary of values storing the encoder output, encoder-decoder
111 | attention bias, and previous decoder attention values.
112 |
113 | Returns:
114 | Tuple of
115 | (logits with shape [batch_size * beam_size, vocab_size],
116 | updated cache values)
117 | """
118 | # Set decoder input to the last generated IDs
119 | decoder_input = ids[:, -1:]
120 |
121 | # Preprocess decoder input by getting embeddings and adding timing signal.
122 | decoder_input = self.embedding_softmax_layer(decoder_input)
123 | decoder_input += timing_signal[i:i + 1]
124 |
125 | self_attention_bias = decoder_self_attention_bias[:, :, i:i + 1, :i + 1]
126 | if self.params["use_full_attention"]:
127 | encoder_outputs = cache.get("encoder_outputs")
128 | else:
129 | encoder_outputs = cache.get("encoder_outputs")[:, i:i + 1]
130 | decoder_outputs = self.decoder_stack(
131 | decoder_input, encoder_outputs, self_attention_bias,
132 | cache.get("encoder_decoder_attention_bias"), cache)
133 | logits = self.embedding_softmax_layer.linear(decoder_outputs)
134 | logits = tf.squeeze(logits, axis=[1])
135 | return logits, cache
136 |
137 | return symbols_to_logits_fn
138 |
139 |
140 | class DecoderStack(tf.layers.Layer):
141 | """Modified Transformer decoder stack.
142 |
143 | Like the standard Transformer decoder stack but:
144 | 1. Removes the encoder-decoder attention layer, and
145 | 2. Adds a layer to project the concatenated [encoder activations, hidden
146 | state] to the hidden size.
147 | """
148 |
149 | def __init__(self, params, train):
150 | super(DecoderStack, self).__init__()
151 | self.layers = []
152 | for _ in range(params["num_hidden_layers"]):
153 | self_attention_layer = attention_layer.SelfAttention(
154 | params["hidden_size"], params["num_heads"],
155 | params["attention_dropout"], train)
156 | feed_forward_network = ffn_layer.FeedFowardNetwork( # NOTYPO
157 | params["hidden_size"], params["filter_size"],
158 | params["relu_dropout"], train, params["allow_ffn_pad"])
159 |
160 | proj_layer = tf.layers.Dense(
161 | # TODO 加了一层MLP,project the concatenated [encoder activations, hidden state] to the hidden size
162 | params["hidden_size"], use_bias=True, name="proj_layer")
163 |
164 | self.layers.append([
165 | transformer.PrePostProcessingWrapper(
166 | self_attention_layer, params, train),
167 | transformer.PrePostProcessingWrapper(
168 | feed_forward_network, params, train),
169 | proj_layer])
170 |
171 | self.output_normalization = transformer.LayerNormalization(
172 | params["hidden_size"])
173 |
174 | def call(self, decoder_inputs, encoder_outputs, decoder_self_attention_bias,
175 | attention_bias=None, cache=None):
176 | """Returns the output of the decoder layer stacks.
177 |
178 | Args:
179 | decoder_inputs: tensor with shape [batch_size, target_length, hidden_size]
180 | encoder_outputs: tensor with shape [batch_size, input_length, hidden_size]
181 | decoder_self_attention_bias: bias for decoder self-attention layer.
182 | [1, 1, target_len, target_length]
183 | attention_bias: bias for encoder-decoder attention layer.
184 | [batch_size, 1, 1, input_length]
185 | cache: (Used for fast decoding) A nested dictionary storing previous
186 | decoder self-attention values. The items are:
187 | {layer_n: {"k": tensor with shape [batch_size, i, key_channels],
188 | "v": tensor with shape [batch_size, i, value_channels]},
189 | ...}
190 |
191 | Returns:
192 | Output of decoder layer stack.
193 | float32 tensor with shape [batch_size, target_length, hidden_size]
194 | """
195 | for n, layer in enumerate(self.layers):
196 | self_attention_layer = layer[0]
197 | feed_forward_network = layer[1]
198 | proj_layer = layer[2]
199 |
200 | decoder_inputs = tf.concat([decoder_inputs, encoder_outputs], axis=-1)
201 | decoder_inputs = proj_layer(decoder_inputs)
202 |
203 | # Run inputs through the sublayers.
204 | layer_name = "layer_%d" % n
205 | layer_cache = cache[layer_name] if cache is not None else None
206 | with tf.variable_scope(layer_name):
207 | with tf.variable_scope("self_attention"):
208 | decoder_inputs = self_attention_layer(
209 | decoder_inputs, decoder_self_attention_bias, cache=layer_cache)
210 | with tf.variable_scope("ffn"):
211 | decoder_inputs = feed_forward_network(decoder_inputs)
212 |
213 | return self.output_normalization(decoder_inputs)
214 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2019 The Google Research Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | # Lint as: python3
17 | """Utility functions for LaserTagger."""
18 |
19 | from __future__ import absolute_import
20 | from __future__ import division
21 |
22 | from __future__ import print_function
23 |
24 | import json
25 | from bert import tokenization
26 | import tensorflow as tf
27 |
28 | ### 中文用word piece, 保留空格
29 | class my_tokenizer_class(object):
30 | def __init__(self, vocab_file, do_lower_case):
31 | self.full_tokenizer = tokenization.FullTokenizer(vocab_file, do_lower_case=do_lower_case)
32 |
33 | # 需要包装一下,因为如果直接对中文用full_tokenizer.tokenize,会忽略文本中的空格
34 | def tokenize(self, text):
35 | segments = text.split(" ")
36 | word_pieces = []
37 | for segId, segment in enumerate(segments):
38 | if segId > 0:
39 | word_pieces.append(" ")
40 | word_pieces.extend(self.full_tokenizer.tokenize(segment))
41 | return word_pieces
42 |
43 | def convert_tokens_to_ids(self, tokens):
44 | id_list = [self.full_tokenizer.vocab[t]
45 | if t != " " else self.full_tokenizer.vocab["[unused20]"] for t in tokens]
46 | return id_list
47 |
48 |
49 | def yield_sources_and_targets(
50 | input_file,
51 | input_format):
52 | """Reads and yields source lists and targets from the input file.
53 |
54 | Args:
55 | input_file: Path to the input file.
56 | input_format: Format of the input file.
57 |
58 | Yields:
59 | Tuple with (list of source texts, target text).
60 | """
61 | if input_format == 'wikisplit':
62 | yield_example_fn = _yield_wikisplit_examples
63 | elif input_format == 'discofuse':
64 | yield_example_fn = _yield_discofuse_examples
65 | else:
66 | raise ValueError('Unsupported input_format: {}'.format(input_format))
67 |
68 | for sources, target in yield_example_fn(input_file):
69 | yield sources, target
70 |
71 |
72 | def _yield_wikisplit_examples(
73 | input_file):
74 | # The Wikisplit format expects a TSV file with the source on the first and the
75 | # target on the second column.
76 | with tf.gfile.GFile(input_file) as f:
77 | for line in f:
78 | source, target, lcs_rate = line.rstrip('\n').split('\t')
79 | yield [source], target
80 |
81 |
82 | def _yield_discofuse_examples(
83 | input_file):
84 | """Yields DiscoFuse examples.
85 |
86 | The documentation for this format:
87 | https://github.com/google-research-datasets/discofuse#data-format
88 |
89 | Args:
90 | input_file: Path to the input file.
91 | """
92 | with tf.gfile.GFile(input_file) as f:
93 | for i, line in enumerate(f):
94 | if i == 0: # Skip the header line.
95 | continue
96 | coherent_1, coherent_2, incoherent_1, incoherent_2, _, _, _, _ = (
97 | line.rstrip('\n').split('\t'))
98 | # Strip because the second coherent sentence might be empty.
99 | fusion = (coherent_1 + ' ' + coherent_2).strip()
100 | yield [incoherent_1, incoherent_2], fusion
101 |
102 |
103 | def read_label_map(path):
104 | """Returns label map read from the given path."""
105 | with tf.gfile.GFile(path) as f:
106 | if path.endswith('.json'):
107 | return json.load(f)
108 | else:
109 | label_map = {}
110 | empty_line_encountered = False
111 | for tag in f:
112 | tag = tag.strip()
113 | if tag:
114 | label_map[tag] = len(label_map)
115 | else:
116 | if empty_line_encountered:
117 | raise ValueError(
118 | 'There should be no empty lines in the middle of the label map '
119 | 'file.'
120 | )
121 | empty_line_encountered = True
122 | return label_map
123 |
--------------------------------------------------------------------------------