├── .gitignore
├── LICENSE
├── README.md
├── convert.py
├── doccano.md
├── doccano.py
├── ernie.py
├── ernie_m.py
├── evaluate.py
├── export_model.py
├── finetune.py
├── labelstudio2doccano.py
├── model.py
├── requirements.txt
├── tokenizer.py
├── uie_predictor.py
└── utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | build*
13 | develop-eggs/
14 | dist/
15 | downloads/
16 | eggs/
17 | .eggs/
18 | lib/
19 | lib64/
20 | parts/
21 | sdist/
22 | var/
23 | wheels/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | .hypothesis/
51 | .pytest_cache/
52 |
53 | # Translations
54 | *.mo
55 | *.pot
56 | *.doctree
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 |
63 | # Flask stuff:
64 | instance/
65 | .webassets-cache
66 |
67 | # Scrapy stuff:
68 | .scrapy
69 |
70 | # Sphinx documentation
71 | docs/_build/
72 |
73 | # PyBuilder
74 | target/
75 |
76 | # Jupyter Notebook
77 | .ipynb_checkpoints
78 |
79 | # IPython
80 | profile_default/
81 | ipython_config.py
82 |
83 | # pyenv
84 | .python-version
85 |
86 | # celery beat schedule file
87 | celerybeat-schedule
88 |
89 | # SageMath parsed files
90 | *.sage.py
91 |
92 | # Environments
93 | .env
94 | .venv
95 | env/
96 | venv/
97 | ENV/
98 | env.bak/
99 | venv.bak/
100 |
101 | # Spyder project settings
102 | .spyderproject
103 | .spyproject
104 |
105 | # Rope project settings
106 | .ropeproject
107 |
108 | # mkdocs documentation
109 | /site
110 |
111 | # mypy
112 | .mypy_cache/
113 | .dmypy.json
114 | dmypy.json
115 |
116 | # Pyre type checker
117 | .pyre/
118 |
119 | # pycharm
120 | .DS_Store
121 | .idea/
122 | FETCH_HEAD
123 |
124 | # vscode
125 | .vscode
126 |
127 | checkpoint
128 | data/*
129 | export
130 | model_best
131 | uie-base
132 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document.
10 |
11 | "Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License.
12 |
13 | "Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity.
14 |
15 | "You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License.
16 |
17 | "Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files.
18 |
19 | "Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types.
20 |
21 | "Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below).
22 |
23 | "Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof.
24 |
25 | "Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution."
26 |
27 | "Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work.
28 |
29 | 2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form.
30 |
31 | 3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed.
32 |
33 | 4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions:
34 |
35 | You must give any other recipients of the Work or Derivative Works a copy of this License; and
36 | You must cause any modified files to carry prominent notices stating that You changed the files; and
37 | You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and
38 | If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License.
39 |
40 | You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License. 5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions.
41 |
42 | 6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file.
43 |
44 | 7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License.
45 |
46 | 8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages.
47 |
48 | 9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability.
49 |
50 | END OF TERMS AND CONDITIONS
51 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # 通用信息抽取 UIE(Universal Information Extraction) PyTorch版
2 |
3 | **迁移[PaddleNLP](https://github.com/PaddlePaddle/PaddleNLP)中的UIE模型到PyTorch上**
4 |
5 | * 2022-10-3: 新增对UIE-M系列模型的支持,增加了ErnieM的Tokenizer。ErnieMTokenizer使用C++实现的高性能分词算子FasterTokenizer进行文本预处理加速。需要通过`pip install faster_tokenizer`安装FasterTokenizer库后方可使用。
6 |
7 | PyTorch版功能介绍
8 | - `convert.py`: 自动下载并转换模型,详见[开箱即用](#开箱即用)。
9 | - `doccano.py`: 转换标注数据,详见[数据标注](#数据标注)。
10 | - `evaluate.py`: 评估模型,详见[模型评估](#模型评估)。
11 | - `export_model.py`: 导出ONNX推理模型,详见[模型部署](#模型部署)。
12 | - `finetune.py`: 微调训练,详见[模型微调](#模型微调)。
13 | - `model.py`: 模型定义。
14 | - `uie_predictor.py`: 推理类。
15 |
16 |
17 | **目录**
18 |
19 | - [1. 模型简介](#模型简介)
20 | - [2. 应用示例](#应用示例)
21 | - [3. 开箱即用](#开箱即用)
22 | - [3.1 实体抽取](#实体抽取)
23 | - [3.2 关系抽取](#关系抽取)
24 | - [3.3 事件抽取](#事件抽取)
25 | - [3.4 评论观点抽取](#评论观点抽取)
26 | - [3.5 情感分类](#情感分类)
27 | - [3.6 跨任务抽取](#跨任务抽取)
28 | - [3.7 模型选择](#模型选择)
29 | - [3.8 更多配置](#更多配置)
30 | - [4. 训练定制](#训练定制)
31 | - [4.1 代码结构](#代码结构)
32 | - [4.2 数据标注](#数据标注)
33 | - [4.3 模型微调](#模型微调)
34 | - [4.4 模型评估](#模型评估)
35 | - [4.5 定制模型一键预测](#定制模型一键预测)
36 | - [4.6 实验指标](#实验指标)
37 | - [4.7 模型部署](#模型部署)
38 |
39 |
40 |
41 | ## 1. 模型简介
42 |
43 | [UIE(Universal Information Extraction)](https://arxiv.org/pdf/2203.12277.pdf):Yaojie Lu等人在ACL-2022中提出了通用信息抽取统一框架UIE。该框架实现了实体抽取、关系抽取、事件抽取、情感分析等任务的统一建模,并使得不同任务间具备良好的迁移和泛化能力。为了方便大家使用UIE的强大能力,PaddleNLP借鉴该论文的方法,基于ERNIE 3.0知识增强预训练模型,训练并开源了首个中文通用信息抽取模型UIE。该模型可以支持不限定行业领域和抽取目标的关键信息抽取,实现零样本快速冷启动,并具备优秀的小样本微调能力,快速适配特定的抽取目标。
44 |
45 |
46 |

47 |
48 |
49 | #### UIE的优势
50 |
51 | - **使用简单**:用户可以使用自然语言自定义抽取目标,无需训练即可统一抽取输入文本中的对应信息。**实现开箱即用,并满足各类信息抽取需求**。
52 |
53 | - **降本增效**:以往的信息抽取技术需要大量标注数据才能保证信息抽取的效果,为了提高开发过程中的开发效率,减少不必要的重复工作时间,开放域信息抽取可以实现零样本(zero-shot)或者少样本(few-shot)抽取,**大幅度降低标注数据依赖,在降低成本的同时,还提升了效果**。
54 |
55 | - **效果领先**:开放域信息抽取在多种场景,多种任务上,均有不俗的表现。
56 |
57 |
58 |
59 | ## 2. 应用示例
60 |
61 | UIE不限定行业领域和抽取目标,以下是一些零样本行业示例:
62 |
63 | - 医疗场景-专病结构化
64 |
65 | 
66 |
67 | - 法律场景-判决书抽取
68 |
69 | 
70 |
71 | - 金融场景-收入证明、招股书抽取
72 |
73 | 
74 |
75 | - 公安场景-事故报告抽取
76 |
77 | 
78 |
79 | - 旅游场景-宣传册、手册抽取
80 |
81 | 
82 |
83 |
84 |
85 | ## 3. 开箱即用
86 |
87 | ```uie_predictor```提供通用信息抽取、评价观点抽取等能力,可抽取多种类型的信息,包括但不限于命名实体识别(如人名、地名、机构名等)、关系(如电影的导演、歌曲的发行时间等)、事件(如某路口发生车祸、某地发生地震等)、以及评价维度、观点词、情感倾向等信息。用户可以使用自然语言自定义抽取目标,无需训练即可统一抽取输入文本中的对应信息。**实现开箱即用,并满足各类信息抽取需求**
88 |
89 | ```uie_predictor```现在可以自动下载模型了,**无需手动convert**,如果想手动转换模型,可以参照以下方法。
90 |
91 | **下载并转换模型**,将下载Paddle版的`uie-base`模型到当前目录中,并生成PyTorch版模型`uie_base_pytorch`。
92 |
93 | ```shell
94 | python convert.py
95 | ```
96 |
97 | 如果没有安装paddlenlp,则使用以下命令。这将不会导入paddlenlp,以及不会验证转换结果正确性。
98 |
99 | ```shell
100 | python convert.py --no_validate_output
101 | ```
102 |
103 | 可配置参数说明:
104 |
105 | - `input_model`: 输入的模型所在的文件夹,例如存在模型`./model_path/model_state.pdparams`,则传入`./model_path`。如果传入`uie-base`或`uie-tiny`等在模型列表中的模型,且当前目录不存在此文件夹时,将自动下载模型。默认值为`uie-base`。
106 |
107 | 支持自动下载的模型
108 | - `uie-base`
109 | - `uie-medium`
110 | - `uie-mini`
111 | - `uie-micro`
112 | - `uie-nano`
113 | - `uie-medical-base`
114 | - `uie-tiny` (弃用,已改为`uie-medium`)
115 | - `uie-base-en`
116 | - `uie-m-base`
117 | - `uie-m-large`
118 | - `ernie-3.0-base-zh`*
119 |
120 | - `output_model`: 输出的模型的文件夹,默认为`uie_base_pytorch`。
121 | - `no_validate_output`:是否关闭对输出模型的验证,默认打开。
122 |
123 | \* : 使用`ernie-3.0-base-zh`时不会验证模型,需要微调后才能用于预测
124 |
125 |
126 |
127 |
128 | #### 3.1 实体抽取
129 |
130 | 命名实体识别(Named Entity Recognition,简称NER),是指识别文本中具有特定意义的实体。在开放域信息抽取中,抽取的类别没有限制,用户可以自己定义。
131 |
132 | - 例如抽取的目标实体类型是"时间"、"选手"和"赛事名称", schema构造如下:
133 |
134 | ```text
135 | ['时间', '选手', '赛事名称']
136 | ```
137 |
138 | 调用示例:
139 |
140 | ```python
141 | >>> from uie_predictor import UIEPredictor
142 | >>> from pprint import pprint
143 |
144 | >>> schema = ['时间', '选手', '赛事名称'] # Define the schema for entity extraction
145 | >>> ie = UIEPredictor(model='uie-base', schema=schema)
146 | >>> pprint(ie("2月8日上午北京冬奥会自由式滑雪女子大跳台决赛中中国选手谷爱凌以188.25分获得金牌!")) # Better print results using pprint
147 | [{'时间': [{'end': 6,
148 | 'probability': 0.9857378532924486,
149 | 'start': 0,
150 | 'text': '2月8日上午'}],
151 | '赛事名称': [{'end': 23,
152 | 'probability': 0.8503089953268272,
153 | 'start': 6,
154 | 'text': '北京冬奥会自由式滑雪女子大跳台决赛'}],
155 | '选手': [{'end': 31,
156 | 'probability': 0.8981548639781138,
157 | 'start': 28,
158 | 'text': '谷爱凌'}]}]
159 | ```
160 |
161 | - 例如抽取的目标实体类型是"肿瘤的大小"、"肿瘤的个数"、"肝癌级别"和"脉管内癌栓分级", schema构造如下:
162 |
163 | ```text
164 | ['肿瘤的大小', '肿瘤的个数', '肝癌级别', '脉管内癌栓分级']
165 | ```
166 |
167 | 在上例中我们已经实例化了一个`UIEPredictor`对象,这里可以通过`set_schema`方法重置抽取目标。
168 |
169 | 调用示例:
170 |
171 | ```python
172 | >>> schema = ['肿瘤的大小', '肿瘤的个数', '肝癌级别', '脉管内癌栓分级']
173 | >>> ie.set_schema(schema)
174 | >>> pprint(ie("(右肝肿瘤)肝细胞性肝癌(II-III级,梁索型和假腺管型),肿瘤包膜不完整,紧邻肝被膜,侵及周围肝组织,未见脉管内癌栓(MVI分级:M0级)及卫星子灶形成。(肿物1个,大小4.2×4.0×2.8cm)。"))
175 | [{'肝癌级别': [{'end': 20,
176 | 'probability': 0.9243267447402701,
177 | 'start': 13,
178 | 'text': 'II-III级'}],
179 | '肿瘤的个数': [{'end': 84,
180 | 'probability': 0.7538413804059623,
181 | 'start': 82,
182 | 'text': '1个'}],
183 | '肿瘤的大小': [{'end': 100,
184 | 'probability': 0.8341128043459491,
185 | 'start': 87,
186 | 'text': '4.2×4.0×2.8cm'}],
187 | '脉管内癌栓分级': [{'end': 70,
188 | 'probability': 0.9083292325934664,
189 | 'start': 67,
190 | 'text': 'M0级'}]}]
191 | ```
192 |
193 | - 例如抽取的目标实体类型是"person"和"organization",schema构造如下:
194 |
195 | ```text
196 | ['person', 'organization']
197 | ```
198 |
199 | 英文模型调用示例:
200 |
201 | ```python
202 | >>> from uie_predictor import UIEPredictor
203 | >>> from pprint import pprint
204 | >>> schema = ['Person', 'Organization']
205 | >>> ie_en = UIEPredictor(model='uie-base-en', schema=schema)
206 | >>> pprint(ie_en('In 1997, Steve was excited to become the CEO of Apple.'))
207 | [{'Organization': [{'end': 53,
208 | 'probability': 0.9985840259877357,
209 | 'start': 48,
210 | 'text': 'Apple'}],
211 | 'Person': [{'end': 14,
212 | 'probability': 0.999631971804547,
213 | 'start': 9,
214 | 'text': 'Steve'}]}]
215 | ```
216 |
217 |
218 |
219 | #### 3.2 关系抽取
220 |
221 | 关系抽取(Relation Extraction,简称RE),是指从文本中识别实体并抽取实体之间的语义关系,进而获取三元组信息,即<主体,谓语,客体>。
222 |
223 | - 例如以"竞赛名称"作为抽取主体,抽取关系类型为"主办方"、"承办方"和"已举办次数", schema构造如下:
224 |
225 | ```text
226 | {
227 | '竞赛名称': [
228 | '主办方',
229 | '承办方',
230 | '已举办次数'
231 | ]
232 | }
233 | ```
234 |
235 | 调用示例:
236 |
237 | ```python
238 | >>> schema = {'竞赛名称': ['主办方', '承办方', '已举办次数']} # Define the schema for relation extraction
239 | >>> ie.set_schema(schema) # Reset schema
240 | >>> pprint(ie('2022语言与智能技术竞赛由中国中文信息学会和中国计算机学会联合主办,百度公司、中国中文信息学会评测工作委员会和中国计算机学会自然语言处理专委会承办,已连续举办4届,成为全球最热门的中文NLP赛事之一。'))
241 | [{'竞赛名称': [{'end': 13,
242 | 'probability': 0.7825402622754041,
243 | 'relations': {'主办方': [{'end': 22,
244 | 'probability': 0.8421710521379353,
245 | 'start': 14,
246 | 'text': '中国中文信息学会'},
247 | {'end': 30,
248 | 'probability': 0.7580801847701935,
249 | 'start': 23,
250 | 'text': '中国计算机学会'}],
251 | '已举办次数': [{'end': 82,
252 | 'probability': 0.4671295049136148,
253 | 'start': 80,
254 | 'text': '4届'}],
255 | '承办方': [{'end': 39,
256 | 'probability': 0.8292706618236352,
257 | 'start': 35,
258 | 'text': '百度公司'},
259 | {'end': 72,
260 | 'probability': 0.6193477885474685,
261 | 'start': 56,
262 | 'text': '中国计算机学会自然语言处理专委会'},
263 | {'end': 55,
264 | 'probability': 0.7000497331473241,
265 | 'start': 40,
266 | 'text': '中国中文信息学会评测工作委员会'}]},
267 | 'start': 0,
268 | 'text': '2022语言与智能技术竞赛'}]}]
269 | ```
270 |
271 | - 例如以"person"作为抽取主体,抽取关系类型为"Company"和"Position", schema构造如下:
272 |
273 | ```text
274 | {
275 | 'Person': [
276 | 'Company',
277 | 'Position'
278 | ]
279 | }
280 | ```
281 |
282 | 英文模型调用示例:
283 |
284 | ```python
285 | >>> schema = [{'Person': ['Company', 'Position']}]
286 | >>> ie_en.set_schema(schema)
287 | >>> pprint(ie_en('In 1997, Steve was excited to become the CEO of Apple.'))
288 | [{'Person': [{'end': 14,
289 | 'probability': 0.999631971804547,
290 | 'relations': {'Company': [{'end': 53,
291 | 'probability': 0.9960158209451642,
292 | 'start': 48,
293 | 'text': 'Apple'}],
294 | 'Position': [{'end': 44,
295 | 'probability': 0.8871063806420736,
296 | 'start': 41,
297 | 'text': 'CEO'}]},
298 | 'start': 9,
299 | 'text': 'Steve'}]}]
300 | ```
301 |
302 |
303 |
304 | #### 3.3 事件抽取
305 |
306 | 事件抽取 (Event Extraction, 简称EE),是指从自然语言文本中抽取预定义的事件触发词(Trigger)和事件论元(Argument),组合为相应的事件结构化信息。
307 |
308 | - 例如抽取的目标是"地震"事件的"地震强度"、"时间"、"震中位置"和"震源深度"这些信息,schema构造如下:
309 |
310 | ```text
311 | {
312 | '地震触发词': [
313 | '地震强度',
314 | '时间',
315 | '震中位置',
316 | '震源深度'
317 | ]
318 | }
319 | ```
320 |
321 | 触发词的格式统一为`触发词`或``XX触发词`,`XX`表示具体事件类型,上例中的事件类型是`地震`,则对应触发词为`地震触发词`。
322 |
323 | 调用示例:
324 |
325 | ```python
326 | >>> schema = {'地震触发词': ['地震强度', '时间', '震中位置', '震源深度']} # Define the schema for event extraction
327 | >>> ie.set_schema(schema) # Reset schema
328 | >>> ie('中国地震台网正式测定:5月16日06时08分在云南临沧市凤庆县(北纬24.34度,东经99.98度)发生3.5级地震,震源深度10千米。')
329 | [{'地震触发词': [{'text': '地震', 'start': 56, 'end': 58, 'probability': 0.9987181623528585, 'relations': {'地震强度': [{'text': '3.5级', 'start': 52, 'end': 56, 'probability': 0.9962985320905915}], '时间': [{'text': '5月16日06时08分', 'start': 11, 'end': 22, 'probability': 0.9882578028575182}], '震中位置': [{'text': '云南临沧市凤庆县(北纬24.34度,东经99.98度)', 'start': 23, 'end': 50, 'probability': 0.8551415716584501}], '震源深度': [{'text': '10千米', 'start': 63, 'end': 67, 'probability': 0.999158304648045}]}}]}]
330 | ```
331 |
332 | - 英文模型**暂不支持事件抽取**
333 |
334 |
335 |
336 | #### 3.4 评论观点抽取
337 |
338 | 评论观点抽取,是指抽取文本中包含的评价维度、观点词。
339 |
340 | - 例如抽取的目标是文本中包含的评价维度及其对应的观点词和情感倾向,schema构造如下:
341 |
342 | ```text
343 | {
344 | '评价维度': [
345 | '观点词',
346 | '情感倾向[正向,负向]'
347 | ]
348 | }
349 | ```
350 |
351 | 调用示例:
352 |
353 | ```python
354 | >>> schema = {'评价维度': ['观点词', '情感倾向[正向,负向]']} # Define the schema for opinion extraction
355 | >>> ie.set_schema(schema) # Reset schema
356 | >>> pprint(ie("店面干净,很清静,服务员服务热情,性价比很高,发现收银台有排队")) # Better print results using pprint
357 | [{'评价维度': [{'end': 20,
358 | 'probability': 0.9817040258681473,
359 | 'relations': {'情感倾向[正向,负向]': [{'probability': 0.9966142505350533,
360 | 'text': '正向'}],
361 | '观点词': [{'end': 22,
362 | 'probability': 0.957396472711558,
363 | 'start': 21,
364 | 'text': '高'}]},
365 | 'start': 17,
366 | 'text': '性价比'},
367 | {'end': 2,
368 | 'probability': 0.9696849569741168,
369 | 'relations': {'情感倾向[正向,负向]': [{'probability': 0.9982153274927796,
370 | 'text': '正向'}],
371 | '观点词': [{'end': 4,
372 | 'probability': 0.9945318044652538,
373 | 'start': 2,
374 | 'text': '干净'}]},
375 | 'start': 0,
376 | 'text': '店面'}]}]
377 | ```
378 |
379 | - 英文模型schema构造如下:
380 |
381 | ```text
382 | {
383 | 'Aspect': [
384 | 'Opinion',
385 | 'Sentiment classification [negative, positive]'
386 | ]
387 | }
388 | ```
389 |
390 | 调用示例:
391 |
392 | ```python
393 | >>> schema = [{'Aspect': ['Opinion', 'Sentiment classification [negative, positive]']}]
394 | >>> ie_en.set_schema(schema)
395 | >>> pprint(ie_en("The teacher is very nice."))
396 | [{'Aspect': [{'end': 11,
397 | 'probability': 0.4301476415932193,
398 | 'relations': {'Opinion': [{'end': 24,
399 | 'probability': 0.9072940447883724,
400 | 'start': 15,
401 | 'text': 'very nice'}],
402 | 'Sentiment classification [negative, positive]': [{'probability': 0.9998571920670685,
403 | 'text': 'positive'}]},
404 | 'start': 4,
405 | 'text': 'teacher'}]}]
406 | ```
407 |
408 |
409 |
410 | #### 3.5 情感分类
411 |
412 | - 句子级情感倾向分类,即判断句子的情感倾向是“正向”还是“负向”,schema构造如下:
413 |
414 | ```text
415 | '情感倾向[正向,负向]'
416 | ```
417 |
418 | 调用示例:
419 |
420 | ```python
421 | >>> schema = '情感倾向[正向,负向]' # Define the schema for sentence-level sentiment classification
422 | >>> ie.set_schema(schema) # Reset schema
423 | >>> ie('这个产品用起来真的很流畅,我非常喜欢')
424 | [{'情感倾向[正向,负向]': [{'text': '正向', 'probability': 0.9988661643929895}]}]
425 | ```
426 |
427 | 英文模型schema构造如下:
428 |
429 | ```text
430 | '情感倾向[正向,负向]'
431 | ```
432 |
433 | 英文模型调用示例:
434 |
435 | ```python
436 | >>> schema = 'Sentiment classification [negative, positive]'
437 | >>> ie_en.set_schema(schema)
438 | >>> ie_en('I am sorry but this is the worst film I have ever seen in my life.')
439 | [{'Sentiment classification [negative, positive]': [{'text': 'negative', 'probability': 0.9998415771287057}]}]
440 | ```
441 |
442 |
443 |
444 | #### 3.6 跨任务抽取
445 |
446 | - 例如在法律场景同时对文本进行实体抽取和关系抽取,schema可按照如下方式进行构造:
447 |
448 | ```text
449 | [
450 | "法院",
451 | {
452 | "原告": "委托代理人"
453 | },
454 | {
455 | "被告": "委托代理人"
456 | }
457 | ]
458 | ```
459 |
460 | 调用示例:
461 |
462 | ```python
463 | >>> schema = ['法院', {'原告': '委托代理人'}, {'被告': '委托代理人'}]
464 | >>> ie.set_schema(schema)
465 | >>> pprint(ie("北京市海淀区人民法院\n民事判决书\n(199x)建初字第xxx号\n原告:张三。\n委托代理人李四,北京市 A律师事务所律师。\n被告:B公司,法定代表人王五,开发公司总经理。\n委托代理人赵六,北京市 C律师事务所律师。")) # Better print results using pprint
466 | [{'原告': [{'end': 37,
467 | 'probability': 0.9949814024296764,
468 | 'relations': {'委托代理人': [{'end': 46,
469 | 'probability': 0.7956844697990384,
470 | 'start': 44,
471 | 'text': '李四'}]},
472 | 'start': 35,
473 | 'text': '张三'}],
474 | '法院': [{'end': 10,
475 | 'probability': 0.9221074192336651,
476 | 'start': 0,
477 | 'text': '北京市海淀区人民法院'}],
478 | '被告': [{'end': 67,
479 | 'probability': 0.8437349536631089,
480 | 'relations': {'委托代理人': [{'end': 92,
481 | 'probability': 0.7267121388225029,
482 | 'start': 90,
483 | 'text': '赵六'}]},
484 | 'start': 64,
485 | 'text': 'B公司'}]}]
486 | ```
487 |
488 |
489 |
490 | #### 3.7 模型选择
491 |
492 | - 多模型选择,满足精度、速度要求
493 |
494 | | 模型 | 结构 | 语言 |
495 | | :---: | :--------: | :--------: |
496 | | `uie-base` (默认)| 12-layers, 768-hidden, 12-heads | 中文 |
497 | | `uie-base-en` | 12-layers, 768-hidden, 12-heads | 英文 |
498 | | `uie-medical-base` | 12-layers, 768-hidden, 12-heads | 中文 |
499 | | `uie-medium`| 6-layers, 768-hidden, 12-heads | 中文 |
500 | | `uie-mini`| 6-layers, 384-hidden, 12-heads | 中文 |
501 | | `uie-micro`| 4-layers, 384-hidden, 12-heads | 中文 |
502 | | `uie-nano`| 4-layers, 312-hidden, 12-heads | 中文 |
503 | | `uie-m-large`| 24-layers, 1024-hidden, 16-heads | 中、英文 |
504 | | `uie-m-base`| 12-layers, 768-hidden, 12-heads | 中、英文 |
505 |
506 |
507 | - `uie-nano`调用示例:
508 |
509 | ```python
510 | >>> from uie_predictor import UIEPredictor
511 |
512 | >>> schema = ['时间', '选手', '赛事名称']
513 | >>> ie = UIEPredictor('uie-nano', schema=schema)
514 | >>> ie("2月8日上午北京冬奥会自由式滑雪女子大跳台决赛中中国选手谷爱凌以188.25分获得金牌!")
515 | [{'时间': [{'text': '2月8日上午', 'start': 0, 'end': 6, 'probability': 0.6513581678349247}], '选手': [{'text': '谷爱凌', 'start': 28, 'end': 31, 'probability': 0.9819330659468051}], '赛事名称': [{'text': '北京冬奥会自由式滑雪女子大跳台决赛', 'start': 6, 'end': 23, 'probability': 0.4908131110420939}]}]
516 | ```
517 |
518 | - `uie-m-base`和`uie-m-large`支持中英文混合抽取,调用示例:
519 |
520 | ```python
521 | >>> from pprint import pprint
522 | >>> from uie_predictor import UIEPredictor
523 |
524 | >>> schema = ['Time', 'Player', 'Competition', 'Score']
525 | >>> ie = UIEPredictor(schema=schema, model="uie-m-base", schema_lang="en")
526 | >>> pprint(ie(["2月8日上午北京冬奥会自由式滑雪女子大跳台决赛中中国选手谷爱凌以188.25分获得金牌!", "Rafael Nadal wins French Open Final!"]))
527 | [{'Competition': [{'end': 23,
528 | 'probability': 0.9373889907291257,
529 | 'start': 6,
530 | 'text': '北京冬奥会自由式滑雪女子大跳台决赛'}],
531 | 'Player': [{'end': 31,
532 | 'probability': 0.6981119555336441,
533 | 'start': 28,
534 | 'text': '谷爱凌'}],
535 | 'Score': [{'end': 39,
536 | 'probability': 0.9888507878270296,
537 | 'start': 32,
538 | 'text': '188.25分'}],
539 | 'Time': [{'end': 6,
540 | 'probability': 0.9784080036931151,
541 | 'start': 0,
542 | 'text': '2月8日上午'}]},
543 | {'Competition': [{'end': 35,
544 | 'probability': 0.9851549932171295,
545 | 'start': 18,
546 | 'text': 'French Open Final'}],
547 | 'Player': [{'end': 12,
548 | 'probability': 0.9379371275888104,
549 | 'start': 0,
550 | 'text': 'Rafael Nadal'}]}]
551 | ```
552 |
553 |
554 |
555 | #### 3.8 更多配置
556 |
557 | ```python
558 | >>> from uie_predictor import UIEPredictor
559 |
560 | >>> ie = UIEPredictor('uie_nano',
561 | schema=schema)
562 | ```
563 |
564 | * `model`:选择任务使用的模型,默认为`uie-base`,可选有`uie-base`, `uie-medium`, `uie-mini`, `uie-micro`, `uie-nano`和`uie-medical-base`, `uie-base-en`。
565 | * `schema`:定义任务抽取目标,可参考开箱即用中不同任务的调用示例进行配置。
566 | * `schema_lang`:设置schema的语言,默认为`zh`, 可选有`zh`和`en`。因为中英schema的构造有所不同,因此需要指定schema的语言。该参数只对`uie-m-base`和`uie-m-large`模型有效。
567 | * `batch_size`:批处理大小,请结合机器情况进行调整,默认为1。
568 | * `task_path`:设定自定义的模型。
569 | * `position_prob`:模型对于span的起始位置/终止位置的结果概率在0~1之间,返回结果去掉小于这个阈值的结果,默认为0.5,span的最终概率输出为起始位置概率和终止位置概率的乘积。
570 | * `use_fp16`:是否使用`fp16`进行加速,默认关闭。`fp16`推理速度更快。如果选择`fp16`,请先确保机器正确安装NVIDIA相关驱动和基础软件,**确保CUDA>=11.2,cuDNN>=8.1.1**,初次使用需按照提示安装相关依赖。其次,需要确保GPU设备的CUDA计算能力(CUDA Compute Capability)大于7.0,典型的设备包括V100、T4、A10、A100、GTX 20系列和30系列显卡等。更多关于CUDA Compute Capability和精度支持情况请参考NVIDIA文档:[GPU硬件与支持精度对照表](https://docs.nvidia.com/deeplearning/tensorrt/archives/tensorrt-840-ea/support-matrix/index.html#hardware-precision-matrix)。
571 |
572 |
573 |
574 | ## 4. 训练定制
575 |
576 | 对于简单的抽取目标可以直接使用```UIEPredictor```实现零样本(zero-shot)抽取,对于细分场景我们推荐使用轻定制功能(标注少量数据进行模型微调)以进一步提升效果。下面通过`报销工单信息抽取`的例子展示如何通过5条训练数据进行UIE模型微调。
577 |
578 |
579 | #### 4.1 代码结构
580 |
581 | ```shell
582 | .
583 | ├── utils.py # 数据处理工具
584 | ├── model.py # 模型组网脚本
585 | ├── doccano.py # 数据标注脚本
586 | ├── doccano.md # 数据标注文档
587 | ├── finetune.py # 模型微调脚本
588 | ├── evaluate.py # 模型评估脚本
589 | └── README.md
590 | ```
591 |
592 |
593 |
594 | #### 4.2 数据标注
595 |
596 | 我们推荐使用数据标注平台[doccano](https://github.com/doccano/doccano) 进行数据标注,本示例也打通了从标注到训练的通道,即doccano导出数据后可通过[doccano.py](./doccano.py)脚本轻松将数据转换为输入模型时需要的形式,实现无缝衔接。标注方法的详细介绍请参考[doccano数据标注指南](doccano.md)。
597 |
598 | 原始数据示例:
599 |
600 | ```text
601 | 深大到双龙28块钱4月24号交通费
602 | ```
603 |
604 | 抽取的目标(schema)为:
605 |
606 | ```python
607 | schema = ['出发地', '目的地', '费用', '时间']
608 | ```
609 |
610 | 标注步骤如下:
611 |
612 | - 在doccano平台上,创建一个类型为``序列标注``的标注项目。
613 | - 定义实体标签类别,上例中需要定义的实体标签有``出发地``、``目的地``、``费用``和``时间``。
614 | - 使用以上定义的标签开始标注数据,下面展示了一个doccano标注示例:
615 |
616 |
617 |

618 |
619 |
620 | - 标注完成后,在doccano平台上导出文件,并将其重命名为``doccano_ext.json``后,放入``./data``目录下。
621 |
622 | - 这里我们提供预先标注好的文件[doccano_ext.json](https://bj.bcebos.com/paddlenlp/datasets/uie/doccano_ext.json),可直接下载并放入`./data`目录。执行以下脚本进行数据转换,执行后会在`./data`目录下生成训练/验证/测试集文件。
623 |
624 | ```shell
625 | python doccano.py \
626 | --doccano_file ./data/doccano_ext.json \
627 | --task_type ext \
628 | --save_dir ./data \
629 | --splits 0.8 0.2 0
630 | ```
631 |
632 |
633 | 可配置参数说明:
634 |
635 | - ``doccano_file``: 从doccano导出的数据标注文件。
636 | - ``save_dir``: 训练数据的保存目录,默认存储在``data``目录下。
637 | - ``negative_ratio``: 最大负例比例,该参数只对抽取类型任务有效,适当构造负例可提升模型效果。负例数量和实际的标签数量有关,最大负例数量 = negative_ratio * 正例数量。该参数只对训练集有效,默认为5。为了保证评估指标的准确性,验证集和测试集默认构造全负例。
638 | - ``splits``: 划分数据集时训练集、验证集所占的比例。默认为[0.8, 0.1, 0.1]表示按照``8:1:1``的比例将数据划分为训练集、验证集和测试集。
639 | - ``task_type``: 选择任务类型,可选有抽取和分类两种类型的任务。
640 | - ``options``: 指定分类任务的类别标签,该参数只对分类类型任务有效。默认为["正向", "负向"]。
641 | - ``prompt_prefix``: 声明分类任务的prompt前缀信息,该参数只对分类类型任务有效。默认为"情感倾向"。
642 | - ``is_shuffle``: 是否对数据集进行随机打散,默认为True。
643 | - ``seed``: 随机种子,默认为1000.
644 | - ``separator``: 实体类别/评价维度与分类标签的分隔符,该参数只对实体/评价维度级分类任务有效。默认为"##"。
645 |
646 | 备注:
647 | - 默认情况下 [doccano.py](./doccano.py) 脚本会按照比例将数据划分为 train/dev/test 数据集
648 | - 每次执行 [doccano.py](./doccano.py) 脚本,将会覆盖已有的同名数据文件
649 | - 在模型训练阶段我们推荐构造一些负例以提升模型效果,在数据转换阶段我们内置了这一功能。可通过`negative_ratio`控制自动构造的负样本比例;负样本数量 = negative_ratio * 正样本数量。
650 | - 对于从doccano导出的文件,默认文件中的每条数据都是经过人工正确标注的。
651 |
652 | 更多**不同类型任务(关系抽取、事件抽取、评价观点抽取等)的标注规则及参数说明**,请参考[doccano数据标注指南](doccano.md)。
653 |
654 | 此外,也可以通过数据标注平台 [Label Studio](https://labelstud.io/) 进行数据标注。本示例提供了 [labelstudio2doccano.py](./labelstudio2doccano.py) 脚本,将 label studio 导出的 JSON 数据文件格式转换成 doccano 导出的数据文件格式,后续的数据转换与模型微调等操作不变。
655 |
656 | ```shell
657 | python labelstudio2doccano.py --labelstudio_file label-studio.json
658 | ```
659 |
660 | 可配置参数说明:
661 |
662 | - ``labelstudio_file``: label studio 的导出文件路径(仅支持 JSON 格式)。
663 | - ``doccano_file``: doccano 格式的数据文件保存路径,默认为 "doccano_ext.jsonl"。
664 | - ``task_type``: 任务类型,可选有抽取("ext")和分类("cls")两种类型的任务,默认为 "ext"。
665 |
666 |
667 |
668 | #### 4.3 模型微调
669 |
670 | 通过运行以下命令进行模型微调:
671 |
672 | ```shell
673 | python finetune.py \
674 | --train_path "./data/train.txt" \
675 | --dev_path "./data/dev.txt" \
676 | --save_dir "./checkpoint" \
677 | --learning_rate 1e-5 \
678 | --batch_size 16 \
679 | --max_seq_len 512 \
680 | --num_epochs 100 \
681 | --model "uie_base_pytorch" \
682 | --seed 1000 \
683 | --logging_steps 10 \
684 | --valid_steps 100 \
685 | --device "gpu"
686 | ```
687 |
688 | 可配置参数说明:
689 |
690 | - `train_path`: 训练集文件路径。
691 | - `dev_path`: 验证集文件路径。
692 | - `save_dir`: 模型存储路径,默认为`./checkpoint`。
693 | - `learning_rate`: 学习率,默认为1e-5。
694 | - `batch_size`: 批处理大小,请结合机器情况进行调整,默认为16。
695 | - `max_seq_len`: 文本最大切分长度,输入超过最大长度时会对输入文本进行自动切分,默认为512。
696 | - `num_epochs`: 训练轮数,默认为100。
697 | - `model`: 选择模型,程序会基于选择的模型进行模型微调,默认为`uie_base_pytorch`。
698 | - `seed`: 随机种子,默认为1000.
699 | - `logging_steps`: 日志打印的间隔steps数,默认10。
700 | - `valid_steps`: evaluate的间隔steps数,默认100。
701 | - `device`: 选用什么设备进行训练,可选cpu或gpu。
702 | - `max_model_num`: 保存的模型的个数,不包含`model_best`和`early_stopping`保存的模型,默认为5。
703 | - `early_stopping`: 是否采用提前停止(Early Stopping),默认不使用。
704 |
705 |
706 |
707 | #### 4.4 模型评估
708 |
709 | 通过运行以下命令进行模型评估:
710 |
711 | ```shell
712 | python evaluate.py \
713 | --model_path ./checkpoint/model_best \
714 | --test_path ./data/dev.txt \
715 | --batch_size 16 \
716 | --max_seq_len 512
717 | ```
718 |
719 | 评估方式说明:采用单阶段评价的方式,即关系抽取、事件抽取等需要分阶段预测的任务对每一阶段的预测结果进行分别评价。验证/测试集默认会利用同一层级的所有标签来构造出全部负例。
720 |
721 | 可开启`debug`模式对每个正例类别分别进行评估,该模式仅用于模型调试:
722 |
723 | ```shell
724 | python evaluate.py \
725 | --model_path ./checkpoint/model_best \
726 | --test_path ./data/dev.txt \
727 | --debug
728 | ```
729 |
730 | 输出打印示例:
731 |
732 | ```text
733 | [2022-09-14 03:13:58,877] [ INFO] - -----------------------------
734 | [2022-09-14 03:13:58,877] [ INFO] - Class Name: 疾病
735 | [2022-09-14 03:13:58,877] [ INFO] - Evaluation Precision: 0.89744 | Recall: 0.83333 | F1: 0.86420
736 | [2022-09-14 03:13:59,145] [ INFO] - -----------------------------
737 | [2022-09-14 03:13:59,145] [ INFO] - Class Name: 手术治疗
738 | [2022-09-14 03:13:59,145] [ INFO] - Evaluation Precision: 0.90000 | Recall: 0.85714 | F1: 0.87805
739 | [2022-09-14 03:13:59,439] [ INFO] - -----------------------------
740 | [2022-09-14 03:13:59,440] [ INFO] - Class Name: 检查
741 | [2022-09-14 03:13:59,440] [ INFO] - Evaluation Precision: 0.77778 | Recall: 0.56757 | F1: 0.65625
742 | [2022-09-14 03:13:59,708] [ INFO] - -----------------------------
743 | [2022-09-14 03:13:59,709] [ INFO] - Class Name: X的手术治疗
744 | [2022-09-14 03:13:59,709] [ INFO] - Evaluation Precision: 0.90000 | Recall: 0.85714 | F1: 0.87805
745 | [2022-09-14 03:13:59,893] [ INFO] - -----------------------------
746 | [2022-09-14 03:13:59,893] [ INFO] - Class Name: X的实验室检查
747 | [2022-09-14 03:13:59,894] [ INFO] - Evaluation Precision: 0.71429 | Recall: 0.55556 | F1: 0.62500
748 | [2022-09-14 03:14:00,057] [ INFO] - -----------------------------
749 | [2022-09-14 03:14:00,058] [ INFO] - Class Name: X的影像学检查
750 | [2022-09-14 03:14:00,058] [ INFO] - Evaluation Precision: 0.69231 | Recall: 0.45000 | F1: 0.54545
751 | ```
752 |
753 | 可配置参数说明:
754 |
755 | - `model_path`: 进行评估的模型文件夹路径,路径下需包含模型权重文件`pytorch_model.bin`及配置文件`config.json`。
756 | - `test_path`: 进行评估的测试集文件。
757 | - `batch_size`: 批处理大小,请结合机器情况进行调整,默认为16。
758 | - `max_seq_len`: 文本最大切分长度,输入超过最大长度时会对输入文本进行自动切分,默认为512。
759 | - `device`: 选用进行训练的设备,可选`cpu`或`gpu`。
760 |
761 |
762 |
763 | #### 4.5 定制模型一键预测
764 |
765 | `UIEPredictor`装载定制模型,通过`task_path`指定模型权重文件的路径,路径下需要包含训练好的模型权重文件`pytorch_model.bin`。
766 |
767 | ```python
768 | >>> from pprint import pprint
769 | >>> from uie_predictor import UIEPredictor
770 |
771 | >>> schema = ['出发地', '目的地', '费用', '时间']
772 | # 设定抽取目标和定制化模型权重路径
773 | >>> my_ie = UIEPredictor(model='uie-base',task_path='./checkpoint/model_best', schema=schema)
774 | >>> pprint(my_ie("城市内交通费7月5日金额114广州至佛山"))
775 | [{'出发地': [{'end': 17,
776 | 'probability': 0.9975287467835301,
777 | 'start': 15,
778 | 'text': '广州'}],
779 | '时间': [{'end': 10,
780 | 'probability': 0.9999476678061399,
781 | 'start': 6,
782 | 'text': '7月5日'}],
783 | '目的地': [{'end': 20,
784 | 'probability': 0.9998511131226735,
785 | 'start': 18,
786 | 'text': '佛山'}],
787 | '费用': [{'end': 15,
788 | 'probability': 0.9994474579292856,
789 | 'start': 12,
790 | 'text': '114'}]}]
791 | ```
792 |
793 |
794 |
795 | #### 4.6 实验指标
796 |
797 | 我们在互联网、医疗、金融三大垂类自建测试集上进行了实验:
798 |
799 |
800 | | 金融 | 医疗 | 互联网
801 | |
---|
| 0-shot | 5-shot | 0-shot | 5-shot | 0-shot | 5-shot
802 | |
---|
uie-base (12L768H) | 46.43 | 70.92 | 71.83 | 85.72 | 78.33 | 81.86
803 | |
uie-medium (6L768H) | 41.11 | 64.53 | 65.40 | 75.72 | 78.32 | 79.68
804 | |
uie-mini (6L384H) | 37.04 | 64.65 | 60.50 | 78.36 | 72.09 | 76.38
805 | |
uie-micro (4L384H) | 37.53 | 62.11 | 57.04 | 75.92 | 66.00 | 70.22
806 | |
uie-nano (4L312H) | 38.94 | 66.83 | 48.29 | 76.74 | 62.86 | 72.35
807 | |
uie-m-large (24L1024H) | 49.35 | 74.55 | 70.50 | 92.66 | 78.49 | 83.02
808 | |
uie-m-base (12L768H) | 38.46 | 74.31 | 63.37 | 87.32 | 76.27 | 80.13
809 | |
810 |
811 | 0-shot表示无训练数据直接通过```UIEPredictor```进行预测,5-shot表示每个类别包含5条标注数据进行模型微调。**实验表明UIE在垂类场景可以通过少量数据(few-shot)进一步提升效果**。
812 |
813 |
814 |
815 | #### 4.7 模型部署
816 |
817 | 以下是UIE Python端的部署流程,包括环境准备、模型导出和使用示例。
818 |
819 | - 环境准备
820 | UIE的部署分为CPU和GPU两种情况,请根据你的部署环境安装对应的依赖。
821 |
822 | - CPU端
823 |
824 | CPU端的部署请使用如下命令安装所需依赖
825 |
826 | ```shell
827 | pip install onnx onnxruntime
828 | ```
829 | - GPU端
830 |
831 | 为了在GPU上获得最佳的推理性能和稳定性,请先确保机器已正确安装NVIDIA相关驱动和基础软件,确保**CUDA >= 11.2,cuDNN >= 8.1.1**,并使用以下命令安装所需依赖
832 |
833 | ```shell
834 | pip install onnx onnxconverter_common onnxruntime-gpu
835 | ```
836 |
837 | 如需使用半精度(FP16)部署,请确保GPU设备的CUDA计算能力 (CUDA Compute Capability) 大于7.0,典型的设备包括V100、T4、A10、A100、GTX 20系列和30系列显卡等。
838 | 更多关于CUDA Compute Capability和精度支持情况请参考NVIDIA文档:[GPU硬件与支持精度对照表](https://docs.nvidia.com/deeplearning/tensorrt/archives/tensorrt-840-ea/support-matrix/index.html#hardware-precision-matrix)
839 |
840 |
841 | - 模型导出
842 |
843 | 将训练后的动态图参数导出为静态图参数:
844 |
845 | ```shell
846 | python export_model.py --model_path ./checkpoint/model_best --output_path ./export
847 | ```
848 |
849 | 可配置参数说明:
850 |
851 | - `model_path`: 动态图训练保存的参数路径,路径下包含模型参数文件`pytorch_model.bin`和模型配置文件`config.json`。
852 | - `output_path`: 静态图参数导出路径,默认导出路径为`model_path`,即保存到输入模型同目录下。
853 |
854 | - 推理
855 |
856 | - CPU端推理样例
857 |
858 | 在CPU端,请使用如下命令进行部署
859 |
860 | ```shell
861 | python uie_predictor.py --task_path ./export --engine onnx --device cpu
862 | ```
863 |
864 | 可配置参数说明:
865 | - `model`:选择任务使用的模型,默认为`uie-base`,可选有`uie-base`, `uie-medium`, `uie-mini`, `uie-micro`, `uie-nano`和`uie-medical-base`, `uie-base-en`。
866 | - `task_path`: 用于推理的ONNX模型文件所在文件夹。例如模型文件路径为`./export/inference.onnx`,则传入`./export`。如果不设置,将自动下载`model`对应的模型。
867 | - `position_prob`:模型对于span的起始位置/终止位置的结果概率0~1之间,返回结果去掉小于这个阈值的结果,默认为0.5,span的最终概率输出为起始位置概率和终止位置概率的乘积。
868 | - `max_seq_len`: 文本最大切分长度,输入超过最大长度时会对输入文本进行自动切分,默认为512。
869 | - `engine`: 可选值为`pytorch`和`onnx`。推理使用的推理引擎。
870 |
871 | - GPU端推理样例
872 |
873 | 在GPU端,请使用如下命令进行部署
874 |
875 | ```shell
876 | python uie_predictor.py --task_path ./export --engine onnx --device gpu --use_fp16
877 | ```
878 |
879 | 可配置参数说明:
880 | - `model`:选择任务使用的模型,默认为`uie-base`,可选有`uie-base`, `uie-medium`, `uie-mini`, `uie-micro`, `uie-nano`和`uie-medical-base`, `uie-base-en`。
881 | - `task_path`: 用于推理的ONNX模型文件所在文件夹。例如模型文件路径为`./export/inference.onnx`,则传入`./export/inference`。如果不设置,将自动下载`model`对应的模型。
882 | - `use_fp16`: 是否使用FP16进行加速,默认关闭。
883 | - `position_prob`:模型对于span的起始位置/终止位置的结果概率0~1之间,返回结果去掉小于这个阈值的结果,默认为0.5,span的最终概率输出为起始位置概率和终止位置概率的乘积。
884 | - `max_seq_len`: 文本最大切分长度,输入超过最大长度时会对输入文本进行自动切分,默认为512。
885 | - `engine`: 可选值为`pytorch`和`onnx`。推理使用的推理引擎。
886 |
--------------------------------------------------------------------------------
/convert.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022 Heiheiyoyo. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import argparse
16 | import collections
17 | import json
18 | import os
19 | import pickle
20 | import shutil
21 | import numpy as np
22 | from base64 import b64decode
23 |
24 | import torch
25 | try:
26 | import paddle
27 | from paddle.utils.download import get_path_from_url
28 | paddle_installed = True
29 | except (ImportError, ModuleNotFoundError):
30 | from utils import get_path_from_url
31 | paddle_installed = False
32 |
33 | from model import UIE, UIEM
34 | from utils import logger
35 |
36 | MODEL_MAP = {
37 | # vocab.txt/special_tokens_map.json/tokenizer_config.json are common to the default model.
38 | "uie-base": {
39 | "resource_file_urls": {
40 | "model_state.pdparams":
41 | "https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_base_v1.0/model_state.pdparams",
42 | "model_config.json":
43 | "https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_base/model_config.json",
44 | "vocab.txt":
45 | "https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_base/vocab.txt",
46 | "special_tokens_map.json":
47 | "https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_base/special_tokens_map.json",
48 | "tokenizer_config.json":
49 | "https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_base/tokenizer_config.json"
50 | }
51 | },
52 | "uie-medium": {
53 | "resource_file_urls": {
54 | "model_state.pdparams":
55 | "https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_medium_v1.0/model_state.pdparams",
56 | "model_config.json":
57 | "https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_medium/model_config.json",
58 | "vocab.txt":
59 | "https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_base/vocab.txt",
60 | "special_tokens_map.json":
61 | "https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_base/special_tokens_map.json",
62 | "tokenizer_config.json":
63 | "https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_base/tokenizer_config.json",
64 | }
65 | },
66 | "uie-mini": {
67 | "resource_file_urls": {
68 | "model_state.pdparams":
69 | "https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_mini_v1.0/model_state.pdparams",
70 | "model_config.json":
71 | "https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_mini/model_config.json",
72 | "vocab.txt":
73 | "https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_base/vocab.txt",
74 | "special_tokens_map.json":
75 | "https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_base/special_tokens_map.json",
76 | "tokenizer_config.json":
77 | "https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_base/tokenizer_config.json",
78 | }
79 | },
80 | "uie-micro": {
81 | "resource_file_urls": {
82 | "model_state.pdparams":
83 | "https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_micro_v1.0/model_state.pdparams",
84 | "model_config.json":
85 | "https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_micro/model_config.json",
86 | "vocab.txt":
87 | "https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_base/vocab.txt",
88 | "special_tokens_map.json":
89 | "https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_base/special_tokens_map.json",
90 | "tokenizer_config.json":
91 | "https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_base/tokenizer_config.json",
92 | }
93 | },
94 | "uie-nano": {
95 | "resource_file_urls": {
96 | "model_state.pdparams":
97 | "https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_nano_v1.0/model_state.pdparams",
98 | "model_config.json":
99 | "https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_nano/model_config.json",
100 | "vocab.txt":
101 | "https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_base/vocab.txt",
102 | "special_tokens_map.json":
103 | "https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_base/special_tokens_map.json",
104 | "tokenizer_config.json":
105 | "https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_base/tokenizer_config.json",
106 | }
107 | },
108 | "uie-medical-base": {
109 | "resource_file_urls": {
110 | "model_state.pdparams":
111 | "https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_medical_base_v0.1/model_state.pdparams",
112 | "model_config.json":
113 | "https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_base/model_config.json",
114 | "vocab.txt":
115 | "https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_base/vocab.txt",
116 | "special_tokens_map.json":
117 | "https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_base/special_tokens_map.json",
118 | "tokenizer_config.json":
119 | "https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_base/tokenizer_config.json",
120 | }
121 | },
122 | "uie-base-en": {
123 | "resource_file_urls": {
124 | "model_state.pdparams":
125 | "https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_base_en_v1.1/model_state.pdparams",
126 | "model_config.json":
127 | "https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_base_en/model_config.json",
128 | "vocab.txt":
129 | "https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_base_en/vocab.txt",
130 | "special_tokens_map.json":
131 | "https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_base_en/special_tokens_map.json",
132 | "tokenizer_config.json":
133 | "https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_base_en/tokenizer_config.json",
134 | }
135 | },
136 | # uie-m模型需要Ernie-M模型
137 | "uie-m-base": {
138 | "resource_file_urls": {
139 | "model_state.pdparams":
140 | "https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_m_base_v1.0/model_state.pdparams",
141 | "model_config.json":
142 | "https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_m_base/model_config.json",
143 | "vocab.txt":
144 | "https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_m_base/vocab.txt",
145 | "special_tokens_map.json":
146 | "https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_m_base/special_tokens_map.json",
147 | "tokenizer_config.json":
148 | "https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_m_base/tokenizer_config.json",
149 | "sentencepiece.bpe.model":
150 | "https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_m_base/sentencepiece.bpe.model"
151 |
152 | }
153 | },
154 | "uie-m-large": {
155 | "resource_file_urls": {
156 | "model_state.pdparams":
157 | "https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_m_large_v1.0/model_state.pdparams",
158 | "model_config.json":
159 | "https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_m_large/model_config.json",
160 | "vocab.txt":
161 | "https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_m_large/vocab.txt",
162 | "special_tokens_map.json":
163 | "https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_m_large/special_tokens_map.json",
164 | "tokenizer_config.json":
165 | "https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_m_large/tokenizer_config.json",
166 | "sentencepiece.bpe.model":
167 | "https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_m_base/sentencepiece.bpe.model"
168 | }
169 | },
170 | # Rename to `uie-medium` and the name of `uie-tiny` will be deprecated in future.
171 | "uie-tiny": {
172 | "resource_file_urls": {
173 | "model_state.pdparams":
174 | "https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_tiny_v0.1/model_state.pdparams",
175 | "model_config.json":
176 | "https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_tiny/model_config.json",
177 | "vocab.txt":
178 | "https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_tiny/vocab.txt",
179 | "special_tokens_map.json":
180 | "https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_tiny/special_tokens_map.json",
181 | "tokenizer_config.json":
182 | "https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_tiny/tokenizer_config.json"
183 | }
184 | },
185 | "ernie-3.0-base-zh": {
186 | "resource_file_urls": {
187 | "model_state.pdparams":
188 | "https://bj.bcebos.com/paddlenlp/models/transformers/ernie_3.0/ernie_3.0_base_zh.pdparams",
189 | "model_config.json":
190 | "base64:ew0KICAiYXR0ZW50aW9uX3Byb2JzX2Ryb3BvdXRfcHJvYiI6IDAuMSwNCiAgImhpZGRlbl9hY3QiOiAiZ2VsdSIsDQogICJoaWRkZW5fZHJvcG91dF9wcm9iIjogMC4xLA0KICAiaGlkZGVuX3NpemUiOiA3NjgsDQogICJpbml0aWFsaXplcl9yYW5nZSI6IDAuMDIsDQogICJtYXhfcG9zaXRpb25fZW1iZWRkaW5ncyI6IDIwNDgsDQogICJudW1fYXR0ZW50aW9uX2hlYWRzIjogMTIsDQogICJudW1faGlkZGVuX2xheWVycyI6IDEyLA0KICAidGFza190eXBlX3ZvY2FiX3NpemUiOiAzLA0KICAidHlwZV92b2NhYl9zaXplIjogNCwNCiAgInVzZV90YXNrX2lkIjogdHJ1ZSwNCiAgInZvY2FiX3NpemUiOiA0MDAwMCwNCiAgImluaXRfY2xhc3MiOiAiRXJuaWVNb2RlbCINCn0=",
191 | "vocab.txt":
192 | "https://bj.bcebos.com/paddlenlp/models/transformers/ernie_3.0/ernie_3.0_base_zh_vocab.txt",
193 | "special_tokens_map.json":
194 | "base64:eyJ1bmtfdG9rZW4iOiAiW1VOS10iLCAic2VwX3Rva2VuIjogIltTRVBdIiwgInBhZF90b2tlbiI6ICJbUEFEXSIsICJjbHNfdG9rZW4iOiAiW0NMU10iLCAibWFza190b2tlbiI6ICJbTUFTS10ifQ==",
195 | "tokenizer_config.json":
196 | "base64:eyJkb19sb3dlcl9jYXNlIjogdHJ1ZSwgInVua190b2tlbiI6ICJbVU5LXSIsICJzZXBfdG9rZW4iOiAiW1NFUF0iLCAicGFkX3Rva2VuIjogIltQQURdIiwgImNsc190b2tlbiI6ICJbQ0xTXSIsICJtYXNrX3Rva2VuIjogIltNQVNLXSIsICJ0b2tlbml6ZXJfY2xhc3MiOiAiRXJuaWVUb2tlbml6ZXIifQ=="
197 | }
198 | }
199 | }
200 |
201 |
202 | def build_params_map(model_prefix='encoder', attention_num=12):
203 | """
204 | build params map from paddle-paddle's ERNIE to transformer's BERT
205 | :return:
206 | """
207 | weight_map = collections.OrderedDict({
208 | f'{model_prefix}.embeddings.word_embeddings.weight': "encoder.embeddings.word_embeddings.weight",
209 | f'{model_prefix}.embeddings.position_embeddings.weight': "encoder.embeddings.position_embeddings.weight",
210 | f'{model_prefix}.embeddings.token_type_embeddings.weight': "encoder.embeddings.token_type_embeddings.weight",
211 | f'{model_prefix}.embeddings.task_type_embeddings.weight': "encoder.embeddings.task_type_embeddings.weight",
212 | f'{model_prefix}.embeddings.layer_norm.weight': 'encoder.embeddings.LayerNorm.gamma',
213 | f'{model_prefix}.embeddings.layer_norm.bias': 'encoder.embeddings.LayerNorm.beta',
214 | })
215 | # add attention layers
216 | for i in range(attention_num):
217 | weight_map[f'{model_prefix}.encoder.layers.{i}.self_attn.q_proj.weight'] = f'encoder.encoder.layer.{i}.attention.self.query.weight'
218 | weight_map[f'{model_prefix}.encoder.layers.{i}.self_attn.q_proj.bias'] = f'encoder.encoder.layer.{i}.attention.self.query.bias'
219 | weight_map[f'{model_prefix}.encoder.layers.{i}.self_attn.k_proj.weight'] = f'encoder.encoder.layer.{i}.attention.self.key.weight'
220 | weight_map[f'{model_prefix}.encoder.layers.{i}.self_attn.k_proj.bias'] = f'encoder.encoder.layer.{i}.attention.self.key.bias'
221 | weight_map[f'{model_prefix}.encoder.layers.{i}.self_attn.v_proj.weight'] = f'encoder.encoder.layer.{i}.attention.self.value.weight'
222 | weight_map[f'{model_prefix}.encoder.layers.{i}.self_attn.v_proj.bias'] = f'encoder.encoder.layer.{i}.attention.self.value.bias'
223 | weight_map[f'{model_prefix}.encoder.layers.{i}.self_attn.out_proj.weight'] = f'encoder.encoder.layer.{i}.attention.output.dense.weight'
224 | weight_map[f'{model_prefix}.encoder.layers.{i}.self_attn.out_proj.bias'] = f'encoder.encoder.layer.{i}.attention.output.dense.bias'
225 | weight_map[f'{model_prefix}.encoder.layers.{i}.norm1.weight'] = f'encoder.encoder.layer.{i}.attention.output.LayerNorm.gamma'
226 | weight_map[f'{model_prefix}.encoder.layers.{i}.norm1.bias'] = f'encoder.encoder.layer.{i}.attention.output.LayerNorm.beta'
227 | weight_map[f'{model_prefix}.encoder.layers.{i}.linear1.weight'] = f'encoder.encoder.layer.{i}.intermediate.dense.weight'
228 | weight_map[f'{model_prefix}.encoder.layers.{i}.linear1.bias'] = f'encoder.encoder.layer.{i}.intermediate.dense.bias'
229 | weight_map[f'{model_prefix}.encoder.layers.{i}.linear2.weight'] = f'encoder.encoder.layer.{i}.output.dense.weight'
230 | weight_map[f'{model_prefix}.encoder.layers.{i}.linear2.bias'] = f'encoder.encoder.layer.{i}.output.dense.bias'
231 | weight_map[f'{model_prefix}.encoder.layers.{i}.norm2.weight'] = f'encoder.encoder.layer.{i}.output.LayerNorm.gamma'
232 | weight_map[f'{model_prefix}.encoder.layers.{i}.norm2.bias'] = f'encoder.encoder.layer.{i}.output.LayerNorm.beta'
233 | # add pooler
234 | weight_map.update(
235 | {
236 | f'{model_prefix}.pooler.dense.weight': 'encoder.pooler.dense.weight',
237 | f'{model_prefix}.pooler.dense.bias': 'encoder.pooler.dense.bias',
238 | 'linear_start.weight': 'linear_start.weight',
239 | 'linear_start.bias': 'linear_start.bias',
240 | 'linear_end.weight': 'linear_end.weight',
241 | 'linear_end.bias': 'linear_end.bias',
242 | }
243 | )
244 | return weight_map
245 |
246 |
247 | def extract_and_convert(input_dir, output_dir, verbose=False):
248 | if not os.path.exists(output_dir):
249 | os.makedirs(output_dir)
250 | if verbose:
251 | logger.info('=' * 20 + 'save config file' + '=' * 20)
252 | config = json.load(
253 | open(os.path.join(input_dir, 'model_config.json'), 'rt', encoding='utf-8'))
254 | if 'init_args' in config:
255 | config = config['init_args'][0]
256 | config["architectures"] = ["UIE"]
257 | config['layer_norm_eps'] = 1e-12
258 | del config['init_class']
259 | if 'sent_type_vocab_size' in config:
260 | config['type_vocab_size'] = config['sent_type_vocab_size']
261 | config['intermediate_size'] = 4 * config['hidden_size']
262 | json.dump(config, open(os.path.join(output_dir, 'config.json'),
263 | 'wt', encoding='utf-8'), indent=4)
264 | if verbose:
265 | logger.info('=' * 20 + 'save vocab file' + '=' * 20)
266 | shutil.copy(os.path.join(input_dir, 'vocab.txt'),
267 | os.path.join(output_dir, 'vocab.txt'))
268 | special_tokens_map = json.load(open(os.path.join(
269 | input_dir, 'special_tokens_map.json'), 'rt', encoding='utf-8'))
270 | json.dump(special_tokens_map, open(os.path.join(output_dir, 'special_tokens_map.json'),
271 | 'wt', encoding='utf-8'))
272 | tokenizer_config = json.load(
273 | open(os.path.join(input_dir, 'tokenizer_config.json'), 'rt', encoding='utf-8'))
274 | if tokenizer_config['tokenizer_class'] == 'ErnieTokenizer':
275 | tokenizer_config['tokenizer_class'] = "BertTokenizer"
276 | json.dump(tokenizer_config, open(os.path.join(output_dir, 'tokenizer_config.json'),
277 | 'wt', encoding='utf-8'))
278 | spm_file = os.path.join(input_dir, 'sentencepiece.bpe.model')
279 | if os.path.exists(spm_file):
280 | shutil.copy(spm_file, os.path.join(
281 | output_dir, 'sentencepiece.bpe.model'))
282 | if verbose:
283 | logger.info('=' * 20 + 'extract weights' + '=' * 20)
284 | state_dict = collections.OrderedDict()
285 | weight_map = build_params_map(attention_num=config['num_hidden_layers'])
286 | weight_map.update(build_params_map(
287 | 'ernie', attention_num=config['num_hidden_layers']))
288 | if paddle_installed:
289 | import paddle.fluid.dygraph as D
290 | from paddle import fluid
291 | with fluid.dygraph.guard():
292 | paddle_paddle_params, _ = D.load_dygraph(
293 | os.path.join(input_dir, 'model_state'))
294 | else:
295 | paddle_paddle_params = pickle.load(
296 | open(os.path.join(input_dir, 'model_state.pdparams'), 'rb'))
297 | del paddle_paddle_params['StructuredToParameterName@@']
298 | for weight_name, weight_value in paddle_paddle_params.items():
299 | transposed = ''
300 | if 'weight' in weight_name:
301 | if '.encoder' in weight_name or 'pooler' in weight_name or 'linear' in weight_name:
302 | weight_value = weight_value.transpose()
303 | transposed = '.T'
304 | # Fix: embedding error
305 | if 'word_embeddings.weight' in weight_name:
306 | weight_value[0, :] = 0
307 | if weight_name not in weight_map:
308 | if verbose:
309 | logger.info(f"{'='*20} [SKIP] {weight_name} {'='*20}")
310 | continue
311 | state_dict[weight_map[weight_name]] = torch.FloatTensor(weight_value)
312 | if verbose:
313 | logger.info(
314 | f"{weight_name}{transposed} -> {weight_map[weight_name]} {weight_value.shape}")
315 | torch.save(state_dict, os.path.join(output_dir, "pytorch_model.bin"))
316 |
317 |
318 | def check_model(input_model):
319 | if not os.path.exists(input_model):
320 | if input_model not in MODEL_MAP:
321 | raise ValueError('input_model not exists!')
322 |
323 | resource_file_urls = MODEL_MAP[input_model]['resource_file_urls']
324 | logger.info("Downloading resource files...")
325 |
326 | for key, val in resource_file_urls.items():
327 | file_path = os.path.join(input_model, key)
328 | if not os.path.exists(file_path):
329 | if val.startswith('base64:'):
330 | base64data = b64decode(val.replace(
331 | 'base64:', '').encode('utf-8'))
332 | with open(file_path, 'wb') as f:
333 | f.write(base64data)
334 | else:
335 | download_path = get_path_from_url(val, input_model)
336 | if download_path != file_path:
337 | shutil.move(download_path, file_path)
338 |
339 |
340 | def validate_model(tokenizer, pt_model, pd_model, model_type='uie', atol: float = 1e-5):
341 |
342 | logger.info("Validating PyTorch model...")
343 |
344 | batch_size = 2
345 | seq_length = 6
346 | seq_length_with_token = seq_length+2
347 | max_seq_length = 512
348 | dummy_input = [" ".join([tokenizer.unk_token])
349 | * seq_length] * batch_size
350 | encoded_inputs = dict(tokenizer(dummy_input, pad_to_max_seq_len=True, max_seq_len=512, return_attention_mask=True,
351 | return_position_ids=True))
352 | paddle_inputs = {}
353 | for name, value in encoded_inputs.items():
354 | if name == "attention_mask":
355 | if model_type == 'uie-m':
356 | continue
357 | name = "att_mask"
358 | if name == "position_ids":
359 | name = "pos_ids"
360 | paddle_inputs[name] = paddle.to_tensor(value, dtype=paddle.int64)
361 |
362 | paddle_named_outputs = ['start_prob', 'end_prob']
363 | paddle_outputs = pd_model(**paddle_inputs)
364 |
365 | torch_inputs = {}
366 | for name, value in encoded_inputs.items():
367 | if name == "attention_mask":
368 | if model_type == 'uie-m':
369 | continue
370 | torch_inputs[name] = torch.tensor(value, dtype=torch.int64)
371 | torch_outputs = pt_model(**torch_inputs)
372 | torch_outputs_dict = {}
373 |
374 | for name, value in torch_outputs.items():
375 | torch_outputs_dict[name] = value
376 |
377 | torch_outputs_set, ref_outputs_set = set(
378 | torch_outputs_dict.keys()), set(paddle_named_outputs)
379 | if not torch_outputs_set.issubset(ref_outputs_set):
380 | logger.info(
381 | f"\t-[x] Pytorch model output names {torch_outputs_set} do not match reference model {ref_outputs_set}"
382 | )
383 |
384 | raise ValueError(
385 | "Outputs doesn't match between reference model and Pytorch converted model: "
386 | f"{torch_outputs_set.difference(ref_outputs_set)}"
387 | )
388 | else:
389 | logger.info(
390 | f"\t-[✓] Pytorch model output names match reference model ({torch_outputs_set})")
391 |
392 | # Check the shape and values match
393 | for name, ref_value in zip(paddle_named_outputs, paddle_outputs):
394 | ref_value = ref_value.numpy()
395 | pt_value = torch_outputs_dict[name].detach().numpy()
396 | logger.info(f'\t- Validating PyTorch Model output "{name}":')
397 |
398 | # Shape
399 | if not pt_value.shape == ref_value.shape:
400 | logger.info(
401 | f"\t\t-[x] shape {pt_value.shape} doesn't match {ref_value.shape}")
402 | raise ValueError(
403 | "Outputs shape doesn't match between reference model and Pytorch converted model: "
404 | f"Got {ref_value.shape} (reference) and {pt_value.shape} (PyTorch)"
405 | )
406 | else:
407 | logger.info(
408 | f"\t\t-[✓] {pt_value.shape} matches {ref_value.shape}")
409 |
410 | # Values
411 | if not np.allclose(ref_value, pt_value, atol=atol):
412 | logger.info(
413 | f"\t\t-[x] values not close enough (atol: {atol})")
414 | raise ValueError(
415 | "Outputs values doesn't match between reference model and Pytorch converted model: "
416 | f"Got max absolute difference of: {np.amax(np.abs(ref_value - pt_value))}"
417 | )
418 | else:
419 | logger.info(
420 | f"\t\t-[✓] all values close (atol: {atol})")
421 |
422 |
423 | def do_main():
424 | if args.output_model is None:
425 | args.output_model = args.input_model.replace('-', '_')+'_pytorch'
426 | check_model(args.input_model)
427 | extract_and_convert(args.input_model, args.output_model, verbose=True)
428 | if not (args.no_validate_output or 'ernie' in args.input_model):
429 | if paddle_installed:
430 | try:
431 | from paddlenlp.transformers import ErnieTokenizer, ErnieMTokenizer
432 | from paddlenlp.taskflow.models import UIE as UIEPaddle, UIEM as UIEMPaddle
433 | except (ImportError, ModuleNotFoundError) as e:
434 | raise ModuleNotFoundError(
435 | 'Module PaddleNLP is not installed. Try install paddlenlp or run convert.py with --no_validate_output') from e
436 | if 'uie-m' in args.input_model:
437 | tokenizer: ErnieMTokenizer = ErnieMTokenizer.from_pretrained(
438 | args.input_model)
439 | model = UIEM.from_pretrained(args.output_model)
440 | model.eval()
441 | paddle_model = UIEMPaddle.from_pretrained(args.input_model)
442 | paddle_model.eval()
443 | model_type = 'uie-m'
444 | else:
445 | tokenizer: ErnieTokenizer = ErnieTokenizer.from_pretrained(
446 | args.input_model)
447 | model = UIE.from_pretrained(args.output_model)
448 | model.eval()
449 | paddle_model = UIEPaddle.from_pretrained(args.input_model)
450 | paddle_model.eval()
451 | model_type = 'uie'
452 | validate_model(tokenizer, model, paddle_model, model_type)
453 | else:
454 | logger.warning("Skipping validating PyTorch model because paddle is not installed. "
455 | "The outputs of the model may not be the same as Paddle model.")
456 |
457 |
458 | if __name__ == '__main__':
459 | parser = argparse.ArgumentParser()
460 | parser.add_argument("-i", "--input_model", default="uie-base", type=str,
461 | help="Directory of input paddle model.\n Will auto download model [uie-base/uie-tiny]")
462 | parser.add_argument("-o", "--output_model", default=None, type=str,
463 | help="Directory of output pytorch model")
464 | parser.add_argument("--no_validate_output", action="store_true",
465 | help="Directory of output pytorch model")
466 | args = parser.parse_args()
467 |
468 | do_main()
469 |
--------------------------------------------------------------------------------
/doccano.md:
--------------------------------------------------------------------------------
1 | # doccano
2 |
3 | **目录**
4 |
5 | * [1. 安装](#安装)
6 | * [2. 项目创建](#项目创建)
7 | * [3. 数据上传](#数据上传)
8 | * [4. 标签构建](#标签构建)
9 | * [5. 任务标注](#任务标注)
10 | * [6. 数据导出](#数据导出)
11 | * [7. 数据转换](#数据转换)
12 |
13 |
14 |
15 | ## 1. 安装
16 |
17 | 参考[doccano官方文档](https://github.com/doccano/doccano) 完成doccano的安装与初始配置。
18 |
19 | **以下标注示例用到的环境配置:**
20 |
21 | - doccano 1.6.2
22 |
23 |
24 |
25 | ## 2. 项目创建
26 |
27 | UIE支持抽取与分类两种类型的任务,根据实际需要创建一个新的项目:
28 |
29 | #### 2.1 抽取式任务项目创建
30 |
31 | 创建项目时选择**序列标注**任务,并勾选**Allow overlapping entity**及**Use relation Labeling**。适配**命名实体识别、关系抽取、事件抽取、评价观点抽取**等任务。
32 |
33 |
34 |

35 |
36 |
37 | #### 2.2 分类式任务项目创建
38 |
39 | 创建项目时选择**文本分类**任务。适配**文本分类、句子级情感倾向分类**等任务。
40 |
41 |
42 |

43 |
44 |
45 |
46 |
47 | ## 3. 数据上传
48 |
49 | 上传的文件为txt格式,每一行为一条待标注文本,示例:
50 |
51 | ```text
52 | 2月8日上午北京冬奥会自由式滑雪女子大跳台决赛中中国选手谷爱凌以188.25分获得金牌
53 | 第十四届全运会在西安举办
54 | ```
55 |
56 | 上传数据类型**选择TextLine**:
57 |
58 |
59 |

60 |
61 |
62 | **NOTE**:doccano支持`TextFile`、`TextLine`、`JSONL`和`CoNLL`四种数据上传格式,UIE定制训练中**统一使用TextLine**这一文件格式,即上传的文件需要为txt格式,且在数据标注时,该文件的每一行待标注文本显示为一页内容。
63 |
64 |
65 |
66 | ## 4. 标签构建
67 |
68 | #### 4.1 构建抽取式任务标签
69 |
70 | 抽取式任务包含**Span**与**Relation**两种标签类型,Span指**原文本中的目标信息片段**,如实体识别中某个类型的实体,事件抽取中的触发词和论元;Relation指**原文本中Span之间的关系**,如关系抽取中两个实体(Subject&Object)之间的关系,事件抽取中论元和触发词之间的关系。
71 |
72 | Span类型标签构建示例:
73 |
74 |
75 |

76 |
77 |
78 | Relation类型标签构建示例:
79 |
80 |
81 |

82 |
83 |
84 | #### 4.2 构建分类式任务标签
85 |
86 | 添加分类类别标签:
87 |
88 |
89 |

90 |
91 |
92 |
93 |
94 | ## 5. 任务标注
95 |
96 | #### 5.1 命名实体识别
97 |
98 | 命名实体识别(Named Entity Recognition,简称NER),是指识别文本中具有特定意义的实体。在开放域信息抽取中,**抽取的类别没有限制,用户可以自己定义**。
99 |
100 | 标注示例:
101 |
102 |
103 |

104 |
105 |
106 | 示例中定义了`时间`、`选手`、`赛事名称`和`得分`四种Span类型标签。
107 |
108 | ```text
109 | schema = [
110 | '时间',
111 | '选手',
112 | '赛事名称',
113 | '得分'
114 | ]
115 | ```
116 |
117 | #### 5.2 关系抽取
118 |
119 | 关系抽取(Relation Extraction,简称RE),是指从文本中识别实体并抽取实体之间的语义关系,即抽取三元组(实体一,关系类型,实体二)。
120 |
121 | 标注示例:
122 |
123 |
124 |

125 |
126 |
127 | 示例中定义了`作品名`、`人物名`和`时间`三种Span类型标签,以及`歌手`、`发行时间`和`所属专辑`三种Relation标签。Relation标签**由Subject对应实体指向Object对应实体**。
128 |
129 | 该标注示例对应的schema为:
130 |
131 | ```text
132 | schema = {
133 | '作品名': [
134 | '歌手',
135 | '发行时间',
136 | '所属专辑'
137 | ]
138 | }
139 | ```
140 |
141 | #### 5.3 事件抽取
142 |
143 | 事件抽取 (Event Extraction, 简称EE),是指从自然语言文本中抽取事件并识别事件类型和事件论元的技术。UIE所包含的事件抽取任务,是指根据已知事件类型,抽取该事件所包含的事件论元。
144 |
145 | 标注示例:
146 |
147 |
148 |

149 |
150 |
151 | 示例中定义了`地震触发词`(触发词)、`等级`(事件论元)和`时间`(事件论元)三种Span标签,以及`时间`和`震级`两种Relation标签。触发词标签**统一格式为`XX触发词`**,`XX`表示具体事件类型,上例中的事件类型是`地震`,则对应触发词为`地震触发词`。Relation标签**由触发词指向对应的事件论元**。
152 |
153 | 该标注示例对应的schema为:
154 |
155 | ```text
156 | schema = {
157 | '地震触发词': [
158 | '时间',
159 | '震级'
160 | ]
161 | }
162 | ```
163 |
164 | #### 5.4 评价观点抽取
165 |
166 | 评论观点抽取,是指抽取文本中包含的评价维度、观点词。
167 |
168 | 标注示例:
169 |
170 |
171 |

172 |
173 |
174 | 示例中定义了`评价维度`和`观点词`两种Span标签,以及`观点词`一种Relation标签。Relation标签**由评价维度指向观点词**。
175 |
176 | 该标注示例对应的schema为:
177 |
178 | ```text
179 | schema = {
180 | '评价维度': '观点词'
181 | }
182 | ```
183 |
184 | #### 5.5 句子级分类任务
185 |
186 | 标注示例:
187 |
188 |
189 |

190 |
191 |
192 | 示例中定义了`正向`和`负向`两种类别标签对文本的情感倾向进行分类。
193 |
194 | 该标注示例对应的schema为:
195 |
196 | ```text
197 | schema = '情感倾向[正向,负向]'
198 | ```
199 |
200 | #### 5.6 实体/评价维度级分类任务
201 |
202 |
203 |

204 |
205 |
206 | 标注示例:
207 |
208 | 示例中定义了`评价维度##正向`,`评价维度##负向`和`观点词`三种Span标签以及`观点词`一种Relation标签。其中,`##`是实体类别/评价维度与分类标签的分隔符(可通过doccano.py中的separator参数自定义)。
209 |
210 | 该标注示例对应的schema为:
211 |
212 | ```text
213 | schema = {
214 | '评价维度': [
215 | '观点词',
216 | '情感倾向[正向,负向]'
217 | ]
218 | }
219 | ```
220 |
221 |
222 |
223 | ## 6. 数据导出
224 |
225 | #### 6.1 导出抽取式和实体/评价维度级分类任务数据
226 |
227 | 选择导出的文件类型为``JSONL(relation)``,导出数据示例:
228 |
229 | ```text
230 | {
231 | "id": 38,
232 | "text": "百科名片你知道我要什么,是歌手高明骏演唱的一首歌曲,1989年发行,收录于个人专辑《丛林男孩》中",
233 | "relations": [
234 | {
235 | "id": 20,
236 | "from_id": 51,
237 | "to_id": 53,
238 | "type": "歌手"
239 | },
240 | {
241 | "id": 21,
242 | "from_id": 51,
243 | "to_id": 55,
244 | "type": "发行时间"
245 | },
246 | {
247 | "id": 22,
248 | "from_id": 51,
249 | "to_id": 54,
250 | "type": "所属专辑"
251 | }
252 | ],
253 | "entities": [
254 | {
255 | "id": 51,
256 | "start_offset": 4,
257 | "end_offset": 11,
258 | "label": "作品名"
259 | },
260 | {
261 | "id": 53,
262 | "start_offset": 15,
263 | "end_offset": 18,
264 | "label": "人物名"
265 | },
266 | {
267 | "id": 54,
268 | "start_offset": 42,
269 | "end_offset": 46,
270 | "label": "作品名"
271 | },
272 | {
273 | "id": 55,
274 | "start_offset": 26,
275 | "end_offset": 31,
276 | "label": "时间"
277 | }
278 | ]
279 | }
280 | ```
281 |
282 | 标注数据保存在同一个文本文件中,每条样例占一行且存储为``json``格式,其包含以下字段
283 | - ``id``: 样本在数据集中的唯一标识ID。
284 | - ``text``: 原始文本数据。
285 | - ``entities``: 数据中包含的Span标签,每个Span标签包含四个字段:
286 | - ``id``: Span在数据集中的唯一标识ID。
287 | - ``start_offset``: Span的起始token在文本中的下标。
288 | - ``end_offset``: Span的结束token在文本中下标的下一个位置。
289 | - ``label``: Span类型。
290 | - ``relations``: 数据中包含的Relation标签,每个Relation标签包含四个字段:
291 | - ``id``: (Span1, Relation, Span2)三元组在数据集中的唯一标识ID,不同样本中的相同三元组对应同一个ID。
292 | - ``from_id``: Span1对应的标识ID。
293 | - ``to_id``: Span2对应的标识ID。
294 | - ``type``: Relation类型。
295 |
296 | #### 6.2 导出句子级分类任务数据
297 |
298 | 选择导出的文件类型为``JSONL``,导出数据示例:
299 |
300 | ```text
301 | {
302 | "id": 41,
303 | "data": "大年初一就把车前保险杠给碰坏了,保险杠和保险公司 真够倒霉的,我决定步行反省。",
304 | "label": [
305 | "负向"
306 | ]
307 | }
308 | ```
309 |
310 | 标注数据保存在同一个文本文件中,每条样例占一行且存储为``json``格式,其包含以下字段
311 | - ``id``: 样本在数据集中的唯一标识ID。
312 | - ``data``: 原始文本数据。
313 | - ``label``: 文本对应类别标签。
314 |
315 |
316 |
317 | ## 7.数据转换
318 |
319 | 该章节详细说明如何通过`doccano.py`脚本对doccano平台导出的标注数据进行转换,一键生成训练/验证/测试集。
320 |
321 | #### 7.1 抽取式任务数据转换
322 |
323 | - 当标注完成后,在 doccano 平台上导出 `JSONL(relation)` 形式的文件,并将其重命名为 `doccano_ext.json` 后,放入 `./data` 目录下。
324 | - 通过 [doccano.py](./doccano.py) 脚本进行数据形式转换,然后便可以开始进行相应模型训练。
325 |
326 | ```shell
327 | python doccano.py \
328 | --doccano_file ./data/doccano_ext.json \
329 | --task_type "ext" \
330 | --save_dir ./data \
331 | --negative_ratio 5
332 | ```
333 |
334 | #### 7.2 句子级分类任务数据转换
335 |
336 | - 当标注完成后,在 doccano 平台上导出 `JSON` 形式的文件,并将其重命名为 `doccano_cls.json` 后,放入 `./data` 目录下。
337 | - 在数据转换阶段,我们会自动构造用于模型训练的prompt信息。例如句子级情感分类中,prompt为``情感倾向[正向,负向]``,可以通过`prompt_prefix`和`options`参数进行声明。
338 | - 通过 [doccano.py](./doccano.py) 脚本进行数据形式转换,然后便可以开始进行相应模型训练。
339 |
340 | ```shell
341 | python doccano.py \
342 | --doccano_file ./data/doccano_cls.json \
343 | --task_type "cls" \
344 | --save_dir ./data \
345 | --splits 0.8 0.1 0.1 \
346 | --prompt_prefix "情感倾向" \
347 | --options "正向" "负向"
348 | ```
349 |
350 | #### 7.3 实体/评价维度级分类任务数据转换
351 |
352 | - 当标注完成后,在 doccano 平台上导出 `JSONL(relation)` 形式的文件,并将其重命名为 `doccano_ext.json` 后,放入 `./data` 目录下。
353 | - 在数据转换阶段,我们会自动构造用于模型训练的prompt信息。例如评价维度级情感分类中,prompt为``XXX的情感倾向[正向,负向]``,可以通过`prompt_prefix`和`options`参数进行声明。
354 | - 通过 [doccano.py](./doccano.py) 脚本进行数据形式转换,然后便可以开始进行相应模型训练。
355 |
356 | ```shell
357 | python doccano.py \
358 | --doccano_file ./data/doccano_ext.json \
359 | --task_type "ext" \
360 | --save_dir ./data \
361 | --splits 0.8 0.1 0.1 \
362 | --prompt_prefix "情感倾向" \
363 | --options "正向" "负向" \
364 | --separator "##"
365 | ```
366 |
367 | 可配置参数说明:
368 |
369 | - ``doccano_file``: 从doccano导出的数据标注文件。
370 | - ``save_dir``: 训练数据的保存目录,默认存储在``data``目录下。
371 | - ``negative_ratio``: 最大负例比例,该参数只对抽取类型任务有效,适当构造负例可提升模型效果。负例数量和实际的标签数量有关,最大负例数量 = negative_ratio * 正例数量。该参数只对训练集有效,默认为5。为了保证评估指标的准确性,验证集和测试集默认构造全负例。
372 | - ``splits``: 划分数据集时训练集、验证集所占的比例。默认为[0.8, 0.1, 0.1]表示按照``8:1:1``的比例将数据划分为训练集、验证集和测试集。
373 | - ``task_type``: 选择任务类型,可选有抽取和分类两种类型的任务。
374 | - ``options``: 指定分类任务的类别标签,该参数只对分类类型任务有效。默认为["正向", "负向"]。
375 | - ``prompt_prefix``: 声明分类任务的prompt前缀信息,该参数只对分类类型任务有效。默认为"情感倾向"。
376 | - ``is_shuffle``: 是否对数据集进行随机打散,默认为True。
377 | - ``seed``: 随机种子,默认为1000.
378 | - ``separator``: 实体类别/评价维度与分类标签的分隔符,该参数只对实体/评价维度级分类任务有效。默认为"##"。
379 |
380 | 备注:
381 | - 默认情况下 [doccano.py](./doccano.py) 脚本会按照比例将数据划分为 train/dev/test 数据集
382 | - 每次执行 [doccano.py](./doccano.py) 脚本,将会覆盖已有的同名数据文件
383 | - 在模型训练阶段我们推荐构造一些负例以提升模型效果,在数据转换阶段我们内置了这一功能。可通过`negative_ratio`控制自动构造的负样本比例;负样本数量 = negative_ratio * 正样本数量。
384 | - 对于从doccano导出的文件,默认文件中的每条数据都是经过人工正确标注的。
385 |
386 | ## References
387 | - **[doccano](https://github.com/doccano/doccano)**
388 |
--------------------------------------------------------------------------------
/doccano.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | import os
17 | import time
18 | import argparse
19 | import json
20 | from decimal import Decimal
21 | import numpy as np
22 |
23 | from utils import set_seed, convert_ext_examples, convert_cls_examples, logger
24 |
25 |
26 | def do_convert():
27 | set_seed(args.seed)
28 |
29 | tic_time = time.time()
30 | if not os.path.exists(args.doccano_file):
31 | raise ValueError("Please input the correct path of doccano file.")
32 |
33 | if not os.path.exists(args.save_dir):
34 | os.makedirs(args.save_dir)
35 |
36 | if len(args.splits) != 0 and len(args.splits) != 3:
37 | raise ValueError("Only []/ len(splits)==3 accepted for splits.")
38 |
39 | def _check_sum(splits):
40 | return Decimal(str(splits[0])) + Decimal(str(splits[1])) + Decimal(
41 | str(splits[2])) == Decimal("1")
42 |
43 | if len(args.splits) == 3 and not _check_sum(args.splits):
44 | raise ValueError(
45 | "Please set correct splits, sum of elements in splits should be equal to 1."
46 | )
47 |
48 | with open(args.doccano_file, "r", encoding="utf-8") as f:
49 | raw_examples = f.readlines()
50 |
51 | def _create_ext_examples(examples,
52 | negative_ratio,
53 | prompt_prefix="情感倾向",
54 | options=["正向", "负向"],
55 | separator="##",
56 | shuffle=False,
57 | is_train=True):
58 | entities, relations, aspects = convert_ext_examples(
59 | examples, negative_ratio, prompt_prefix, options, separator,
60 | is_train)
61 | examples = entities + relations + aspects
62 | if shuffle:
63 | indexes = np.random.permutation(len(examples))
64 | examples = [examples[i] for i in indexes]
65 | return examples
66 |
67 | def _create_cls_examples(examples, prompt_prefix, options, shuffle=False):
68 | examples = convert_cls_examples(examples, prompt_prefix, options)
69 | if shuffle:
70 | indexes = np.random.permutation(len(examples))
71 | examples = [examples[i] for i in indexes]
72 | return examples
73 |
74 | def _save_examples(save_dir, file_name, examples):
75 | count = 0
76 | save_path = os.path.join(save_dir, file_name)
77 | if not examples:
78 | logger.info("Skip saving %d examples to %s." % (0, save_path))
79 | return
80 | with open(save_path, "w", encoding="utf-8") as f:
81 | for example in examples:
82 | f.write(json.dumps(example, ensure_ascii=False) + "\n")
83 | count += 1
84 | logger.info("Save %d examples to %s." % (count, save_path))
85 |
86 | if len(args.splits) == 0:
87 | if args.task_type == "ext":
88 | examples = _create_ext_examples(raw_examples, args.negative_ratio,
89 | args.prompt_prefix, args.options,
90 | args.separator, args.is_shuffle)
91 | else:
92 | examples = _create_cls_examples(raw_examples, args.prompt_prefix,
93 | args.options, args.is_shuffle)
94 | _save_examples(args.save_dir, "train.txt", examples)
95 | else:
96 | if args.is_shuffle:
97 | indexes = np.random.permutation(len(raw_examples))
98 | index_list = indexes.tolist()
99 | raw_examples = [raw_examples[i] for i in indexes]
100 |
101 | i1, i2, _ = args.splits
102 | p1 = int(len(raw_examples) * i1)
103 | p2 = int(len(raw_examples) * (i1 + i2))
104 |
105 | train_ids = index_list[:p1]
106 | dev_ids = index_list[p1:p2]
107 | test_ids = index_list[p2:]
108 |
109 | with open(os.path.join(args.save_dir, "sample_index.json"), "w") as fp:
110 | maps = {
111 | "train_ids": train_ids,
112 | "dev_ids": dev_ids,
113 | "test_ids": test_ids
114 | }
115 | fp.write(json.dumps(maps))
116 |
117 | if args.task_type == "ext":
118 | train_examples = _create_ext_examples(raw_examples[:p1],
119 | args.negative_ratio,
120 | args.prompt_prefix,
121 | args.options, args.separator,
122 | args.is_shuffle)
123 | dev_examples = _create_ext_examples(raw_examples[p1:p2],
124 | -1,
125 | args.prompt_prefix,
126 | args.options,
127 | args.separator,
128 | is_train=False)
129 | test_examples = _create_ext_examples(raw_examples[p2:],
130 | -1,
131 | args.prompt_prefix,
132 | args.options,
133 | args.separator,
134 | is_train=False)
135 | else:
136 | train_examples = _create_cls_examples(raw_examples[:p1],
137 | args.prompt_prefix,
138 | args.options)
139 | dev_examples = _create_cls_examples(raw_examples[p1:p2],
140 | args.prompt_prefix,
141 | args.options)
142 | test_examples = _create_cls_examples(raw_examples[p2:],
143 | args.prompt_prefix,
144 | args.options)
145 |
146 | _save_examples(args.save_dir, "train.txt", train_examples)
147 | _save_examples(args.save_dir, "dev.txt", dev_examples)
148 | _save_examples(args.save_dir, "test.txt", test_examples)
149 |
150 | logger.info('Finished! It takes %.2f seconds' % (time.time() - tic_time))
151 |
152 |
153 | if __name__ == "__main__":
154 | # yapf: disable
155 | parser = argparse.ArgumentParser()
156 |
157 | parser.add_argument("-d", "--doccano_file", default="./data/doccano.json",
158 | type=str, help="The doccano file exported from doccano platform.")
159 | parser.add_argument("-s", "--save_dir", default="./data",
160 | type=str, help="The path of data that you wanna save.")
161 | parser.add_argument("--negative_ratio", default=5, type=int,
162 | help="Used only for the extraction task, the ratio of positive and negative samples, number of negtive samples = negative_ratio * number of positive samples")
163 | parser.add_argument("--splits", default=[0.8, 0.1, 0.1], type=float, nargs="*",
164 | help="The ratio of samples in datasets. [0.6, 0.2, 0.2] means 60%% samples used for training, 20%% for evaluation and 20%% for test.")
165 | parser.add_argument("--task_type", choices=['ext', 'cls'], default="ext", type=str,
166 | help="Select task type, ext for the extraction task and cls for the classification task, defaults to ext.")
167 | parser.add_argument("--options", default=["正向", "负向"], type=str, nargs="+",
168 | help="Used only for the classification task, the options for classification")
169 | parser.add_argument("--prompt_prefix", default="情感倾向", type=str,
170 | help="Used only for the classification task, the prompt prefix for classification")
171 | parser.add_argument("--is_shuffle", default=True, type=bool,
172 | help="Whether to shuffle the labeled dataset, defaults to True.")
173 | parser.add_argument("--seed", type=int, default=1000,
174 | help="Random seed for initialization")
175 | parser.add_argument("--separator", type=str, default='##',
176 | help="Used only for entity/aspect-level classification task, separator for entity label and classification label")
177 |
178 | args = parser.parse_args()
179 | # yapf: enable
180 |
181 | do_convert()
182 |
--------------------------------------------------------------------------------
/evaluate.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022 Heiheiyoyo. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from model import UIE
16 | import argparse
17 | from functools import partial
18 |
19 | import torch
20 | from transformers import BertTokenizerFast
21 | from torch.utils.data import DataLoader
22 |
23 | from utils import IEMapDataset, SpanEvaluator, IEDataset, convert_example, get_relation_type_dict, logger, tqdm, unify_prompt_name
24 |
25 |
26 | @torch.no_grad()
27 | def evaluate(model, metric, data_loader, device='gpu', loss_fn=None, show_bar=True):
28 | """
29 | Given a dataset, it evals model and computes the metric.
30 | Args:
31 | model(obj:`torch.nn.Module`): A model to classify texts.
32 | metric(obj:`Metric`): The evaluation metric.
33 | data_loader(obj:`torch.utils.data.DataLoader`): The dataset loader which generates batches.
34 | """
35 | return_loss = False
36 | if loss_fn is not None:
37 | return_loss = True
38 | model.eval()
39 | metric.reset()
40 | loss_list = []
41 | loss_sum = 0
42 | loss_num = 0
43 | if show_bar:
44 | data_loader = tqdm(
45 | data_loader, desc="Evaluating", unit='batch')
46 | for batch in data_loader:
47 | input_ids, token_type_ids, att_mask, start_ids, end_ids = batch
48 | if device == 'gpu':
49 | input_ids = input_ids.cuda()
50 | token_type_ids = token_type_ids.cuda()
51 | att_mask = att_mask.cuda()
52 | outputs = model(input_ids=input_ids,
53 | token_type_ids=token_type_ids,
54 | attention_mask=att_mask)
55 | start_prob, end_prob = outputs[0], outputs[1]
56 | if device == 'gpu':
57 | start_prob, end_prob = start_prob.cpu(), end_prob.cpu()
58 | start_ids = start_ids.type(torch.float32)
59 | end_ids = end_ids.type(torch.float32)
60 |
61 | if return_loss:
62 | # Calculate loss
63 | loss_start = loss_fn(start_prob, start_ids)
64 | loss_end = loss_fn(end_prob, end_ids)
65 | loss = (loss_start + loss_end) / 2.0
66 | loss = float(loss)
67 | loss_list.append(loss)
68 | loss_sum += loss
69 | loss_num += 1
70 | if show_bar:
71 | data_loader.set_postfix(
72 | {
73 | 'dev loss': f'{loss_sum / loss_num:.5f}'
74 | }
75 | )
76 |
77 | # Calcalate metric
78 | num_correct, num_infer, num_label = metric.compute(start_prob, end_prob,
79 | start_ids, end_ids)
80 | metric.update(num_correct, num_infer, num_label)
81 | precision, recall, f1 = metric.accumulate()
82 | model.train()
83 | if return_loss:
84 | loss_avg = sum(loss_list) / len(loss_list)
85 | return loss_avg, precision, recall, f1
86 | else:
87 | return precision, recall, f1
88 |
89 |
90 | def do_eval():
91 | tokenizer = BertTokenizerFast.from_pretrained(args.model_path)
92 | model = UIE.from_pretrained(args.model_path)
93 | if args.device == 'gpu':
94 | model = model.cuda()
95 |
96 | test_ds = IEDataset(args.test_path, tokenizer=tokenizer,
97 | max_seq_len=args.max_seq_len)
98 |
99 | test_data_loader = DataLoader(
100 | test_ds, batch_size=args.batch_size, shuffle=False)
101 | class_dict = {}
102 | relation_data = []
103 | if args.debug:
104 | for data in test_ds.dataset:
105 | class_name = unify_prompt_name(data['prompt'])
106 | # Only positive examples are evaluated in debug mode
107 | if len(data['result_list']) != 0:
108 | if "的" not in data['prompt']:
109 | class_dict.setdefault(class_name, []).append(data)
110 | else:
111 | relation_data.append((data['prompt'], data))
112 | relation_type_dict = get_relation_type_dict(relation_data)
113 | else:
114 | class_dict["all_classes"] = test_ds
115 |
116 | for key in class_dict.keys():
117 | if args.debug:
118 | test_ds = IEMapDataset(class_dict[key], tokenizer=tokenizer,
119 | max_seq_len=args.max_seq_len)
120 | else:
121 | test_ds = class_dict[key]
122 |
123 | test_data_loader = DataLoader(
124 | test_ds, batch_size=args.batch_size, shuffle=False)
125 | metric = SpanEvaluator()
126 | precision, recall, f1 = evaluate(
127 | model, metric, test_data_loader, args.device)
128 | logger.info("-----------------------------")
129 | logger.info("Class Name: %s" % key)
130 | logger.info("Evaluation Precision: %.5f | Recall: %.5f | F1: %.5f" %
131 | (precision, recall, f1))
132 |
133 | if args.debug and len(relation_type_dict.keys()) != 0:
134 | for key in relation_type_dict.keys():
135 | test_ds = IEMapDataset(relation_type_dict[key], tokenizer=tokenizer,
136 | max_seq_len=args.max_seq_len)
137 |
138 | test_data_loader = DataLoader(
139 | test_ds, batch_size=args.batch_size, shuffle=False)
140 | metric = SpanEvaluator()
141 | precision, recall, f1 = evaluate(
142 | model, metric, test_data_loader, args.device)
143 | logger.info("-----------------------------")
144 | logger.info("Class Name: X的%s" % key)
145 | logger.info("Evaluation Precision: %.5f | Recall: %.5f | F1: %.5f" %
146 | (precision, recall, f1))
147 |
148 |
149 | if __name__ == "__main__":
150 | # yapf: disable
151 | parser = argparse.ArgumentParser()
152 |
153 | parser.add_argument("-m", "--model_path", type=str, required=True,
154 | help="The path of saved model that you want to load.")
155 | parser.add_argument("-t", "--test_path", type=str, required=True,
156 | help="The path of test set.")
157 | parser.add_argument("-b", "--batch_size", type=int, default=16,
158 | help="Batch size per GPU/CPU for training.")
159 | parser.add_argument("--max_seq_len", type=int, default=512,
160 | help="The maximum total input sequence length after tokenization.")
161 | parser.add_argument("-D", '--device', choices=['cpu', 'gpu'], default="gpu",
162 | help="Select which device to run model, defaults to gpu.")
163 | parser.add_argument("--debug", action='store_true',
164 | help="Precision, recall and F1 score are calculated for each class separately if this option is enabled.")
165 |
166 | args = parser.parse_args()
167 | # yapf: enable
168 |
169 | do_eval()
170 |
--------------------------------------------------------------------------------
/export_model.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022 Heiheiyoyo. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import argparse
16 | import os
17 | from itertools import chain
18 | from typing import List, Union
19 | import shutil
20 | from pathlib import Path
21 |
22 | import numpy as np
23 | import torch
24 | from transformers import (BertTokenizer, PreTrainedModel,
25 | PreTrainedTokenizerBase)
26 |
27 | from model import UIE
28 | from utils import logger
29 |
30 |
31 | def validate_onnx(tokenizer: PreTrainedTokenizerBase, pt_model: PreTrainedModel, onnx_path: Union[Path, str], strict: bool = True, atol: float = 1e-05):
32 |
33 | # 验证模型
34 | from onnxruntime import InferenceSession, SessionOptions
35 | from transformers import AutoTokenizer
36 |
37 | logger.info("Validating ONNX model...")
38 | if strict:
39 | ref_inputs = tokenizer('装备', "印媒所称的“印度第一艘国产航母”—“维克兰特”号",
40 | add_special_tokens=True,
41 | truncation=True,
42 | max_length=512,
43 | return_tensors="pt")
44 | else:
45 | batch_size = 2
46 | seq_length = 6
47 | dummy_input = [" ".join([tokenizer.unk_token])
48 | * seq_length] * batch_size
49 | ref_inputs = dict(tokenizer(dummy_input, return_tensors="pt"))
50 | # ref_inputs =
51 | ref_outputs = pt_model(**ref_inputs)
52 | ref_outputs_dict = {}
53 |
54 | # We flatten potential collection of outputs (i.e. past_keys) to a flat structure
55 | for name, value in ref_outputs.items():
56 | # Overwriting the output name as "present" since it is the name used for the ONNX outputs
57 | # ("past_key_values" being taken for the ONNX inputs)
58 | if name == "past_key_values":
59 | name = "present"
60 | ref_outputs_dict[name] = value
61 |
62 | # Create ONNX Runtime session
63 | options = SessionOptions()
64 | session = InferenceSession(str(onnx_path), options, providers=[
65 | "CPUExecutionProvider"])
66 |
67 | # We flatten potential collection of inputs (i.e. past_keys)
68 | onnx_inputs = {}
69 | for name, value in ref_inputs.items():
70 | onnx_inputs[name] = value.numpy()
71 | onnx_named_outputs = ['start_prob', 'end_prob']
72 | # Compute outputs from the ONNX model
73 | onnx_outputs = session.run(onnx_named_outputs, onnx_inputs)
74 |
75 | # Check we have a subset of the keys into onnx_outputs against ref_outputs
76 | ref_outputs_set, onnx_outputs_set = set(
77 | ref_outputs_dict.keys()), set(onnx_named_outputs)
78 | if not onnx_outputs_set.issubset(ref_outputs_set):
79 | logger.info(
80 | f"\t-[x] ONNX model output names {onnx_outputs_set} do not match reference model {ref_outputs_set}"
81 | )
82 |
83 | raise ValueError(
84 | "Outputs doesn't match between reference model and ONNX exported model: "
85 | f"{onnx_outputs_set.difference(ref_outputs_set)}"
86 | )
87 | else:
88 | logger.info(
89 | f"\t-[✓] ONNX model output names match reference model ({onnx_outputs_set})")
90 |
91 | # Check the shape and values match
92 | for name, ort_value in zip(onnx_named_outputs, onnx_outputs):
93 | ref_value = ref_outputs_dict[name].detach().numpy()
94 |
95 | logger.info(f'\t- Validating ONNX Model output "{name}":')
96 |
97 | # Shape
98 | if not ort_value.shape == ref_value.shape:
99 | logger.info(
100 | f"\t\t-[x] shape {ort_value.shape} doesn't match {ref_value.shape}")
101 | raise ValueError(
102 | "Outputs shape doesn't match between reference model and ONNX exported model: "
103 | f"Got {ref_value.shape} (reference) and {ort_value.shape} (ONNX)"
104 | )
105 | else:
106 | logger.info(
107 | f"\t\t-[✓] {ort_value.shape} matches {ref_value.shape}")
108 |
109 | # Values
110 | if not np.allclose(ref_value, ort_value, atol=atol):
111 | logger.info(f"\t\t-[x] values not close enough (atol: {atol})")
112 | raise ValueError(
113 | "Outputs values doesn't match between reference model and ONNX exported model: "
114 | f"Got max absolute difference of: {np.amax(np.abs(ref_value - ort_value))}"
115 | )
116 | else:
117 | logger.info(f"\t\t-[✓] all values close (atol: {atol})")
118 |
119 |
120 | def export_onnx(output_path: Union[Path, str], tokenizer: PreTrainedTokenizerBase, model: PreTrainedModel, device: torch.device, input_names: List[str], output_names: List[str]):
121 | with torch.no_grad():
122 | model = model.to(device)
123 | model.eval()
124 | model.config.return_dict = True
125 | model.config.use_cache = False
126 |
127 | output_path = Path(output_path)
128 |
129 | # Create folder
130 | if not output_path.exists():
131 | output_path.mkdir(parents=True)
132 | save_path = output_path / "inference.onnx"
133 |
134 | dynamic_axes = {name: {0: 'batch', 1: 'sequence'}
135 | for name in chain(input_names, output_names)}
136 |
137 | # Generate dummy input
138 | batch_size = 2
139 | seq_length = 6
140 | dummy_input = [" ".join([tokenizer.unk_token])
141 | * seq_length] * batch_size
142 | inputs = dict(tokenizer(dummy_input, return_tensors="pt"))
143 |
144 | if save_path.exists():
145 | logger.warning(f'Overwrite model {save_path.as_posix()}')
146 | save_path.unlink()
147 |
148 | torch.onnx.export(model,
149 | (inputs,),
150 | save_path,
151 | input_names=input_names,
152 | output_names=output_names,
153 | dynamic_axes=dynamic_axes,
154 | do_constant_folding=True,
155 | opset_version=11
156 | )
157 |
158 | if not os.path.exists(save_path):
159 | logger.error(f'Export Failed!')
160 |
161 | return save_path
162 |
163 |
164 | def main():
165 | parser = argparse.ArgumentParser()
166 | parser.add_argument("-m", "--model_path", type=Path, required=True,
167 | default='./checkpoint/model_best', help="The path to model parameters to be loaded.")
168 | parser.add_argument("-o", "--output_path", type=Path, default=None,
169 | help="The path of model parameter in static graph to be saved.")
170 | args = parser.parse_args()
171 |
172 | if args.output_path is None:
173 | args.output_path = args.model_path
174 |
175 | tokenizer = BertTokenizer.from_pretrained(args.model_path)
176 | model = UIE.from_pretrained(args.model_path)
177 | device = torch.device('cpu')
178 | input_names = [
179 | 'input_ids',
180 | 'token_type_ids',
181 | 'attention_mask',
182 | ]
183 | output_names = [
184 | 'start_prob',
185 | 'end_prob'
186 | ]
187 |
188 | logger.info("Export Tokenizer Config...")
189 |
190 | export_tokenizer(args)
191 |
192 | logger.info("Export ONNX Model...")
193 |
194 | save_path = export_onnx(
195 | args.output_path, tokenizer, model, device, input_names, output_names)
196 | validate_onnx(tokenizer, model, save_path)
197 |
198 | logger.info(f"All good, model saved at: {save_path.as_posix()}")
199 |
200 |
201 | def export_tokenizer(args):
202 | for tokenizer_fine in ['tokenizer_config.json', 'special_tokens_map.json', 'vocab.txt']:
203 | file_from = args.model_path / tokenizer_fine
204 | file_to = args.output_path/tokenizer_fine
205 | if file_from.resolve() == file_to.resolve():
206 | continue
207 | shutil.copyfile(file_from, file_to)
208 |
209 |
210 | if __name__ == "__main__":
211 |
212 | main()
213 |
--------------------------------------------------------------------------------
/finetune.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022 Heiheiyoyo. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import argparse
16 | import shutil
17 | import sys
18 | import time
19 | import os
20 | import torch
21 | from torch.utils.data import DataLoader
22 | from transformers import BertTokenizerFast
23 |
24 | from utils import IEDataset, logger, tqdm
25 | from model import UIE
26 | from evaluate import evaluate
27 | from utils import set_seed, SpanEvaluator, EarlyStopping, logging_redirect_tqdm
28 |
29 |
30 | def do_train():
31 |
32 | set_seed(args.seed)
33 | show_bar = True
34 |
35 | tokenizer = BertTokenizerFast.from_pretrained(args.model)
36 | model = UIE.from_pretrained(args.model)
37 | if args.device == 'gpu':
38 | model = model.cuda()
39 | train_ds = IEDataset(args.train_path, tokenizer=tokenizer,
40 | max_seq_len=args.max_seq_len)
41 | dev_ds = IEDataset(args.dev_path, tokenizer=tokenizer,
42 | max_seq_len=args.max_seq_len)
43 |
44 | train_data_loader = DataLoader(
45 | train_ds, batch_size=args.batch_size, shuffle=True)
46 | dev_data_loader = DataLoader(
47 | dev_ds, batch_size=args.batch_size, shuffle=True)
48 |
49 | optimizer = torch.optim.AdamW(
50 | lr=args.learning_rate, params=model.parameters())
51 |
52 | criterion = torch.nn.functional.binary_cross_entropy
53 | metric = SpanEvaluator()
54 |
55 | if args.early_stopping:
56 | early_stopping_save_dir = os.path.join(
57 | args.save_dir, "early_stopping")
58 | if not os.path.exists(early_stopping_save_dir):
59 | os.makedirs(early_stopping_save_dir)
60 | if show_bar:
61 | def trace_func(*args, **kwargs):
62 | with logging_redirect_tqdm([logger.logger]):
63 | logger.info(*args, **kwargs)
64 | else:
65 | trace_func = logger.info
66 | early_stopping = EarlyStopping(
67 | patience=7, verbose=True, trace_func=trace_func,
68 | save_dir=early_stopping_save_dir)
69 |
70 | loss_list = []
71 | loss_sum = 0
72 | loss_num = 0
73 | global_step = 0
74 | best_step = 0
75 | best_f1 = 0
76 | tic_train = time.time()
77 | epoch_iterator = range(1, args.num_epochs + 1)
78 | if show_bar:
79 | train_postfix_info = {'loss': 'unknown'}
80 | epoch_iterator = tqdm(
81 | epoch_iterator, desc='Training', unit='epoch')
82 | for epoch in epoch_iterator:
83 | train_data_iterator = train_data_loader
84 | if show_bar:
85 | train_data_iterator = tqdm(train_data_iterator,
86 | desc=f'Training Epoch {epoch}', unit='batch')
87 | train_data_iterator.set_postfix(train_postfix_info)
88 | for batch in train_data_iterator:
89 | if show_bar:
90 | epoch_iterator.refresh()
91 | input_ids, token_type_ids, att_mask, start_ids, end_ids = batch
92 | if args.device == 'gpu':
93 | input_ids = input_ids.cuda()
94 | token_type_ids = token_type_ids.cuda()
95 | att_mask = att_mask.cuda()
96 | start_ids = start_ids.cuda()
97 | end_ids = end_ids.cuda()
98 | outputs = model(input_ids=input_ids,
99 | token_type_ids=token_type_ids,
100 | attention_mask=att_mask)
101 | start_prob, end_prob = outputs[0], outputs[1]
102 |
103 | start_ids = start_ids.type(torch.float32)
104 | end_ids = end_ids.type(torch.float32)
105 | loss_start = criterion(start_prob, start_ids)
106 | loss_end = criterion(end_prob, end_ids)
107 | loss = (loss_start + loss_end) / 2.0
108 | loss.backward()
109 | optimizer.step()
110 | optimizer.zero_grad()
111 | loss_list.append(float(loss))
112 | loss_sum += float(loss)
113 | loss_num += 1
114 |
115 | if show_bar:
116 | loss_avg = loss_sum / loss_num
117 | train_postfix_info.update({
118 | 'loss': f'{loss_avg:.5f}'
119 | })
120 | train_data_iterator.set_postfix(train_postfix_info)
121 |
122 | global_step += 1
123 | if global_step % args.logging_steps == 0:
124 | time_diff = time.time() - tic_train
125 | loss_avg = loss_sum / loss_num
126 |
127 | if show_bar:
128 | with logging_redirect_tqdm([logger.logger]):
129 | logger.info(
130 | "global step %d, epoch: %d, loss: %.5f, speed: %.2f step/s"
131 | % (global_step, epoch, loss_avg,
132 | args.logging_steps / time_diff))
133 | else:
134 | logger.info(
135 | "global step %d, epoch: %d, loss: %.5f, speed: %.2f step/s"
136 | % (global_step, epoch, loss_avg,
137 | args.logging_steps / time_diff))
138 | tic_train = time.time()
139 |
140 | if global_step % args.valid_steps == 0:
141 | save_dir = os.path.join(
142 | args.save_dir, "model_%d" % global_step)
143 | if not os.path.exists(save_dir):
144 | os.makedirs(save_dir)
145 | model_to_save = model
146 | model_to_save.save_pretrained(save_dir)
147 | tokenizer.save_pretrained(save_dir)
148 | if args.max_model_num:
149 | model_to_delete = global_step-args.max_model_num*args.valid_steps
150 | model_to_delete_path = os.path.join(
151 | args.save_dir, "model_%d" % model_to_delete)
152 | if model_to_delete > 0 and os.path.exists(model_to_delete_path):
153 | shutil.rmtree(model_to_delete_path)
154 |
155 | dev_loss_avg, precision, recall, f1 = evaluate(
156 | model, metric, data_loader=dev_data_loader, device=args.device, loss_fn=criterion)
157 |
158 | if show_bar:
159 | train_postfix_info.update({
160 | 'F1': f'{f1:.3f}',
161 | 'dev loss': f'{dev_loss_avg:.5f}'
162 | })
163 | train_data_iterator.set_postfix(train_postfix_info)
164 | with logging_redirect_tqdm([logger.logger]):
165 | logger.info("Evaluation precision: %.5f, recall: %.5f, F1: %.5f, dev loss: %.5f"
166 | % (precision, recall, f1, dev_loss_avg))
167 | else:
168 | logger.info("Evaluation precision: %.5f, recall: %.5f, F1: %.5f, dev loss: %.5f"
169 | % (precision, recall, f1, dev_loss_avg))
170 | # Save model which has best F1
171 | if f1 > best_f1:
172 | if show_bar:
173 | with logging_redirect_tqdm([logger.logger]):
174 | logger.info(
175 | f"best F1 performence has been updated: {best_f1:.5f} --> {f1:.5f}"
176 | )
177 | else:
178 | logger.info(
179 | f"best F1 performence has been updated: {best_f1:.5f} --> {f1:.5f}"
180 | )
181 | best_f1 = f1
182 | save_dir = os.path.join(args.save_dir, "model_best")
183 | model_to_save = model
184 | model_to_save.save_pretrained(save_dir)
185 | tokenizer.save_pretrained(save_dir)
186 | tic_train = time.time()
187 |
188 | if args.early_stopping:
189 | dev_loss_avg, precision, recall, f1 = evaluate(
190 | model, metric, data_loader=dev_data_loader, device=args.device, loss_fn=criterion)
191 |
192 | if show_bar:
193 | train_postfix_info.update({
194 | 'F1': f'{f1:.3f}',
195 | 'dev loss': f'{dev_loss_avg:.5f}'
196 | })
197 | train_data_iterator.set_postfix(train_postfix_info)
198 | with logging_redirect_tqdm([logger.logger]):
199 | logger.info("Evaluation precision: %.5f, recall: %.5f, F1: %.5f, dev loss: %.5f"
200 | % (precision, recall, f1, dev_loss_avg))
201 | else:
202 | logger.info("Evaluation precision: %.5f, recall: %.5f, F1: %.5f, dev loss: %.5f"
203 | % (precision, recall, f1, dev_loss_avg))
204 |
205 | # Early Stopping
206 | early_stopping(dev_loss_avg, model)
207 | if early_stopping.early_stop:
208 | if show_bar:
209 | with logging_redirect_tqdm([logger.logger]):
210 | logger.info("Early stopping")
211 | else:
212 | logger.info("Early stopping")
213 | tokenizer.save_pretrained(early_stopping_save_dir)
214 | sys.exit(0)
215 |
216 |
217 | if __name__ == "__main__":
218 | # yapf: disable
219 | parser = argparse.ArgumentParser()
220 |
221 | parser.add_argument("-b", "--batch_size", default=16, type=int,
222 | help="Batch size per GPU/CPU for training.")
223 | parser.add_argument("--learning_rate", default=1e-5,
224 | type=float, help="The initial learning rate for Adam.")
225 | parser.add_argument("-t", "--train_path", default=None, required=True,
226 | type=str, help="The path of train set.")
227 | parser.add_argument("-d", "--dev_path", default=None, required=True,
228 | type=str, help="The path of dev set.")
229 | parser.add_argument("-s", "--save_dir", default='./checkpoint', type=str,
230 | help="The output directory where the model checkpoints will be written.")
231 | parser.add_argument("--max_seq_len", default=512, type=int, help="The maximum input sequence length. "
232 | "Sequences longer than this will be split automatically.")
233 | parser.add_argument("--num_epochs", default=100, type=int,
234 | help="Total number of training epochs to perform.")
235 | parser.add_argument("--seed", default=1000, type=int,
236 | help="Random seed for initialization")
237 | parser.add_argument("--logging_steps", default=10,
238 | type=int, help="The interval steps to logging.")
239 | parser.add_argument("--valid_steps", default=100, type=int,
240 | help="The interval steps to evaluate model performance.")
241 | parser.add_argument("-D", '--device', choices=['cpu', 'gpu'], default="gpu",
242 | help="Select which device to train model, defaults to gpu.")
243 | parser.add_argument("-m", "--model", default="uie_base_pytorch", type=str,
244 | help="Select the pretrained model for few-shot learning.")
245 | parser.add_argument("--max_model_num", default=5, type=int,
246 | help="Max number of saved model. Best model and earlystopping model is not included.")
247 | parser.add_argument("--early_stopping", action='store_true', default=False,
248 | help="Use early stopping while training")
249 |
250 | args = parser.parse_args()
251 | # yapf: enable
252 |
253 | do_train()
254 |
--------------------------------------------------------------------------------
/labelstudio2doccano.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import argparse
16 | import os
17 | import json
18 |
19 |
20 | def append_attrs(data, item, label_id, relation_id):
21 |
22 | mapp = {}
23 |
24 | for anno in data["annotations"][0]["result"]:
25 | if anno["type"] == "labels":
26 | label_id += 1
27 | item["entities"].append({
28 | "id": label_id,
29 | "label": anno["value"]["labels"][0],
30 | "start_offset": anno["value"]["start"],
31 | "end_offset": anno["value"]["end"]
32 | })
33 | mapp[anno["id"]] = label_id
34 |
35 | for anno in data["annotations"][0]["result"]:
36 | if anno["type"] == "relation":
37 | relation_id += 1
38 | item["relations"].append({
39 | "id": relation_id,
40 | "from_id": mapp[anno["from_id"]],
41 | "to_id": mapp[anno["to_id"]],
42 | "type": anno["labels"][0]
43 | })
44 |
45 | return item, label_id, relation_id
46 |
47 |
48 | def convert(dataset, task_type):
49 | results = []
50 | outer_id = 0
51 | if task_type == "ext":
52 | label_id = 0
53 | relation_id = 0
54 | for data in dataset:
55 | outer_id += 1
56 | item = {
57 | "id": outer_id,
58 | "text": data["data"]["text"],
59 | "entities": [],
60 | "relations": []
61 | }
62 | item, label_id, relation_id = append_attrs(data, item, label_id,
63 | relation_id)
64 | results.append(item)
65 | # for the classification task
66 | else:
67 | for data in dataset:
68 | outer_id += 1
69 | results.append({
70 | "id":
71 | outer_id,
72 | "text":
73 | data["data"]["text"],
74 | "label":
75 | data["annotations"][0]["result"][0]["value"]["choices"]
76 | })
77 | return results
78 |
79 |
80 | def do_convert(args):
81 |
82 | if not os.path.exists(args.labelstudio_file):
83 | raise ValueError("Please input the correct path of label studio file.")
84 |
85 | with open(args.labelstudio_file, "r", encoding="utf-8") as infile:
86 | for content in infile:
87 | dataset = json.loads(content)
88 | results = convert(dataset, args.task_type)
89 |
90 | with open(args.doccano_file, "w", encoding="utf-8") as outfile:
91 | for item in results:
92 | outline = json.dumps(item, ensure_ascii=False)
93 | outfile.write(outline + "\n")
94 |
95 |
96 | if __name__ == "__main__":
97 |
98 | parser = argparse.ArgumentParser()
99 |
100 | parser.add_argument(
101 | '--labelstudio_file',
102 | type=str,
103 | help=
104 | 'The export file path of label studio, only support the JSON format.')
105 | parser.add_argument('--doccano_file',
106 | type=str,
107 | default='doccano_ext.jsonl',
108 | help='Saving path in doccano format.')
109 | parser.add_argument(
110 | '--task_type',
111 | type=str,
112 | choices=['ext', 'cls'],
113 | default='ext',
114 | help=
115 | 'Select task type, ext for the extraction task and cls for the classification task, defaults to ext.'
116 | )
117 |
118 | args = parser.parse_args()
119 |
120 | do_convert(args)
121 |
--------------------------------------------------------------------------------
/model.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022 Heiheiyoyo. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import torch
16 | import torch.nn as nn
17 | from dataclasses import dataclass
18 | from transformers import PretrainedConfig
19 | from transformers.utils import ModelOutput
20 | from typing import Optional, Tuple
21 |
22 | from ernie import ErnieModel, ErniePreTrainedModel
23 | from ernie_m import ErnieMModel, ErnieMPreTrainedModel
24 |
25 |
26 | @dataclass
27 | class UIEModelOutput(ModelOutput):
28 | """
29 | Output class for outputs of UIE.
30 | Args:
31 | loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
32 | Total span extraction loss is the sum of a Cross-Entropy for the start and end positions.
33 | start_prob (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
34 | Span-start scores (after Sigmoid).
35 | end_prob (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
36 | Span-end scores (after Sigmoid).
37 | hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
38 | Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
39 | one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
40 | Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
41 | attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
42 | Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
43 | sequence_length)`.
44 | Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
45 | heads.
46 | """
47 | loss: Optional[torch.FloatTensor] = None
48 | start_prob: torch.FloatTensor = None
49 | end_prob: torch.FloatTensor = None
50 | hidden_states: Optional[Tuple[torch.FloatTensor]] = None
51 | attentions: Optional[Tuple[torch.FloatTensor]] = None
52 |
53 |
54 | class UIE(ErniePreTrainedModel):
55 | """
56 | UIE model based on Bert model.
57 | This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
58 | library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
59 | etc.)
60 | This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
61 | Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
62 | and behavior.
63 | Parameters:
64 | config ([`PretrainedConfig`]): Model configuration class with all the parameters of the model.
65 | Initializing with a config file does not load the weights associated with the model, only the
66 | configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
67 | """
68 |
69 | def __init__(self, config: PretrainedConfig):
70 | super(UIE, self).__init__(config)
71 | self.encoder = ErnieModel(config)
72 | self.config = config
73 | hidden_size = self.config.hidden_size
74 |
75 | self.linear_start = nn.Linear(hidden_size, 1)
76 | self.linear_end = nn.Linear(hidden_size, 1)
77 | self.sigmoid = nn.Sigmoid()
78 |
79 | # if hasattr(config, 'use_task_id') and config.use_task_id:
80 | # # Add task type embedding to BERT
81 | # task_type_embeddings = nn.Embedding(
82 | # config.task_type_vocab_size, config.hidden_size)
83 | # self.encoder.embeddings.task_type_embeddings = task_type_embeddings
84 |
85 | # def hook(module, input, output):
86 | # input = input[0]
87 | # return output+task_type_embeddings(torch.zeros(input.size(), dtype=torch.int64, device=input.device))
88 | # self.encoder.embeddings.word_embeddings.register_forward_hook(hook)
89 |
90 | self.post_init()
91 |
92 | def forward(self, input_ids: Optional[torch.Tensor] = None,
93 | token_type_ids: Optional[torch.Tensor] = None,
94 | position_ids: Optional[torch.Tensor] = None,
95 | attention_mask: Optional[torch.Tensor] = None,
96 | head_mask: Optional[torch.Tensor] = None,
97 | inputs_embeds: Optional[torch.Tensor] = None,
98 | start_positions: Optional[torch.Tensor] = None,
99 | end_positions: Optional[torch.Tensor] = None,
100 | output_attentions: Optional[bool] = None,
101 | output_hidden_states: Optional[bool] = None,
102 | return_dict: Optional[bool] = None
103 | ):
104 | """
105 | Args:
106 | input_ids (`torch.LongTensor` of shape `({0})`):
107 | Indices of input sequence tokens in the vocabulary.
108 | Indices can be obtained using [`BertTokenizer`]. See [`PreTrainedTokenizer.encode`] and
109 | [`PreTrainedTokenizer.__call__`] for details.
110 | [What are input IDs?](../glossary#input-ids)
111 | attention_mask (`torch.FloatTensor` of shape `({0})`, *optional*):
112 | Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
113 | - 1 for tokens that are **not masked**,
114 | - 0 for tokens that are **masked**.
115 | [What are attention masks?](../glossary#attention-mask)
116 | token_type_ids (`torch.LongTensor` of shape `({0})`, *optional*):
117 | Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
118 | 1]`:
119 | - 0 corresponds to a *sentence A* token,
120 | - 1 corresponds to a *sentence B* token.
121 | [What are token type IDs?](../glossary#token-type-ids)
122 | position_ids (`torch.LongTensor` of shape `({0})`, *optional*):
123 | Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
124 | config.max_position_embeddings - 1]`.
125 | [What are position IDs?](../glossary#position-ids)
126 | head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
127 | Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
128 | - 1 indicates the head is **not masked**,
129 | - 0 indicates the head is **masked**.
130 | inputs_embeds (`torch.FloatTensor` of shape `({0}, hidden_size)`, *optional*):
131 | Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
132 | is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
133 | model's internal embedding lookup matrix.
134 | start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
135 | Labels for position (index) of the start of the labelled span for computing the token classification loss.
136 | Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
137 | are not taken into account for computing the loss.
138 | end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
139 | Labels for position (index) of the end of the labelled span for computing the token classification loss.
140 | Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
141 | are not taken into account for computing the loss.
142 | output_attentions (`bool`, *optional*):
143 | Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
144 | tensors for more detail.
145 | output_hidden_states (`bool`, *optional*):
146 | Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
147 | more detail.
148 | return_dict (`bool`, *optional*):
149 | Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
150 | """
151 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict
152 | outputs = self.encoder(
153 | input_ids=input_ids,
154 | token_type_ids=token_type_ids,
155 | position_ids=position_ids,
156 | attention_mask=attention_mask,
157 | head_mask=head_mask,
158 | inputs_embeds=inputs_embeds,
159 | output_attentions=output_attentions,
160 | output_hidden_states=output_hidden_states,
161 | return_dict=return_dict
162 | )
163 | sequence_output = outputs[0]
164 |
165 | start_logits = self.linear_start(sequence_output)
166 | start_logits = torch.squeeze(start_logits, -1)
167 | start_prob = self.sigmoid(start_logits)
168 | end_logits = self.linear_end(sequence_output)
169 | end_logits = torch.squeeze(end_logits, -1)
170 | end_prob = self.sigmoid(end_logits)
171 |
172 | total_loss = None
173 | if start_positions is not None and end_positions is not None:
174 | loss_fct = nn.BCELoss()
175 | start_loss = loss_fct(start_prob, start_positions)
176 | end_loss = loss_fct(end_prob, end_positions)
177 | total_loss = (start_loss + end_loss) / 2.0
178 |
179 | if not return_dict:
180 | output = (start_prob, end_prob) + outputs[2:]
181 | return ((total_loss,) + output) if total_loss is not None else output
182 |
183 | return UIEModelOutput(
184 | loss=total_loss,
185 | start_prob=start_prob,
186 | end_prob=end_prob,
187 | hidden_states=outputs.hidden_states,
188 | attentions=outputs.attentions,
189 | )
190 |
191 |
192 | class UIEM(ErnieMPreTrainedModel):
193 | """
194 | UIE model based on Bert model.
195 | This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
196 | library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
197 | etc.)
198 | This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
199 | Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
200 | and behavior.
201 | Parameters:
202 | config ([`PretrainedConfig`]): Model configuration class with all the parameters of the model.
203 | Initializing with a config file does not load the weights associated with the model, only the
204 | configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
205 | """
206 |
207 | def __init__(self, config: PretrainedConfig):
208 | super(UIEM, self).__init__(config)
209 | self.encoder = ErnieMModel(config)
210 | self.config = config
211 | hidden_size = self.config.hidden_size
212 |
213 | self.linear_start = nn.Linear(hidden_size, 1)
214 | self.linear_end = nn.Linear(hidden_size, 1)
215 | self.sigmoid = nn.Sigmoid()
216 |
217 | self.post_init()
218 |
219 | def forward(self, input_ids: Optional[torch.Tensor] = None,
220 | position_ids: Optional[torch.Tensor] = None,
221 | attention_mask: Optional[torch.Tensor] = None,
222 | head_mask: Optional[torch.Tensor] = None,
223 | start_positions: Optional[torch.Tensor] = None,
224 | end_positions: Optional[torch.Tensor] = None,
225 | output_attentions: Optional[bool] = None,
226 | output_hidden_states: Optional[bool] = None,
227 | return_dict: Optional[bool] = None
228 | ):
229 | """
230 | Args:
231 | input_ids (`torch.LongTensor` of shape `({0})`):
232 | Indices of input sequence tokens in the vocabulary.
233 | Indices can be obtained using [`BertTokenizer`]. See [`PreTrainedTokenizer.encode`] and
234 | [`PreTrainedTokenizer.__call__`] for details.
235 | [What are input IDs?](../glossary#input-ids)
236 | attention_mask (`torch.FloatTensor` of shape `({0})`, *optional*):
237 | Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
238 | - 1 for tokens that are **not masked**,
239 | - 0 for tokens that are **masked**.
240 | [What are attention masks?](../glossary#attention-mask)
241 | position_ids (`torch.LongTensor` of shape `({0})`, *optional*):
242 | Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
243 | config.max_position_embeddings - 1]`.
244 | [What are position IDs?](../glossary#position-ids)
245 | head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
246 | Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
247 | - 1 indicates the head is **not masked**,
248 | - 0 indicates the head is **masked**.
249 | inputs_embeds (`torch.FloatTensor` of shape `({0}, hidden_size)`, *optional*):
250 | Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
251 | is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
252 | model's internal embedding lookup matrix.
253 | start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
254 | Labels for position (index) of the start of the labelled span for computing the token classification loss.
255 | Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
256 | are not taken into account for computing the loss.
257 | end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
258 | Labels for position (index) of the end of the labelled span for computing the token classification loss.
259 | Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
260 | are not taken into account for computing the loss.
261 | output_attentions (`bool`, *optional*):
262 | Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
263 | tensors for more detail.
264 | output_hidden_states (`bool`, *optional*):
265 | Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
266 | more detail.
267 | return_dict (`bool`, *optional*):
268 | Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
269 | """
270 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict
271 | outputs = self.encoder(
272 | input_ids=input_ids,
273 | position_ids=position_ids,
274 | # attention_mask=attention_mask,
275 | # head_mask=head_mask,
276 | output_attentions=output_attentions,
277 | output_hidden_states=output_hidden_states,
278 | return_dict=return_dict
279 | )
280 | sequence_output = outputs[0]
281 |
282 | start_logits = self.linear_start(sequence_output)
283 | start_logits = torch.squeeze(start_logits, -1)
284 | start_prob = self.sigmoid(start_logits)
285 | end_logits = self.linear_end(sequence_output)
286 | end_logits = torch.squeeze(end_logits, -1)
287 | end_prob = self.sigmoid(end_logits)
288 |
289 | total_loss = None
290 | if start_positions is not None and end_positions is not None:
291 | loss_fct = nn.BCELoss()
292 | start_loss = loss_fct(start_prob, start_positions)
293 | end_loss = loss_fct(end_prob, end_positions)
294 | total_loss = (start_loss + end_loss) / 2.0
295 |
296 | if not return_dict:
297 | output = (start_prob, end_prob) + outputs[2:]
298 | return ((total_loss,) + output) if total_loss is not None else output
299 |
300 | return UIEModelOutput(
301 | loss=total_loss,
302 | start_prob=start_prob,
303 | end_prob=end_prob,
304 | hidden_states=outputs.hidden_states,
305 | attentions=outputs.attentions,
306 | )
307 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | numpy >=1.22
3 | six
4 | colorlog
5 | torch >=1.10,<2.0 # cpu版
6 | transformers >=4.18,<5.0
7 | packaging
8 | tqdm
9 | sentencepiece
10 | protobuf==3.19.0
11 | onnxruntime # cpu版
12 |
13 | # faster-tokenizer==0.2.0 # faster-tokenizer 已经废弃
14 | fast-tokenizer-python==1.0.0
15 |
16 | # paddlepaddle # paddlepaddle 是可选的
--------------------------------------------------------------------------------
/tokenizer.py:
--------------------------------------------------------------------------------
1 |
2 | import collections
3 | import json
4 | import os
5 | import unicodedata
6 | from shutil import copyfile
7 | from typing import Any, Dict, List, Optional, Tuple
8 |
9 | import sentencepiece as spm
10 | from transformers import SLOW_TO_FAST_CONVERTERS, PreTrainedTokenizerFast, requires_backends
11 | from transformers.convert_slow_tokenizer import Converter, SentencePieceExtractor
12 |
13 | try:
14 | from fast_tokenizer import Tokenizer, normalizers, pretokenizers, postprocessors
15 | from fast_tokenizer.models import BPE, Unigram
16 | except ImportError as e:
17 | print('fast_tokenizer 未安装! pip install fast-tokenizer-python==1.0.0')
18 | print(e)
19 | from faster_tokenizer import Tokenizer, normalizers, pretokenizers, postprocessors
20 | from faster_tokenizer.models import BPE, Unigram
21 |
22 | from transformers.tokenization_utils import AddedToken, PreTrainedTokenizer
23 | from transformers.utils import SPIECE_UNDERLINE
24 |
25 | from utils import logger
26 |
27 |
28 | VOCAB_FILES_NAMES = {
29 | "sentencepiece_model_file": "sentencepiece.bpe.model",
30 | "vocab_file": "vocab.txt",
31 | }
32 |
33 | PRETRAINED_VOCAB_FILES_MAP = {
34 | "vocab_file": {
35 | "ernie-m-base": "https://bj.bcebos.com/paddlenlp/models/transformers/ernie_m/ernie_m.vocab.txt",
36 | "ernie-m-large": "https://bj.bcebos.com/paddlenlp/models/transformers/ernie_m/ernie_m.vocab.txt"
37 | },
38 | "sentencepiece_model_file": {
39 | "ernie-m-base": "https://bj.bcebos.com/paddlenlp/models/transformers/ernie_m/ernie_m.sentencepiece.bpe.model",
40 | "ernie-m-large": "https://bj.bcebos.com/paddlenlp/models/transformers/ernie_m/ernie_m.sentencepiece.bpe.model",
41 | }
42 | }
43 |
44 | PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
45 | "ernie-m-base": 514,
46 | "ernie-m-large": 514,
47 | }
48 |
49 |
50 | def load_vocab(vocab_file):
51 | """Loads a vocabulary file into a dictionary."""
52 | vocab = collections.OrderedDict()
53 | with open(vocab_file, "r", encoding="utf-8") as reader:
54 | tokens = reader.readlines()
55 | for index, token in enumerate(tokens):
56 | token = token.rstrip("\n")
57 | if token in vocab:
58 | print(f'{token} 重复!')
59 | vocab[token] = index
60 | return vocab
61 |
62 |
63 | class ErnieMTokenizer(PreTrainedTokenizer):
64 | """
65 | Construct an Erine-M tokenizer. Based on [SentencePiece](https://github.com/google/sentencepiece).
66 |
67 | This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to
68 | this superclass for more information regarding those methods.
69 |
70 | Args:
71 | vocab_file (`str`):
72 | File containing the vocabulary.
73 | sentencepiece_model_file (`str`):
74 | [SentencePiece](https://github.com/google/sentencepiece) file (generally has a .spm extension) that
75 | contains the vocabulary necessary to instantiate a tokenizer.
76 | do_lower_case (`bool`, *optional*, defaults to `True`):
77 | Whether to lowercase the input when tokenizing.
78 | unk_token (`str`, *optional*, defaults to `""`):
79 | The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
80 | token instead.
81 | sep_token (`str`, *optional*, defaults to `""`):
82 | The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for
83 | sequence classification or for a text and a question for question answering. It is also used as the last
84 | token of a sequence built with special tokens.
85 | pad_token (`str`, *optional*, defaults to `""`):
86 | The token used for padding, for example when batching sequences of different lengths.
87 | cls_token (`str`, *optional*, defaults to `""`):
88 | The classifier token which is used when doing sequence classification (classification of the whole sequence
89 | instead of per-token classification). It is the first token of the sequence when built with special tokens.
90 | mask_token (`str`, *optional*, defaults to `""`):
91 | The token used for masking values. This is the token used when training this model with masked language
92 | modeling. This is the token which the model will try to predict.
93 | sp_model_kwargs (`dict`, *optional*):
94 | Will be passed to the `SentencePieceProcessor.__init__()` method. The [Python wrapper for
95 | SentencePiece](https://github.com/google/sentencepiece/tree/master/python) can be used, among other things,
96 | to set:
97 |
98 | - `enable_sampling`: Enable subword regularization.
99 | - `nbest_size`: Sampling parameters for unigram. Invalid for BPE-Dropout.
100 |
101 | - `nbest_size = {0,1}`: No sampling is performed.
102 | - `nbest_size > 1`: samples from the nbest_size results.
103 | - `nbest_size < 0`: assuming that nbest_size is infinite and samples from the all hypothesis (lattice)
104 | using forward-filtering-and-backward-sampling algorithm.
105 |
106 | - `alpha`: Smoothing parameter for unigram sampling, and dropout probability of merge operations for
107 | BPE-dropout.
108 |
109 | Attributes:
110 | sp_model (`SentencePieceProcessor`):
111 | The *SentencePiece* processor that is used for every conversion (string, tokens and IDs).
112 | """
113 |
114 | vocab_files_names = VOCAB_FILES_NAMES
115 | pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
116 | max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
117 | padding_side = "left"
118 |
119 | def __init__(
120 | self,
121 | vocab_file,
122 | sentencepiece_model_file,
123 | do_lower_case=False,
124 | unk_token="[UNK]",
125 | sep_token="[SEP]",
126 | pad_token="[PAD]",
127 | cls_token="[CLS]",
128 | mask_token="[MASK]",
129 | sp_model_kwargs: Optional[Dict[str, Any]] = None,
130 | **kwargs
131 | ) -> None:
132 |
133 | # Mask token behave like a normal word, i.e. include the space before it
134 | mask_token = AddedToken(mask_token, lstrip=True, rstrip=False) if isinstance(
135 | mask_token, str) else mask_token
136 |
137 | self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs
138 |
139 | super().__init__(
140 | do_lower_case=do_lower_case,
141 | unk_token=unk_token,
142 | sep_token=sep_token,
143 | pad_token=pad_token,
144 | cls_token=cls_token,
145 | mask_token=mask_token,
146 | sp_model_kwargs=self.sp_model_kwargs,
147 | **kwargs,
148 | )
149 |
150 | self.do_lower_case = do_lower_case
151 | self.sentencepiece_model_file = sentencepiece_model_file
152 | if not os.path.isfile(sentencepiece_model_file):
153 | raise ValueError(
154 | f"Can't find a vocabulary file at path '{sentencepiece_model_file}'. To load the vocabulary from a Google pretrained "
155 | "model use `tokenizer = ErnieMTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`"
156 | )
157 |
158 | self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
159 | self.sp_model.Load(sentencepiece_model_file)
160 |
161 | if not os.path.isfile(vocab_file):
162 | raise ValueError(
163 | f"Can't find a vocabulary file at path '{vocab_file}'. To load the vocabulary from a Google pretrained "
164 | "model use `tokenizer = ErnieMTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`"
165 | )
166 | self.vocab = load_vocab(vocab_file)
167 | self.ids_to_tokens = collections.OrderedDict(
168 | [(ids, tok) for tok, ids in self.vocab.items()])
169 |
170 | self.SP_CHAR_MAPPING = {}
171 |
172 | for ch in range(65281, 65375):
173 | if ch in [ord(u'~')]:
174 | self.SP_CHAR_MAPPING[chr(ch)] = chr(ch)
175 | continue
176 | self.SP_CHAR_MAPPING[chr(ch)] = chr(ch - 65248)
177 |
178 | @property
179 | def vocab_size(self):
180 | return len(self.sp_model)
181 |
182 | def get_vocab(self):
183 | return dict(self.vocab, **self.added_tokens_encoder)
184 |
185 | def __getstate__(self):
186 | state = self.__dict__.copy()
187 | state["sp_model"] = None
188 | return state
189 |
190 | def __setstate__(self, d):
191 | self.__dict__ = d
192 |
193 | # for backward compatibility
194 | if not hasattr(self, "sp_model_kwargs"):
195 | self.sp_model_kwargs = {}
196 |
197 | self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
198 | self.sp_model.Load(self.vocab_file)
199 |
200 | def preprocess_text(self, inputs):
201 | outputs = ''.join((self.SP_CHAR_MAPPING.get(c, c) for c in inputs))
202 | outputs = outputs.replace("``", '"').replace("''", '"')
203 |
204 | outputs = unicodedata.normalize("NFKD", outputs)
205 | outputs = "".join(
206 | [c for c in outputs if not unicodedata.combining(c)])
207 | if self.do_lower_case:
208 | outputs = outputs.lower()
209 |
210 | return outputs
211 |
212 | def _tokenize(self, text: str) -> List[str]:
213 | """Tokenize a string."""
214 | text = self.preprocess_text(text)
215 |
216 | pieces = self.sp_model.EncodeAsPieces(text)
217 |
218 | new_pieces = []
219 | for piece in pieces:
220 | if piece == SPIECE_UNDERLINE:
221 | continue
222 | lst_i = 0
223 | for i, c in enumerate(piece):
224 | if c == SPIECE_UNDERLINE:
225 | continue
226 | if self.is_ch_char(c) or self.is_punct(c):
227 | if i > lst_i and piece[lst_i:i] != SPIECE_UNDERLINE:
228 | new_pieces.append(piece[lst_i:i])
229 | new_pieces.append(c)
230 | lst_i = i + 1
231 | elif c.isdigit() and i > 0 and not piece[i - 1].isdigit():
232 | if i > lst_i and piece[lst_i:i] != SPIECE_UNDERLINE:
233 | new_pieces.append(piece[lst_i:i])
234 | lst_i = i
235 | elif not c.isdigit() and i > 0 and piece[i - 1].isdigit():
236 | if i > lst_i and piece[lst_i:i] != SPIECE_UNDERLINE:
237 | new_pieces.append(piece[lst_i:i])
238 | lst_i = i
239 | if len(piece) > lst_i:
240 | new_pieces.append(piece[lst_i:])
241 |
242 | return new_pieces
243 |
244 | def _convert_token_to_id(self, token):
245 | """Converts a token (str) in an id using the vocab."""
246 |
247 | return self.vocab.get(token, self.vocab.get(self.unk_token))
248 |
249 | def _convert_id_to_token(self, index):
250 | """Converts an index (integer) in a token (str) using the vocab."""
251 | return self.ids_to_tokens.get(index, self.unk_token)
252 |
253 | def convert_tokens_to_string(self, tokens):
254 | """Converts a sequence of tokens (strings for sub-words) in a single string."""
255 | out_string = "".join(tokens).replace(SPIECE_UNDERLINE, " ").strip()
256 | return out_string
257 |
258 | def build_inputs_with_special_tokens(
259 | self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
260 | ) -> List[int]:
261 | """
262 | Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
263 | adding special tokens. An Erine-M sequence has the following format:
264 |
265 | - single sequence: `X `
266 | - pair of sequences: `A B `
267 |
268 | Args:
269 | token_ids_0 (`List[int]`):
270 | List of IDs to which the special tokens will be added.
271 | token_ids_1 (`List[int]`, *optional*):
272 | Optional second list of IDs for sequence pairs.
273 |
274 | Returns:
275 | `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens.
276 | """
277 | sep = [self.sep_token_id]
278 | cls = [self.cls_token_id]
279 |
280 | if token_ids_1 is None:
281 | return cls + token_ids_0 + sep
282 |
283 | return cls + token_ids_0 + sep + sep + token_ids_1 + sep
284 |
285 | def get_special_tokens_mask(
286 | self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
287 | ) -> List[int]:
288 | """
289 | Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding
290 | special tokens using the tokenizer `prepare_for_model` method.
291 |
292 | Args:
293 | token_ids_0 (`List[int]`):
294 | List of IDs.
295 | token_ids_1 (`List[int]`, *optional*):
296 | Optional second list of IDs for sequence pairs.
297 | already_has_special_tokens (`bool`, *optional*, defaults to `False`):
298 | Whether or not the token list is already formatted with special tokens for the model.
299 |
300 | Returns:
301 | `List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
302 | """
303 |
304 | if already_has_special_tokens:
305 | return super().get_special_tokens_mask(
306 | token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True
307 | )
308 |
309 | if token_ids_1 is not None:
310 | return [1] + ([0] * len(token_ids_0)) + [1, 1] + ([0] * len(token_ids_1)) + [1]
311 | return [1] + ([0] * len(token_ids_0)) + [1]
312 |
313 | def create_token_type_ids_from_sequences(
314 | self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
315 | ) -> List[int]:
316 | """
317 | Create a mask from the two sequences passed to be used in a sequence-pair classification task. An Erine-M
318 | sequence pair mask has the following format:
319 |
320 | ```
321 | 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1
322 | | first sequence | second sequence |
323 | ```
324 |
325 | If `token_ids_1` is `None`, this method only returns the first portion of the mask (0s).
326 |
327 | Args:
328 | token_ids_0 (`List[int]`):
329 | List of IDs.
330 | token_ids_1 (`List[int]`, *optional*):
331 | Optional second list of IDs for sequence pairs.
332 |
333 | Returns:
334 | `List[int]`: List of [token type IDs](../glossary#token-type-ids) according to the given sequence(s).
335 | """
336 |
337 | if token_ids_1 is None:
338 | # [CLS] X [SEP]
339 | return (len(token_ids_0) + 2) * [0]
340 |
341 | # [CLS] A [SEP] [SEP] B [SEP]
342 | return [0] * (len(token_ids_0) + 1) + [1] * (len(token_ids_1) + 3)
343 |
344 | def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
345 | if not os.path.isdir(save_directory):
346 | logger.error(
347 | f"Vocabulary path ({save_directory}) should be a directory")
348 | return
349 | sentencepiece_model_file = os.path.join(
350 | save_directory, (filename_prefix + "-" if filename_prefix else "") +
351 | VOCAB_FILES_NAMES["sentencepiece_model_file"]
352 | )
353 | vocab_file = (filename_prefix +
354 | "-" if filename_prefix else "") + save_directory
355 |
356 | if os.path.abspath(self.vocab_file) != os.path.abspath(sentencepiece_model_file) and os.path.isfile(self.vocab_file):
357 | copyfile(self.vocab_file, sentencepiece_model_file)
358 | elif not os.path.isfile(self.vocab_file):
359 | with open(sentencepiece_model_file, "wb") as fi:
360 | content_spiece_model = self.sp_model.serialized_model_proto()
361 | fi.write(content_spiece_model)
362 |
363 | index = 0
364 | with open(vocab_file, "w", encoding="utf-8") as writer:
365 | for token, token_index in sorted(self.vocab.items(), key=lambda kv: kv[1]):
366 | if index != token_index:
367 | logger.warning(
368 | f"Saving vocabulary to {vocab_file}: vocabulary indices are not consecutive."
369 | " Please check that the vocabulary is not corrupted!"
370 | )
371 | index = token_index
372 | writer.write(token + "\n")
373 | index += 1
374 | return vocab_file, sentencepiece_model_file
375 |
376 | def is_ch_char(self, char):
377 | """
378 | is_ch_char
379 | """
380 | if u'\u4e00' <= char <= u'\u9fff':
381 | return True
382 | return False
383 |
384 | def is_alpha(self, char):
385 | """
386 | is_alpha
387 | """
388 | if 'a' <= char <= 'z':
389 | return True
390 | if 'A' <= char <= 'Z':
391 | return True
392 | return False
393 |
394 | def is_punct(self, char):
395 | """
396 | is_punct
397 | """
398 | if char in u",;:.?!~,;:。?!《》【】":
399 | return True
400 | return False
401 |
402 | def is_whitespace(self, char):
403 | """
404 | is whitespace
405 | """
406 | if char == " " or char == "\t" or char == "\n" or char == "\r":
407 | return True
408 | if len(char) == 1:
409 | cat = unicodedata.category(char)
410 | if cat == "Zs":
411 | return True
412 | return False
413 |
414 |
415 | class ErnieMTokenizerFast(PreTrainedTokenizerFast):
416 | r"""
417 | Construct a "fast" ERNIE-M tokenizer (backed by HuggingFace's *tokenizers* library). Based on WordPiece.
418 | This tokenizer inherits from [`PreTrainedTokenizerFast`] which contains most of the main methods. Users should
419 | refer to this superclass for more information regarding those methods.
420 | Args:
421 | vocab_file (`str`):
422 | File containing the vocabulary.
423 | sentencepiece_model_file (`str`):
424 | [SentencePiece](https://github.com/google/sentencepiece) file (generally has a .spm extension) that
425 | contains the vocabulary necessary to instantiate a tokenizer.
426 | do_lower_case (`bool`, *optional*, defaults to `True`):
427 | Whether or not to lowercase the input when tokenizing.
428 | unk_token (`str`, *optional*, defaults to `"[UNK]"`):
429 | The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
430 | token instead.
431 | sep_token (`str`, *optional*, defaults to `"[SEP]"`):
432 | The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for
433 | sequence classification or for a text and a question for question answering. It is also used as the last
434 | token of a sequence built with special tokens.
435 | pad_token (`str`, *optional*, defaults to `"[PAD]"`):
436 | The token used for padding, for example when batching sequences of different lengths.
437 | cls_token (`str`, *optional*, defaults to `"[CLS]"`):
438 | The classifier token which is used when doing sequence classification (classification of the whole sequence
439 | instead of per-token classification). It is the first token of the sequence when built with special tokens.
440 | mask_token (`str`, *optional*, defaults to `"[MASK]"`):
441 | The token used for masking values. This is the token used when training this model with masked language
442 | modeling. This is the token which the model will try to predict.
443 | clean_text (`bool`, *optional*, defaults to `True`):
444 | Whether or not to clean the text before tokenization by removing any control characters and replacing all
445 | whitespaces by the classic one.
446 | tokenize_chinese_chars (`bool`, *optional*, defaults to `True`):
447 | Whether or not to tokenize Chinese characters. This should likely be deactivated for Japanese (see [this
448 | issue](https://github.com/huggingface/transformers/issues/328)).
449 | strip_accents (`bool`, *optional*):
450 | Whether or not to strip all accents. If this option is not specified, then it will be determined by the
451 | value for `lowercase` (as in the original ERNIE-M).
452 | wordpieces_prefix (`str`, *optional*, defaults to `"##"`):
453 | The prefix for subwords.
454 | """
455 |
456 | vocab_files_names = VOCAB_FILES_NAMES
457 | pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
458 | max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
459 | slow_tokenizer_class = ErnieMTokenizer
460 |
461 | def __init__(
462 | self,
463 | vocab_file=None,
464 | sentencepiece_model_file=None,
465 | tokenizer_file=None,
466 | do_lower_case=True,
467 | unk_token="[UNK]",
468 | sep_token="[SEP]",
469 | pad_token="[PAD]",
470 | cls_token="[CLS]",
471 | mask_token="[MASK]",
472 | tokenize_chinese_chars=True,
473 | strip_accents=None,
474 | **kwargs
475 | ):
476 | super().__init__(
477 | vocab_file,
478 | sentencepiece_model_file,
479 | tokenizer_file=tokenizer_file,
480 | do_lower_case=do_lower_case,
481 | unk_token=unk_token,
482 | sep_token=sep_token,
483 | pad_token=pad_token,
484 | cls_token=cls_token,
485 | mask_token=mask_token,
486 | tokenize_chinese_chars=tokenize_chinese_chars,
487 | strip_accents=strip_accents,
488 | **kwargs,
489 | )
490 |
491 | normalizer_state = json.loads(
492 | self.backend_tokenizer.normalizer.__getstate__())
493 | if (
494 | normalizer_state.get("lowercase", do_lower_case) != do_lower_case
495 | or normalizer_state.get("strip_accents", strip_accents) != strip_accents
496 | or normalizer_state.get("handle_chinese_chars", tokenize_chinese_chars) != tokenize_chinese_chars
497 | ):
498 | normalizer_class = getattr(
499 | normalizers, normalizer_state.pop("type"))
500 | normalizer_state["lowercase"] = do_lower_case
501 | normalizer_state["strip_accents"] = strip_accents
502 | normalizer_state["handle_chinese_chars"] = tokenize_chinese_chars
503 | self.backend_tokenizer.normalizer = normalizer_class(
504 | **normalizer_state)
505 |
506 | self.do_lower_case = do_lower_case
507 |
508 | def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
509 | """
510 | Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
511 | adding special tokens. A ERNIE-M sequence has the following format:
512 | - single sequence: `[CLS] X [SEP]`
513 | - pair of sequences: `[CLS] A [SEP] B [SEP]`
514 | Args:
515 | token_ids_0 (`List[int]`):
516 | List of IDs to which the special tokens will be added.
517 | token_ids_1 (`List[int]`, *optional*):
518 | Optional second list of IDs for sequence pairs.
519 | Returns:
520 | `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens.
521 | """
522 | output = [self.cls_token_id] + token_ids_0 + [self.sep_token_id]
523 |
524 | if token_ids_1:
525 | output += [self.sep_token_id] + token_ids_1 + [self.sep_token_id]
526 |
527 | return output
528 |
529 | def create_token_type_ids_from_sequences(
530 | self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
531 | ) -> List[int]:
532 | """
533 | Create a mask from the two sequences passed to be used in a sequence-pair classification task. A ERNIE-M sequence
534 | pair mask has the following format:
535 | ```
536 | 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1
537 | | first sequence | second sequence |
538 | ```
539 | If `token_ids_1` is `None`, this method only returns the first portion of the mask (0s).
540 | Args:
541 | token_ids_0 (`List[int]`):
542 | List of IDs.
543 | token_ids_1 (`List[int]`, *optional*):
544 | Optional second list of IDs for sequence pairs.
545 | Returns:
546 | `List[int]`: List of [token type IDs](../glossary#token-type-ids) according to the given sequence(s).
547 | """
548 |
549 | if token_ids_1 is None:
550 | return (len(token_ids_0) + 2) * [0]
551 | return [0] * (len(token_ids_0) + 1) + [1] * (len(token_ids_1) + 3)
552 |
553 | def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
554 | files = self._tokenizer.model.save(
555 | save_directory, name=filename_prefix)
556 | return tuple(files)
557 |
558 |
559 | class TokenizerProxy:
560 | def __init__(self, tokenizer):
561 | self._tokenizer = tokenizer
562 | self.no_padding = self._tokenizer.disable_padding
563 | self.no_truncation = self._tokenizer.disable_truncation
564 |
565 | def __getattr__(self, __name: str) -> Any:
566 | return getattr(self._tokenizer, __name)
567 |
568 |
569 | class ErnieMConverter(Converter):
570 | def __init__(self, *args):
571 | requires_backends(self, "protobuf")
572 |
573 | super().__init__(*args)
574 |
575 | from transformers.utils import sentencepiece_model_pb2 as model_pb2
576 |
577 | m = model_pb2.ModelProto()
578 | with open(self.original_tokenizer.sentencepiece_model_file, "rb") as f:
579 | m.ParseFromString(f.read())
580 | self.proto = m
581 |
582 | def vocab(self, proto):
583 | word_score_dict = {}
584 | for piece in proto.pieces:
585 | word_score_dict[piece.piece] = piece.score
586 | vocab_list = [None] * len(self.original_tokenizer.ids_to_tokens)
587 | original_vocab = self.original_tokenizer.vocab
588 | for _token, _id in original_vocab.items():
589 | if _token in word_score_dict:
590 | vocab_list[_id] = (_token, word_score_dict[_token])
591 | else:
592 | vocab_list[_id] = (_token, 0.0)
593 | return vocab_list
594 |
595 | def post_processor(self):
596 | '''
597 | An ERNIE-M sequence has the following format:
598 | - single sequence: ``[CLS] X [SEP]``
599 | - pair of sequences: ``[CLS] A [SEP] [SEP] B [SEP]``
600 | '''
601 | return postprocessors.TemplatePostProcessor(
602 | single="[CLS]:0 $A:0 [SEP]:0",
603 | pair="[CLS]:0 $A:0 [SEP]:0 [SEP]:1 $B:1 [SEP]:1",
604 | special_tokens=[
605 | ("[CLS]",
606 | self.original_tokenizer.convert_tokens_to_ids("[CLS]")),
607 | ("[SEP]",
608 | self.original_tokenizer.convert_tokens_to_ids("[SEP]")),
609 | ],
610 | )
611 |
612 | def normalizer(self, proto):
613 | list_normalizers = []
614 | precompiled_charsmap = proto.normalizer_spec.precompiled_charsmap
615 | list_normalizers.append(
616 | normalizers.PrecompiledNormalizer(precompiled_charsmap))
617 | return normalizers.SequenceNormalizer(list_normalizers)
618 |
619 | def unk_id(self, proto):
620 | return self.original_tokenizer.convert_tokens_to_ids(
621 | str(self.original_tokenizer.unk_token))
622 |
623 | def pre_tokenizer(self, replacement, add_prefix_space):
624 | return pretokenizers.SequencePreTokenizer([
625 | pretokenizers.WhitespacePreTokenizer(),
626 | pretokenizers.MetaSpacePreTokenizer(
627 | replacement=replacement, add_prefix_space=add_prefix_space)
628 | ])
629 |
630 | def converted(self) -> Tokenizer:
631 | tokenizer = self.tokenizer(self.proto)
632 |
633 | SPLICE_UNDERLINE = SPIECE_UNDERLINE
634 | tokenizer.model.set_filter_token(SPLICE_UNDERLINE)
635 | chinese_chars = r"\x{4e00}-\x{9fff}"
636 | punc_chars = r",;:.?!~,;:。?!《》【】"
637 | digits = r"0-9"
638 | tokenizer.model.set_split_rule(
639 | fr"[{chinese_chars}]|[{punc_chars}]|[{digits}]+|[^{chinese_chars}{punc_chars}{digits}]+"
640 | )
641 |
642 | # Tokenizer assemble
643 | tokenizer.normalizer = self.normalizer(self.proto)
644 |
645 | replacement = "▁"
646 | add_prefix_space = True
647 | tokenizer.pretokenizer = self.pre_tokenizer(
648 | replacement, add_prefix_space)
649 |
650 | post_processor = self.post_processor()
651 | if post_processor:
652 | tokenizer.postprocessor = post_processor
653 |
654 | tokenizer = TokenizerProxy(tokenizer)
655 | return tokenizer
656 |
657 | def tokenizer(self, proto):
658 | model_type = proto.trainer_spec.model_type
659 | vocab = self.vocab(proto)
660 | unk_id = self.unk_id(proto)
661 |
662 | if model_type == 1:
663 | tokenizer = Tokenizer(Unigram(vocab, unk_id))
664 | elif model_type == 2:
665 | _, merges = SentencePieceExtractor(
666 | self.original_tokenizer.sentencepiece_model_file).extract()
667 | bpe_vocab = {word: i for i, (word, score) in enumerate(vocab)}
668 | tokenizer = Tokenizer(
669 | BPE(
670 | bpe_vocab,
671 | merges,
672 | unk_token=proto.trainer_spec.unk_piece,
673 | fuse_unk=True,
674 | )
675 | )
676 | else:
677 | raise Exception(
678 | "You're trying to run a `Unigram` model but you're file was trained with a different algorithm"
679 | )
680 |
681 | return tokenizer
682 |
683 |
684 | SLOW_TO_FAST_CONVERTERS["ErnieMTokenizer"] = ErnieMConverter
685 |
--------------------------------------------------------------------------------
/uie_predictor.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022 Heiheiyoyo. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import re
16 | import numpy as np
17 | import six
18 |
19 |
20 | import math
21 | import argparse
22 | import os.path
23 |
24 | from utils import logger, get_bool_ids_greater_than, get_span, get_id_and_prob, cut_chinese_sent, dbc2sbc
25 |
26 |
27 | class ONNXInferBackend(object):
28 | def __init__(self,
29 | model_path_prefix,
30 | device='cpu',
31 | use_fp16=False):
32 | from onnxruntime import InferenceSession, SessionOptions
33 | logger.info(">>> [ONNXInferBackend] Creating Engine ...")
34 | onnx_model = float_onnx_file = os.path.join(
35 | model_path_prefix, "inference.onnx")
36 | if not os.path.exists(onnx_model):
37 | raise OSError(f'{onnx_model} not exists!')
38 | infer_model_dir = model_path_prefix
39 |
40 | if device == "gpu":
41 | providers = ['CUDAExecutionProvider']
42 | logger.info(">>> [ONNXInferBackend] Use GPU to inference ...")
43 | if use_fp16:
44 | logger.info(">>> [ONNXInferBackend] Use FP16 to inference ...")
45 | from onnxconverter_common import float16
46 | import onnx
47 | fp16_model_file = os.path.join(infer_model_dir,
48 | "fp16_model.onnx")
49 | onnx_model = onnx.load_model(float_onnx_file)
50 | trans_model = float16.convert_float_to_float16(
51 | onnx_model, keep_io_types=True)
52 | onnx.save_model(trans_model, fp16_model_file)
53 | onnx_model = fp16_model_file
54 | else:
55 | providers = ['CPUExecutionProvider']
56 | logger.info(">>> [ONNXInferBackend] Use CPU to inference ...")
57 |
58 | sess_options = SessionOptions()
59 | self.predictor = InferenceSession(
60 | onnx_model, sess_options=sess_options, providers=providers)
61 | if device == "gpu":
62 | try:
63 | assert 'CUDAExecutionProvider' in self.predictor.get_providers()
64 | except AssertionError:
65 | raise AssertionError(
66 | f"The environment for GPU inference is not set properly. "
67 | "A possible cause is that you had installed both onnxruntime and onnxruntime-gpu. "
68 | "Please run the following commands to reinstall: \n "
69 | "1) pip uninstall -y onnxruntime onnxruntime-gpu \n 2) pip install onnxruntime-gpu"
70 | )
71 | logger.info(">>> [InferBackend] Engine Created ...")
72 |
73 | def infer(self, input_dict: dict):
74 | result = self.predictor.run(None, dict(input_dict))
75 | return result
76 |
77 |
78 | class PyTorchInferBackend:
79 | def __init__(self,
80 | model_path_prefix,
81 | multilingual=False,
82 | device='cpu',
83 | use_fp16=False):
84 | from model import UIE, UIEM
85 | logger.info(">>> [PyTorchInferBackend] Creating Engine ...")
86 | if multilingual:
87 | self.model = UIEM.from_pretrained(model_path_prefix)
88 | else:
89 | self.model = UIE.from_pretrained(model_path_prefix)
90 | self.model.eval()
91 | self.device = device
92 | if self.device == 'gpu':
93 | logger.info(">>> [PyTorchInferBackend] Use GPU to inference ...")
94 | if use_fp16:
95 | logger.info(
96 | ">>> [PyTorchInferBackend] Use FP16 to inference ...")
97 | self.model = self.model.half()
98 | self.model = self.model.cuda()
99 | else:
100 | logger.info(">>> [PyTorchInferBackend] Use CPU to inference ...")
101 | logger.info(">>> [PyTorchInferBackend] Engine Created ...")
102 |
103 | def infer(self, input_dict):
104 | import torch
105 | for input_name, input_value in input_dict.items():
106 | input_value = torch.LongTensor(input_value)
107 | if self.device == 'gpu':
108 | input_value = input_value.cuda()
109 | input_dict[input_name] = input_value
110 |
111 | outputs = self.model(**input_dict)
112 | start_prob, end_prob = outputs[0], outputs[1]
113 | if self.device == 'gpu':
114 | start_prob, end_prob = start_prob.cpu(), end_prob.cpu()
115 | start_prob = start_prob.detach().numpy()
116 | end_prob = end_prob.detach().numpy()
117 | return start_prob, end_prob
118 |
119 |
120 | class UIEPredictor(object):
121 |
122 | def __init__(self, model, schema, task_path=None, schema_lang="zh", engine='pytorch', device='cpu', position_prob=0.5, max_seq_len=512, batch_size=64, split_sentence=False, use_fp16=False):
123 |
124 | assert isinstance(
125 | device, six.string_types
126 | ), "The type of device must be string."
127 | assert device in [
128 | 'cpu', 'gpu'], "The device must be cpu or gpu."
129 | if model in ['uie-m-base', 'uie-m-large']:
130 | self._multilingual = True
131 | else:
132 | self._multilingual = False
133 | self._model = model
134 | self._engine = engine
135 | self._task_path = task_path
136 | self._device = device
137 | self._position_prob = position_prob
138 | self._max_seq_len = max_seq_len
139 | self._batch_size = batch_size
140 | self._split_sentence = split_sentence
141 | self._use_fp16 = use_fp16
142 |
143 | self._schema_tree = None
144 | self._is_en = True if model in ['uie-base-en'
145 | ] or schema_lang == 'en' else False
146 | self.set_schema(schema)
147 | self._prepare_predictor()
148 |
149 | def _prepare_predictor(self):
150 | assert self._engine in ['pytorch',
151 | 'onnx'], "engine must be pytorch or onnx!"
152 |
153 | if self._task_path is None:
154 | self._task_path = self._model.replace('-', '_')+'_pytorch'
155 | if not os.path.exists(self._task_path):
156 | from convert import check_model, extract_and_convert
157 | check_model(self._model)
158 | extract_and_convert(self._model, self._task_path)
159 |
160 | if self._multilingual:
161 | from tokenizer import ErnieMTokenizerFast
162 | self._tokenizer = ErnieMTokenizerFast.from_pretrained(
163 | self._task_path)
164 | else:
165 | from transformers import BertTokenizerFast
166 | self._tokenizer = BertTokenizerFast.from_pretrained(
167 | self._task_path)
168 |
169 | if self._engine == 'pytorch':
170 | self.inference_backend = PyTorchInferBackend(
171 | self._task_path, multilingual=self._multilingual, device=self._device, use_fp16=self._use_fp16)
172 |
173 | if self._engine == 'onnx':
174 | if os.path.exists(os.path.join(self._task_path, "pytorch_model.bin")) and not os.path.exists(os.path.join(self._task_path, "inference.onnx")):
175 | from export_model import export_onnx
176 | from model import UIE, UIEM
177 | if self._multilingual:
178 | model = UIEM.from_pretrained(self._task_path)
179 | else:
180 | model = UIE.from_pretrained(self._task_path)
181 | input_names = [
182 | 'input_ids',
183 | 'token_type_ids',
184 | 'attention_mask',
185 | ]
186 | output_names = [
187 | 'start_prob',
188 | 'end_prob'
189 | ]
190 | logger.info(
191 | "Converting to the inference model cost a little time.")
192 | save_path = export_onnx(
193 | self._task_path, self._tokenizer, model, 'cpu', input_names, output_names)
194 | logger.info(
195 | "The inference model save in the path:{}".format(save_path))
196 | del model
197 | self.inference_backend = ONNXInferBackend(
198 | self._task_path, device=self._device, use_fp16=self._use_fp16)
199 |
200 | def set_schema(self, schema):
201 | if isinstance(schema, dict) or isinstance(schema, str):
202 | schema = [schema]
203 | self._schema_tree = self._build_tree(schema)
204 |
205 | def __call__(self, inputs):
206 | texts = inputs
207 | if isinstance(texts, str):
208 | texts = [texts]
209 | results = self._multi_stage_predict(texts)
210 | return results
211 |
212 | def _multi_stage_predict(self, datas):
213 | """
214 | Traversal the schema tree and do multi-stage prediction.
215 | Args:
216 | datas (list): a list of strings
217 | Returns:
218 | list: a list of predictions, where the list's length
219 | equals to the length of `datas`
220 | """
221 | results = [{} for _ in range(len(datas))]
222 | # input check to early return
223 | if len(datas) < 1 or self._schema_tree is None:
224 | return results
225 |
226 | # copy to stay `self._schema_tree` unchanged
227 | schema_list = self._schema_tree.children[:]
228 | while len(schema_list) > 0:
229 | node = schema_list.pop(0)
230 | examples = []
231 | input_map = {}
232 | cnt = 0
233 | idx = 0
234 | if not node.prefix:
235 | for data in datas:
236 | examples.append({
237 | "text": data,
238 | "prompt": dbc2sbc(node.name)
239 | })
240 | input_map[cnt] = [idx]
241 | idx += 1
242 | cnt += 1
243 | else:
244 | for pre, data in zip(node.prefix, datas):
245 | if len(pre) == 0:
246 | input_map[cnt] = []
247 | else:
248 | for p in pre:
249 | if self._is_en:
250 | if re.search(r'\[.*?\]$', node.name):
251 | prompt_prefix = node.name[:node.name.find(
252 | "[", 1)].strip()
253 | cls_options = re.search(
254 | r'\[.*?\]$', node.name).group()
255 | # Sentiment classification of xxx [positive, negative]
256 | prompt = prompt_prefix + p + " " + cls_options
257 | else:
258 | prompt = node.name + p
259 | else:
260 | prompt = p + node.name
261 | examples.append({
262 | "text": data,
263 | "prompt": dbc2sbc(prompt)
264 | })
265 | input_map[cnt] = [i + idx for i in range(len(pre))]
266 | idx += len(pre)
267 | cnt += 1
268 | if len(examples) == 0:
269 | result_list = []
270 | else:
271 | result_list = self._single_stage_predict(examples)
272 |
273 | if not node.parent_relations:
274 | relations = [[] for i in range(len(datas))]
275 | for k, v in input_map.items():
276 | for idx in v:
277 | if len(result_list[idx]) == 0:
278 | continue
279 | if node.name not in results[k].keys():
280 | results[k][node.name] = result_list[idx]
281 | else:
282 | results[k][node.name].extend(result_list[idx])
283 | if node.name in results[k].keys():
284 | relations[k].extend(results[k][node.name])
285 | else:
286 | relations = node.parent_relations
287 | for k, v in input_map.items():
288 | for i in range(len(v)):
289 | if len(result_list[v[i]]) == 0:
290 | continue
291 | if "relations" not in relations[k][i].keys():
292 | relations[k][i]["relations"] = {
293 | node.name: result_list[v[i]]
294 | }
295 | elif node.name not in relations[k][i]["relations"].keys(
296 | ):
297 | relations[k][i]["relations"][
298 | node.name] = result_list[v[i]]
299 | else:
300 | relations[k][i]["relations"][node.name].extend(
301 | result_list[v[i]])
302 |
303 | new_relations = [[] for i in range(len(datas))]
304 | for i in range(len(relations)):
305 | for j in range(len(relations[i])):
306 | if "relations" in relations[i][j].keys(
307 | ) and node.name in relations[i][j]["relations"].keys():
308 | for k in range(
309 | len(relations[i][j]["relations"][
310 | node.name])):
311 | new_relations[i].append(relations[i][j][
312 | "relations"][node.name][k])
313 | relations = new_relations
314 |
315 | prefix = [[] for _ in range(len(datas))]
316 | for k, v in input_map.items():
317 | for idx in v:
318 | for i in range(len(result_list[idx])):
319 | if self._is_en:
320 | prefix[k].append(" of " +
321 | result_list[idx][i]["text"])
322 | else:
323 | prefix[k].append(result_list[idx][i]["text"] + "的")
324 |
325 | for child in node.children:
326 | child.prefix = prefix
327 | child.parent_relations = relations
328 | schema_list.append(child)
329 | return results
330 |
331 | def _convert_ids_to_results(self, examples, sentence_ids, probs):
332 | """
333 | Convert ids to raw text in a single stage.
334 | """
335 | results = []
336 | for example, sentence_id, prob in zip(examples, sentence_ids, probs):
337 | if len(sentence_id) == 0:
338 | results.append([])
339 | continue
340 | result_list = []
341 | text = example["text"]
342 | prompt = example["prompt"]
343 | for i in range(len(sentence_id)):
344 | start, end = sentence_id[i]
345 | if start < 0 and end >= 0:
346 | continue
347 | if end < 0:
348 | start += (len(prompt) + 1)
349 | end += (len(prompt) + 1)
350 | result = {"text": prompt[start:end],
351 | "probability": prob[i]}
352 | result_list.append(result)
353 | else:
354 | result = {
355 | "text": text[start:end],
356 | "start": start,
357 | "end": end,
358 | "probability": prob[i]
359 | }
360 | result_list.append(result)
361 | results.append(result_list)
362 | return results
363 |
364 | def _auto_splitter(self, input_texts, max_text_len, split_sentence=False):
365 | '''
366 | Split the raw texts automatically for model inference.
367 | Args:
368 | input_texts (List[str]): input raw texts.
369 | max_text_len (int): cutting length.
370 | split_sentence (bool): If True, sentence-level split will be performed.
371 | return:
372 | short_input_texts (List[str]): the short input texts for model inference.
373 | input_mapping (dict): mapping between raw text and short input texts.
374 | '''
375 | input_mapping = {}
376 | short_input_texts = []
377 | cnt_org = 0
378 | cnt_short = 0
379 | for text in input_texts:
380 | if not split_sentence:
381 | sens = [text]
382 | else:
383 | sens = cut_chinese_sent(text)
384 | for sen in sens:
385 | lens = len(sen)
386 | if lens <= max_text_len:
387 | short_input_texts.append(sen)
388 | if cnt_org not in input_mapping.keys():
389 | input_mapping[cnt_org] = [cnt_short]
390 | else:
391 | input_mapping[cnt_org].append(cnt_short)
392 | cnt_short += 1
393 | else:
394 | temp_text_list = [
395 | sen[i:i + max_text_len]
396 | for i in range(0, lens, max_text_len)
397 | ]
398 | short_input_texts.extend(temp_text_list)
399 | short_idx = cnt_short
400 | cnt_short += math.ceil(lens / max_text_len)
401 | temp_text_id = [
402 | short_idx + i for i in range(cnt_short - short_idx)
403 | ]
404 | if cnt_org not in input_mapping.keys():
405 | input_mapping[cnt_org] = temp_text_id
406 | else:
407 | input_mapping[cnt_org].extend(temp_text_id)
408 | cnt_org += 1
409 | return short_input_texts, input_mapping
410 |
411 | def _single_stage_predict(self, inputs):
412 | input_texts = []
413 | prompts = []
414 | for i in range(len(inputs)):
415 | input_texts.append(inputs[i]["text"])
416 | prompts.append(inputs[i]["prompt"])
417 | # max predict length should exclude the length of prompt and summary tokens
418 | max_predict_len = self._max_seq_len - len(max(prompts)) - 3
419 |
420 | short_input_texts, self.input_mapping = self._auto_splitter(
421 | input_texts, max_predict_len, split_sentence=self._split_sentence)
422 |
423 | short_texts_prompts = []
424 | for k, v in self.input_mapping.items():
425 | short_texts_prompts.extend([prompts[k] for i in range(len(v))])
426 | short_inputs = [{
427 | "text": short_input_texts[i],
428 | "prompt": short_texts_prompts[i]
429 | } for i in range(len(short_input_texts))]
430 |
431 | sentence_ids = []
432 | probs = []
433 |
434 | input_ids = []
435 | token_type_ids = []
436 | attention_mask = []
437 | offset_maps = []
438 |
439 | if self._multilingual:
440 | padding_type = "max_length"
441 | else:
442 | padding_type = "longest"
443 | encoded_inputs = self._tokenizer(
444 | text=short_texts_prompts,
445 | text_pair=short_input_texts,
446 | stride=2,
447 | truncation=True,
448 | max_length=self._max_seq_len,
449 | padding=padding_type,
450 | add_special_tokens=True,
451 | return_offsets_mapping=True,
452 | return_tensors="np")
453 |
454 | start_prob_concat, end_prob_concat = [], []
455 | for batch_start in range(0, len(short_input_texts), self._batch_size):
456 | input_ids = encoded_inputs["input_ids"][batch_start:batch_start+self._batch_size]
457 | token_type_ids = encoded_inputs["token_type_ids"][batch_start:batch_start+self._batch_size]
458 | attention_mask = encoded_inputs["attention_mask"][batch_start:batch_start+self._batch_size]
459 | offset_maps = encoded_inputs["offset_mapping"][batch_start:batch_start+self._batch_size]
460 | if self._multilingual:
461 | input_ids = np.array(
462 | input_ids, dtype="int64")
463 | attention_mask = np.array(
464 | attention_mask, dtype="int64")
465 | position_ids = (np.cumsum(np.ones_like(input_ids), axis=1)
466 | - np.ones_like(input_ids))*attention_mask
467 | input_dict = {
468 | "input_ids": input_ids,
469 | "attention_mask": attention_mask,
470 | "position_ids": position_ids
471 | }
472 | else:
473 | input_dict = {
474 | "input_ids": np.array(
475 | input_ids, dtype="int64"),
476 | "token_type_ids": np.array(
477 | token_type_ids, dtype="int64"),
478 | "attention_mask": np.array(
479 | attention_mask, dtype="int64")
480 | }
481 |
482 | outputs = self.inference_backend.infer(input_dict)
483 | start_prob, end_prob = outputs[0], outputs[1]
484 | start_prob_concat.append(start_prob)
485 | end_prob_concat.append(end_prob)
486 | start_prob_concat = np.concatenate(start_prob_concat)
487 | end_prob_concat = np.concatenate(end_prob_concat)
488 |
489 | start_ids_list = get_bool_ids_greater_than(
490 | start_prob_concat, limit=self._position_prob, return_prob=True)
491 | end_ids_list = get_bool_ids_greater_than(
492 | end_prob_concat, limit=self._position_prob, return_prob=True)
493 |
494 | input_ids = input_dict['input_ids']
495 | sentence_ids = []
496 | probs = []
497 | for start_ids, end_ids, ids, offset_map in zip(start_ids_list,
498 | end_ids_list,
499 | input_ids.tolist(),
500 | offset_maps):
501 | for i in reversed(range(len(ids))):
502 | if ids[i] != 0:
503 | ids = ids[:i]
504 | break
505 | span_list = get_span(start_ids, end_ids, with_prob=True)
506 | sentence_id, prob = get_id_and_prob(span_list, offset_map.tolist())
507 | sentence_ids.append(sentence_id)
508 | probs.append(prob)
509 |
510 | results = self._convert_ids_to_results(short_inputs, sentence_ids,
511 | probs)
512 | results = self._auto_joiner(results, short_input_texts,
513 | self.input_mapping)
514 | return results
515 |
516 | def _auto_joiner(self, short_results, short_inputs, input_mapping):
517 | concat_results = []
518 | is_cls_task = False
519 | for short_result in short_results:
520 | if short_result == []:
521 | continue
522 | elif 'start' not in short_result[0].keys(
523 | ) and 'end' not in short_result[0].keys():
524 | is_cls_task = True
525 | break
526 | else:
527 | break
528 | for k, vs in input_mapping.items():
529 | if is_cls_task:
530 | cls_options = {}
531 | single_results = []
532 | for v in vs:
533 | if len(short_results[v]) == 0:
534 | continue
535 | if short_results[v][0]['text'] not in cls_options.keys():
536 | cls_options[short_results[v][0][
537 | 'text']] = [1, short_results[v][0]['probability']]
538 | else:
539 | cls_options[short_results[v][0]['text']][0] += 1
540 | cls_options[short_results[v][0]['text']][
541 | 1] += short_results[v][0]['probability']
542 | if len(cls_options) != 0:
543 | cls_res, cls_info = max(cls_options.items(),
544 | key=lambda x: x[1])
545 | concat_results.append([{
546 | 'text': cls_res,
547 | 'probability': cls_info[1] / cls_info[0]
548 | }])
549 | else:
550 | concat_results.append([])
551 | else:
552 | offset = 0
553 | single_results = []
554 | for v in vs:
555 | if v == 0:
556 | single_results = short_results[v]
557 | offset += len(short_inputs[v])
558 | else:
559 | for i in range(len(short_results[v])):
560 | if 'start' not in short_results[v][
561 | i] or 'end' not in short_results[v][i]:
562 | continue
563 | short_results[v][i]['start'] += offset
564 | short_results[v][i]['end'] += offset
565 | offset += len(short_inputs[v])
566 | single_results.extend(short_results[v])
567 | concat_results.append(single_results)
568 | return concat_results
569 |
570 | def predict(self, input_data):
571 | results = self._multi_stage_predict(input_data)
572 | return results
573 |
574 | @classmethod
575 | def _build_tree(cls, schema, name='root'):
576 | """
577 | Build the schema tree.
578 | """
579 | schema_tree = SchemaTree(name)
580 | for s in schema:
581 | if isinstance(s, str):
582 | schema_tree.add_child(SchemaTree(s))
583 | elif isinstance(s, dict):
584 | for k, v in s.items():
585 | if isinstance(v, str):
586 | child = [v]
587 | elif isinstance(v, list):
588 | child = v
589 | else:
590 | raise TypeError(
591 | "Invalid schema, value for each key:value pairs should be list or string"
592 | "but {} received".format(type(v)))
593 | schema_tree.add_child(cls._build_tree(child, name=k))
594 | else:
595 | raise TypeError(
596 | "Invalid schema, element should be string or dict, "
597 | "but {} received".format(type(s)))
598 | return schema_tree
599 |
600 |
601 | class SchemaTree(object):
602 | """
603 | Implementataion of SchemaTree
604 | """
605 |
606 | def __init__(self, name='root', children=None):
607 | self.name = name
608 | self.children = []
609 | self.prefix = None
610 | self.parent_relations = None
611 | if children is not None:
612 | for child in children:
613 | self.add_child(child)
614 |
615 | def __repr__(self):
616 | return self.name
617 |
618 | def add_child(self, node):
619 | assert isinstance(
620 | node, SchemaTree
621 | ), "The children of a node should be an instacne of SchemaTree."
622 | self.children.append(node)
623 |
624 |
625 | def parse_args():
626 | parser = argparse.ArgumentParser()
627 | # Required parameters
628 | parser.add_argument(
629 | "-m",
630 | "--model",
631 | type=str,
632 | default='uie-base',
633 | help="The model to be used.", )
634 | parser.add_argument(
635 | "-t",
636 | "--task_path",
637 | type=str,
638 | default=None,
639 | help="The path prefix of custom inference model to be used.", )
640 | parser.add_argument(
641 | "-p",
642 | "--position_prob",
643 | default=0.5,
644 | type=float,
645 | help="Probability threshold for start/end index probabiliry.", )
646 | parser.add_argument(
647 | "--use_fp16",
648 | action='store_true',
649 | help="Whether to use fp16 inference, only takes effect when deploying on gpu.",
650 | )
651 | parser.add_argument(
652 | "--max_seq_len",
653 | default=512,
654 | type=int,
655 | help="The maximum input sequence length. Sequences longer than this will be split automatically.",
656 | )
657 | parser.add_argument(
658 | "-D",
659 | "--device",
660 | choices=['cpu', 'gpu'],
661 | default="gpu",
662 | help="Select which device to run model, defaults to gpu."
663 | )
664 | parser.add_argument(
665 | "-e",
666 | "--engine",
667 | choices=['pytorch', 'onnx'],
668 | default="pytorch",
669 | help="Select which engine to run model, defaults to pytorch."
670 | )
671 | args = parser.parse_args()
672 | return args
673 |
674 |
675 | if __name__ == '__main__':
676 | args = parse_args()
677 | args.schema = ['航母']
678 | args.schema_lang = "en"
679 | uie = UIEPredictor(model=args.model, task_path=args.task_path, schema_lang=args.schema_lang, schema=args.schema, engine=args.engine, device=args.device,
680 | position_prob=args.position_prob, max_seq_len=args.max_seq_len, batch_size=64, split_sentence=False, use_fp16=args.use_fp16)
681 | print(uie("印媒所称的“印度第一艘国产航母”—“维克兰特”号"))
682 |
--------------------------------------------------------------------------------