├── .gitignore
├── LICENSE
├── README.md
├── example.ipynb
├── prompt_uie.py
├── prompt_uie_information_extraction_dataset.py
├── prompt_uie_information_extraction_predictor.py
├── prompt_uie_information_extraction_task.py
├── tokenizer.py
└── utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # prompt_uie_torch
2 |
3 | 基于PaddleNLP开源的抽取式UIE进行医学命名实体识别
4 |
5 | ## 简介
6 |
7 | [UIE(Universal Information Extraction)](https://arxiv.org/pdf/2203.12277.pdf)是Yaojie Lu等人在ACL-2022中提出了通用信息抽取统一框架。PaddleNLP借鉴该论文的方法,基于ERNIE 3.0知识增强预训练模型,开源了[基于Prompt的抽取式UIE](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/model_zoo/uie)。
8 |
9 | 
10 |
11 | 本项目使用torch进行复现微调,并在CMeEE数据集上进行效果测试。本项目仅做了命名实体部分,后续会在ark-nlp项目中加入关系抽取和事件抽取等任务。
12 |
13 | **数据下载**
14 |
15 | * CMeEE:https://tianchi.aliyun.com/dataset/dataDetail?dataId=95414
16 |
17 |
18 | ## 环境
19 |
20 | ```
21 | pip install ark-nlp
22 | pip install pandas
23 | ```
24 |
25 | ## 使用说明
26 |
27 | 项目目录按以下格式设置
28 |
29 | ```shell
30 | │
31 | ├── data # 数据文件夹
32 | │ ├── source_datasets
33 | │ ├── task_datasets
34 | │ └── output_datasets
35 | │
36 | ├── checkpoint # 存放训练好的模型
37 | │ ├── ...
38 | │ └── ...
39 | │
40 | └── example.ipynb # 代码
41 | ```
42 | 下载数据并解压到`data/source_datasets`中,运行`example.ipynb`文件
43 |
44 |
45 | ## 权重文件
46 | 为了方便使用,已将paddle模型的权重转化成huggingface的格式,并上传至huggingface:https://huggingface.co/freedomking/prompt-uie-base
47 |
48 |
49 | ## 效果
50 |
51 | 运行一到两轮后提交至CBLUE进行测评,大概在65-66左右,已高于大部分的基线模型
52 |
53 |
54 | ## Acknowledge
55 |
56 | 感谢PaddleNLP的开源分享
57 |
--------------------------------------------------------------------------------
/example.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": null,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "import warnings\n",
10 | "warnings.filterwarnings(\"ignore\")\n",
11 | "\n",
12 | "import os\n",
13 | "import jieba\n",
14 | "import torch\n",
15 | "import pickle\n",
16 | "import torch.nn as nn\n",
17 | "import torch.optim as optim\n",
18 | "import pandas as pd\n",
19 | "\n",
20 | "from tokenizer import TransfomerTokenizer as Tokenizer\n",
21 | "from utils import convert_ner_task_uie_df\n",
22 | "from prompt_uie import PromptUIE as Module\n",
23 | "from prompt_uie_information_extraction_dataset import PromptUIEDataset as Dataset\n",
24 | "from prompt_uie_information_extraction_task import PromptUIETask as Task\n",
25 | "from ark_nlp.nn import BertConfig as ModuleConfig\n",
26 | "from ark_nlp.factory.optimizer import get_default_bert_optimizer as get_default_model_optimizer"
27 | ]
28 | },
29 | {
30 | "cell_type": "code",
31 | "execution_count": null,
32 | "metadata": {},
33 | "outputs": [],
34 | "source": [
35 | "train_data_path = './data/source_datasets/CMeEE/CMeEE_train.json'\n",
36 | "dev_data_path = './data/source_datasets/CMeEE/CMeEE_dev.json'"
37 | ]
38 | },
39 | {
40 | "cell_type": "code",
41 | "execution_count": null,
42 | "metadata": {},
43 | "outputs": [],
44 | "source": [
45 | "model_path = 'freedomking/prompt-uie-base'"
46 | ]
47 | },
48 | {
49 | "cell_type": "markdown",
50 | "metadata": {},
51 | "source": [
52 | "### 一、数据读入与处理"
53 | ]
54 | },
55 | {
56 | "cell_type": "markdown",
57 | "metadata": {},
58 | "source": [
59 | "#### 1. 数据读入"
60 | ]
61 | },
62 | {
63 | "cell_type": "code",
64 | "execution_count": null,
65 | "metadata": {},
66 | "outputs": [],
67 | "source": [
68 | "train_data_df = pd.read_json(train_data_path)\n",
69 | "dev_data_df = pd.read_json(dev_data_path)"
70 | ]
71 | },
72 | {
73 | "cell_type": "code",
74 | "execution_count": null,
75 | "metadata": {},
76 | "outputs": [],
77 | "source": [
78 | "train_data_df = train_data_df.rename(columns={'entities': 'label'})\n",
79 | "dev_data_df = dev_data_df.rename(columns={'entities': 'label'})"
80 | ]
81 | },
82 | {
83 | "cell_type": "code",
84 | "execution_count": null,
85 | "metadata": {},
86 | "outputs": [],
87 | "source": [
88 | "type2name = {\n",
89 | " 'dis': '疾病',\n",
90 | " 'sym': '临床表现',\n",
91 | " 'pro': '医疗程序',\n",
92 | " 'equ': '医疗设备',\n",
93 | " 'dru': '药物',\n",
94 | " 'ite': '医学检验项目',\n",
95 | " 'bod': '身体',\n",
96 | " 'dep': '科室',\n",
97 | " 'mic': '微生物类'\n",
98 | "}"
99 | ]
100 | },
101 | {
102 | "cell_type": "code",
103 | "execution_count": null,
104 | "metadata": {},
105 | "outputs": [],
106 | "source": [
107 | "def convert_entity_type(labels):\n",
108 | " \n",
109 | " converted_labels = []\n",
110 | " for label in labels:\n",
111 | " converted_labels.append({\n",
112 | " 'start_idx': label['start_idx'],\n",
113 | " 'end_idx': label['end_idx'],\n",
114 | " 'type': type2name[label['type']],\n",
115 | " 'entity': label['entity']\n",
116 | " })\n",
117 | " \n",
118 | " return converted_labels"
119 | ]
120 | },
121 | {
122 | "cell_type": "code",
123 | "execution_count": null,
124 | "metadata": {},
125 | "outputs": [],
126 | "source": [
127 | "train_data_df['label'] = train_data_df['label'].apply(lambda x: convert_entity_type(x))\n",
128 | "dev_data_df['label'] = dev_data_df['label'].apply(lambda x: convert_entity_type(x))"
129 | ]
130 | },
131 | {
132 | "cell_type": "code",
133 | "execution_count": null,
134 | "metadata": {},
135 | "outputs": [],
136 | "source": [
137 | "train_data_df = convert_ner_task_uie_df(train_data_df, negative_ratio=2)"
138 | ]
139 | },
140 | {
141 | "cell_type": "code",
142 | "execution_count": null,
143 | "metadata": {},
144 | "outputs": [],
145 | "source": [
146 | "dev_data_df = convert_ner_task_uie_df(dev_data_df, negative_ratio=0)"
147 | ]
148 | },
149 | {
150 | "cell_type": "code",
151 | "execution_count": null,
152 | "metadata": {},
153 | "outputs": [],
154 | "source": [
155 | "ner_train_dataset = Dataset(train_data_df)\n",
156 | "ner_dev_dataset = Dataset(dev_data_df)"
157 | ]
158 | },
159 | {
160 | "cell_type": "markdown",
161 | "metadata": {},
162 | "source": [
163 | "#### 2. 词典创建和生成分词器"
164 | ]
165 | },
166 | {
167 | "cell_type": "code",
168 | "execution_count": null,
169 | "metadata": {},
170 | "outputs": [],
171 | "source": [
172 | "tokenizer = Tokenizer(vocab=model_path, max_seq_len=100)"
173 | ]
174 | },
175 | {
176 | "cell_type": "markdown",
177 | "metadata": {},
178 | "source": [
179 | "#### 3. ID化"
180 | ]
181 | },
182 | {
183 | "cell_type": "code",
184 | "execution_count": null,
185 | "metadata": {},
186 | "outputs": [],
187 | "source": [
188 | "ner_train_dataset.convert_to_ids(tokenizer)\n",
189 | "ner_dev_dataset.convert_to_ids(tokenizer)"
190 | ]
191 | },
192 | {
193 | "cell_type": "markdown",
194 | "metadata": {},
195 | "source": [
196 | "
\n",
197 | "\n",
198 | "### 二、模型构建"
199 | ]
200 | },
201 | {
202 | "cell_type": "markdown",
203 | "metadata": {},
204 | "source": [
205 | "#### 1. 模型参数设置"
206 | ]
207 | },
208 | {
209 | "cell_type": "code",
210 | "execution_count": null,
211 | "metadata": {},
212 | "outputs": [],
213 | "source": [
214 | "config = ModuleConfig.from_pretrained(model_path)"
215 | ]
216 | },
217 | {
218 | "cell_type": "markdown",
219 | "metadata": {},
220 | "source": [
221 | "#### 2. 模型创建"
222 | ]
223 | },
224 | {
225 | "cell_type": "code",
226 | "execution_count": null,
227 | "metadata": {},
228 | "outputs": [],
229 | "source": [
230 | "torch.cuda.empty_cache()"
231 | ]
232 | },
233 | {
234 | "cell_type": "code",
235 | "execution_count": null,
236 | "metadata": {},
237 | "outputs": [],
238 | "source": [
239 | "dl_module = Module.from_pretrained(model_path, config=config)"
240 | ]
241 | },
242 | {
243 | "cell_type": "markdown",
244 | "metadata": {},
245 | "source": [
246 | "
\n",
247 | "\n",
248 | "### 三、任务构建"
249 | ]
250 | },
251 | {
252 | "cell_type": "markdown",
253 | "metadata": {},
254 | "source": [
255 | "#### 1. 任务参数和必要部件设定"
256 | ]
257 | },
258 | {
259 | "cell_type": "code",
260 | "execution_count": null,
261 | "metadata": {},
262 | "outputs": [],
263 | "source": [
264 | "# 设置运行次数\n",
265 | "num_epoches = 5\n",
266 | "batch_size = 32"
267 | ]
268 | },
269 | {
270 | "cell_type": "code",
271 | "execution_count": null,
272 | "metadata": {},
273 | "outputs": [],
274 | "source": [
275 | "optimizer = get_default_model_optimizer(dl_module)"
276 | ]
277 | },
278 | {
279 | "cell_type": "markdown",
280 | "metadata": {},
281 | "source": [
282 | "#### 2. 任务创建"
283 | ]
284 | },
285 | {
286 | "cell_type": "code",
287 | "execution_count": null,
288 | "metadata": {},
289 | "outputs": [],
290 | "source": [
291 | "model = Task(dl_module, optimizer, None, cuda_device=0)"
292 | ]
293 | },
294 | {
295 | "cell_type": "markdown",
296 | "metadata": {},
297 | "source": [
298 | "#### 3. 训练"
299 | ]
300 | },
301 | {
302 | "cell_type": "code",
303 | "execution_count": null,
304 | "metadata": {},
305 | "outputs": [],
306 | "source": [
307 | "model.fit(ner_train_dataset, \n",
308 | " ner_dev_dataset,\n",
309 | " lr=1e-5,\n",
310 | " epochs=num_epoches, \n",
311 | " batch_size=batch_size\n",
312 | " )"
313 | ]
314 | },
315 | {
316 | "cell_type": "markdown",
317 | "metadata": {},
318 | "source": [
319 | "
\n",
320 | "\n",
321 | "### 四、生成提交数据"
322 | ]
323 | },
324 | {
325 | "cell_type": "code",
326 | "execution_count": null,
327 | "metadata": {},
328 | "outputs": [],
329 | "source": [
330 | "import json\n",
331 | "\n",
332 | "from tqdm import tqdm\n",
333 | "from prompt_uie_information_extraction_predictor import PromptUIEPredictor as Predictor"
334 | ]
335 | },
336 | {
337 | "cell_type": "code",
338 | "execution_count": null,
339 | "metadata": {},
340 | "outputs": [],
341 | "source": [
342 | "ner_predictor_instance = Predictor(model.module, tokenizer)"
343 | ]
344 | },
345 | {
346 | "cell_type": "code",
347 | "execution_count": null,
348 | "metadata": {},
349 | "outputs": [],
350 | "source": [
351 | "test_df = pd.read_json('./data/source_datasets/CMeEE/CMeEE_test.json')\n",
352 | "\n",
353 | "submit = []\n",
354 | "for _text in tqdm(test_df['text'].to_list()):\n",
355 | " \n",
356 | " entities = []\n",
357 | " for source_type, prompt_type in type2name.items():\n",
358 | " \n",
359 | " for entity in ner_predictor_instance.predict_one_sample([_text, prompt_type]):\n",
360 | " \n",
361 | " entities.append({\n",
362 | " 'start_idx': entity['start_idx'],\n",
363 | " 'end_idx': entity['end_idx'],\n",
364 | " 'type': source_type,\n",
365 | " 'entity': entity['entity'],\n",
366 | " })\n",
367 | " \n",
368 | " submit.append({\n",
369 | " 'text': _text,\n",
370 | " 'entities': entities\n",
371 | " })"
372 | ]
373 | },
374 | {
375 | "cell_type": "code",
376 | "execution_count": null,
377 | "metadata": {},
378 | "outputs": [],
379 | "source": [
380 | "output_path = './submit_CMeEE_test.json'\n",
381 | "\n",
382 | "with open(output_path,'w', encoding='utf-8') as f:\n",
383 | " f.write(json.dumps(submit, ensure_ascii=False))"
384 | ]
385 | }
386 | ],
387 | "metadata": {
388 | "kernelspec": {
389 | "display_name": "Python 3 (ipykernel)",
390 | "language": "python",
391 | "name": "python3"
392 | },
393 | "language_info": {
394 | "codemirror_mode": {
395 | "name": "ipython",
396 | "version": 3
397 | },
398 | "file_extension": ".py",
399 | "mimetype": "text/x-python",
400 | "name": "python",
401 | "nbconvert_exporter": "python",
402 | "pygments_lexer": "ipython3",
403 | "version": "3.8.10"
404 | }
405 | },
406 | "nbformat": 4,
407 | "nbformat_minor": 4
408 | }
409 |
--------------------------------------------------------------------------------
/prompt_uie.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022 DataArk Authors. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | #
15 | # Author: Xiang Wang, xiangking1995@163.com
16 | # Status: Active
17 |
18 | import torch
19 | import torch.nn as nn
20 |
21 | from transformers import BertModel
22 | from transformers import BertPreTrainedModel
23 |
24 |
25 | class BertEmbeddings(torch.nn.Module):
26 | """
27 | bert的嵌入层,包含词嵌入、位置嵌入和token类型嵌入
28 | """# noqa: ignore flake8"
29 |
30 | def __init__(self, config):
31 | super().__init__()
32 |
33 | self.use_task_id = config.use_task_id
34 |
35 | # bert的输入分为三部分:词嵌入、位置嵌入和token类型嵌入
36 | # (token类型嵌入用于区分词是属于哪个句子,主要用于N个句子拼接输入的情况)
37 | self.word_embeddings = nn.Embedding(config.vocab_size,
38 | config.hidden_size,
39 | padding_idx=config.pad_token_id)
40 | self.position_embeddings = nn.Embedding(config.max_position_embeddings,
41 | config.hidden_size)
42 | self.token_type_embeddings = nn.Embedding(config.type_vocab_size,
43 | config.hidden_size)
44 |
45 | if self.use_task_id:
46 | self.task_type_embeddings = nn.Embedding(
47 | config.task_type_vocab_size, config.hidden_size)
48 |
49 | self.LayerNorm = nn.LayerNorm(config.hidden_size,
50 | eps=config.layer_norm_eps)
51 | self.dropout = nn.Dropout(config.hidden_dropout_prob)
52 |
53 | # bert的位置嵌入使用的是绝对位置,即从句首开始按自然数进行编码
54 | self.position_embedding_type = getattr(config,
55 | "position_embedding_type",
56 | "absolute")
57 | # 初始化时position_ids按设置中的max_position_embeddings生成,在forward会根据input_ids输入长度进行截断
58 | self.register_buffer(
59 | "position_ids",
60 | torch.arange(config.max_position_embeddings).expand((1, -1)))
61 | # 初始化时token_type_ids按position_ids的size生成,在forward会根据input_ids输入长度进行截断
62 | self.register_buffer(
63 | "token_type_ids",
64 | torch.zeros(self.position_ids.size(), dtype=torch.long),
65 | persistent=False,
66 | )
67 |
68 | if self.use_task_id:
69 | self.register_buffer(
70 | "task_type_ids",
71 | torch.zeros(self.position_ids.size(), dtype=torch.long),
72 | persistent=False,
73 | )
74 |
75 | def forward(self,
76 | input_ids=None,
77 | token_type_ids=None,
78 | position_ids=None,
79 | task_type_ids=None,
80 | **kwargs):
81 | # transformers的库允许不输入input_ids而是输入向量
82 | # 在ark-nlp中不需要对输入向量进行兼容,ark-nlp倾向于用户自己去定义包含该功能的模型
83 | input_shape = input_ids.size()
84 |
85 | seq_length = input_shape[1]
86 |
87 | if position_ids is None:
88 | position_ids = self.position_ids[:, :seq_length]
89 |
90 | if token_type_ids is None:
91 | if hasattr(self, "token_type_ids"):
92 | buffered_token_type_ids = self.token_type_ids[:, :seq_length]
93 | buffered_token_type_ids_expanded = buffered_token_type_ids.expand(
94 | input_shape[0], seq_length)
95 | token_type_ids = buffered_token_type_ids_expanded
96 | else:
97 | token_type_ids = torch.zeros(input_shape,
98 | dtype=torch.long,
99 | device=self.position_ids.device)
100 |
101 | if task_type_ids is None:
102 | if hasattr(self, "task_type_ids"):
103 | buffered_task_type_ids = self.task_type_ids[:, :seq_length]
104 | buffered_task_type_ids_expanded = buffered_task_type_ids.expand(
105 | input_shape[0], seq_length)
106 | task_type_ids = buffered_task_type_ids_expanded
107 | else:
108 | task_type_ids = torch.zeros(input_shape,
109 | dtype=torch.long,
110 | device=self.position_ids.device)
111 |
112 | # 生成词嵌入向量
113 | input_embedings = self.word_embeddings(input_ids)
114 | # 生成token类型嵌入向量
115 | token_type_embeddings = self.token_type_embeddings(token_type_ids)
116 | # 生成位置嵌入向量
117 | # 此处保留transformers里的代码形式,但该判断条件对本部分代码并无实际意义
118 | # 本部分的位置编码仅使用绝对编码
119 | if self.position_embedding_type == "absolute":
120 | position_embeddings = self.position_embeddings(position_ids)
121 |
122 | # 将三个向量相加
123 | embeddings = input_embedings + position_embeddings + token_type_embeddings
124 |
125 | # 生成任务嵌入向量
126 | if self.use_task_id:
127 | task_type_embeddings = self.task_type_embeddings(task_type_ids)
128 | embeddings = embeddings + task_type_embeddings
129 |
130 | embeddings = self.LayerNorm(embeddings)
131 | embeddings = self.dropout(embeddings)
132 |
133 | return embeddings
134 |
135 |
136 | class PromptUIE(BertPreTrainedModel):
137 | """
138 | 通用信息抽取 UIE(Universal Information Extraction), 基于MRC结构实现
139 |
140 | Args:
141 | config: 模型的配置对象
142 | encoder_trained (bool, optional): bert参数是否可训练, 默认值为True
143 |
144 | Reference:
145 | [1] https://github.com/PaddlePaddle/PaddleNLP/tree/develop/model_zoo/uie
146 | """ # noqa: ignore flake8"
147 |
148 | def __init__(self,
149 | config,
150 | encoder_trained=True):
151 | super(PromptUIE, self).__init__(config)
152 | self.bert = BertModel(config)
153 |
154 | self.bert.embeddings = BertEmbeddings(config)
155 |
156 | for param in self.bert.parameters():
157 | param.requires_grad = encoder_trained
158 |
159 | self.num_labels = config.num_labels
160 |
161 | self.start_linear = torch.nn.Linear(config.hidden_size, 1)
162 | self.end_linear = torch.nn.Linear(config.hidden_size, 1)
163 |
164 | self.sigmoid = nn.Sigmoid()
165 |
166 | def forward(self,
167 | input_ids,
168 | token_type_ids=None,
169 | pos_ids=None,
170 | attention_mask=None,
171 | **kwargs):
172 | sequence_feature = self.bert(input_ids,
173 | attention_mask=attention_mask,
174 | token_type_ids=token_type_ids,
175 | return_dict=True,
176 | output_hidden_states=True).hidden_states
177 |
178 | sequence_feature = sequence_feature[-1]
179 |
180 | start_logits = self.start_linear(sequence_feature)
181 | start_logits = torch.squeeze(start_logits, -1)
182 | start_prob = self.sigmoid(start_logits)
183 |
184 | end_logits = self.end_linear(sequence_feature)
185 |
186 | end_logits = torch.squeeze(end_logits, -1)
187 | end_prob = self.sigmoid(end_logits)
188 |
189 | return start_prob, end_prob
190 |
--------------------------------------------------------------------------------
/prompt_uie_information_extraction_dataset.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022 DataArk Authors. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | #
15 | # Author: Xiang Wang, xiangking1995@163.com
16 | # Status: Active
17 |
18 | import torch
19 |
20 | from ark_nlp.dataset import TokenClassificationDataset
21 |
22 |
23 | class PromptUIEDataset(TokenClassificationDataset):
24 | """
25 | 用于通用信息抽取UIE任务的Dataset
26 |
27 | Args:
28 | data (:obj:`DataFrame` or :obj:`string`): 数据或者数据地址
29 | categories (:obj:`list`, optional, defaults to `None`): 数据类别
30 | is_retain_df (:obj:`bool`, optional, defaults to False): 是否将DataFrame格式的原始数据复制到属性retain_df中
31 | is_retain_dataset (:obj:`bool`, optional, defaults to False): 是否将处理成dataset格式的原始数据复制到属性retain_dataset中
32 | is_train (:obj:`bool`, optional, defaults to True): 数据集是否为训练集数据
33 | is_test (:obj:`bool`, optional, defaults to False): 数据集是否为测试集数据
34 | """ # noqa: ignore flake8"
35 |
36 | def _convert_to_transfomer_ids(self, bert_tokenizer):
37 |
38 | features = []
39 | for (index_, row_) in enumerate(self.dataset):
40 |
41 | prompt_tokens = bert_tokenizer.tokenize(row_['condition'])
42 | tokens = bert_tokenizer.tokenize(row_['text'])[:bert_tokenizer.max_seq_len - 3 - len(prompt_tokens)]
43 | token_mapping = bert_tokenizer.get_token_mapping(row_['text'], tokens)
44 |
45 | start_mapping = {j[0]: i for i, j in enumerate(token_mapping) if j}
46 | end_mapping = {j[-1]: i for i, j in enumerate(token_mapping) if j}
47 |
48 | input_ids = bert_tokenizer.sequence_to_ids(prompt_tokens, tokens, truncation_method='last')
49 |
50 | input_ids, input_mask, segment_ids = input_ids
51 |
52 | start_label = torch.zeros((bert_tokenizer.max_seq_len))
53 | end_label = torch.zeros((bert_tokenizer.max_seq_len))
54 |
55 | label_ = set()
56 | for info_ in row_['label']:
57 | if info_['start_idx'] in start_mapping and info_['end_idx'] in end_mapping:
58 | start_idx = start_mapping[info_['start_idx']]
59 | end_idx = end_mapping[info_['end_idx']]
60 | if start_idx > end_idx or info_['entity'] == '':
61 | continue
62 |
63 | start_label[start_idx + 2 + len(prompt_tokens)] = 1
64 | end_label[end_idx + 2 + len(prompt_tokens)] = 1
65 |
66 | label_.add((start_idx + 2 + len(prompt_tokens),
67 | end_idx + 2 + len(prompt_tokens)))
68 |
69 | features.append({
70 | 'input_ids': input_ids,
71 | 'attention_mask': input_mask,
72 | 'token_type_ids': segment_ids,
73 | 'start_label_ids': start_label,
74 | 'end_label_ids': end_label,
75 | 'label_ids': list(label_)
76 | })
77 |
78 | return features
79 |
80 | @property
81 | def to_device_cols(self):
82 | _cols = list(self.dataset[0].keys())
83 | _cols.remove('label_ids')
84 | return _cols
85 |
--------------------------------------------------------------------------------
/prompt_uie_information_extraction_predictor.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021 DataArk Authors. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | #
15 | # Author: Xiang Wang, xiangking1995@163.com
16 | # Status: Active
17 |
18 | import torch
19 |
20 | from utils import get_span
21 | from utils import get_bool_ids_greater_than
22 |
23 |
24 | class PromptUIEPredictor(object):
25 | """
26 | 通用信息抽取UIE的预测器
27 |
28 | Args:
29 | module: 深度学习模型
30 | tokernizer: 分词器
31 | """ # noqa: ignore flake8"
32 |
33 | def __init__(self, module, tokernizer):
34 | self.module = module
35 | self.module.task = 'TokenLevel'
36 |
37 | self.tokenizer = tokernizer
38 | self.device = list(self.module.parameters())[0].device
39 |
40 | def _convert_to_transfomer_ids(self, text, prompt):
41 | tokens = self.tokenizer.tokenize(text)
42 | token_mapping = self.tokenizer.get_token_mapping(text, tokens)
43 |
44 | prompt_tokens = self.tokenizer.tokenize(prompt)
45 |
46 | input_ids = self.tokenizer.sequence_to_ids(prompt_tokens, tokens, truncation_method='last')
47 | input_ids, input_mask, segment_ids = input_ids
48 |
49 | features = {
50 | 'input_ids': input_ids,
51 | 'attention_mask': input_mask,
52 | 'token_type_ids': segment_ids,
53 | }
54 |
55 | return features, token_mapping
56 |
57 | def _get_input_ids(self, text, prompt):
58 | if self.tokenizer.tokenizer_type == 'transfomer':
59 | return self._convert_to_transfomer_ids(text, prompt)
60 | else:
61 | raise ValueError("The tokenizer type does not exist")
62 |
63 | def _get_module_one_sample_inputs(self, features):
64 | return {
65 | col: torch.Tensor(features[col]).type(torch.long).unsqueeze(0).to(self.device)
66 | for col in features
67 | }
68 |
69 | def predict_one_sample(
70 | self,
71 | text,
72 | ):
73 | """
74 | 单样本预测
75 |
76 | Args:
77 | text (:obj:`string`): 输入文本
78 | """ # noqa: ignore flake8"
79 |
80 | text, prompt = text
81 | features, token_mapping = self._get_input_ids(text, prompt)
82 |
83 | self.module.eval()
84 |
85 | with torch.no_grad():
86 | inputs = self._get_module_one_sample_inputs(features)
87 | start_logits, end_logits = self.module(**inputs)
88 |
89 | start_scores = start_logits[0].cpu().numpy()[2 + len(self.tokenizer.tokenize(prompt)):]
90 | end_scores = end_logits[0].cpu().numpy()[2 + len(self.tokenizer.tokenize(prompt)):]
91 |
92 | start_scores = get_bool_ids_greater_than(start_scores)
93 | end_scores = get_bool_ids_greater_than(end_scores)
94 |
95 | entities = []
96 | for span in get_span(start_scores, end_scores):
97 |
98 | if span[0] >= len(token_mapping) or span[-1] >= len(token_mapping):
99 | continue
100 |
101 | entitie_ = {
102 | "start_idx": token_mapping[span[0]][0],
103 | "end_idx": token_mapping[span[-1]][-1],
104 | "type": prompt,
105 | "entity": text[token_mapping[span[0]][0]:token_mapping[span[-1]][-1] + 1]
106 | }
107 | entities.append(entitie_)
108 |
109 | return entities
110 |
--------------------------------------------------------------------------------
/prompt_uie_information_extraction_task.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2020 DataArk Authors. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | #
15 | # Author: Xiang Wang, xiangking1995@163.com
16 | # Status: Active
17 |
18 | import torch
19 | import torch.nn.functional as F
20 |
21 | from collections import Counter
22 | from utils import get_span
23 | from utils import get_bool_ids_greater_than
24 | from ark_nlp.factory.task.base._token_classification import TokenClassificationTask
25 |
26 |
27 | class SpanMetrics(object):
28 |
29 | def __init__(self, id2label=None):
30 | self.id2label = id2label
31 | self.reset()
32 |
33 | def reset(self):
34 | self.origins = []
35 | self.founds = []
36 | self.rights = []
37 |
38 | def compute(self, origin, found, right):
39 | recall = 0 if origin == 0 else (right / origin)
40 | precision = 0 if found == 0 else (right / found)
41 | f1 = 0. if recall + precision == 0 else (2 * precision * recall) / (precision + recall)
42 | return recall, precision, f1
43 |
44 | def result(self):
45 | if self.id2label is not None:
46 | class_info = {}
47 | origin_counter = Counter([self.id2label[x[0]] for x in self.origins])
48 | found_counter = Counter([self.id2label[x[0]] for x in self.founds])
49 | right_counter = Counter([self.id2label[x[0]] for x in self.rights])
50 | for type_, count in origin_counter.items():
51 | origin = count
52 | found = found_counter.get(type_, 0)
53 | right = right_counter.get(type_, 0)
54 | recall, precision, f1 = self.compute(origin, found, right)
55 | class_info[type_] = {"acc": round(precision, 4), 'recall': round(recall, 4), 'f1': round(f1, 4)}
56 |
57 | origin = len(self.origins)
58 | found = len(self.founds)
59 | right = len(self.rights)
60 | recall, precision, f1 = self.compute(origin, found, right)
61 |
62 | if self.id2label is None:
63 | return {'acc': precision, 'recall': recall, 'f1': f1}
64 | else:
65 | return {'acc': precision, 'recall': recall, 'f1': f1}, class_info
66 |
67 | def update(self, true_subject, pred_subject):
68 | self.origins.extend(true_subject)
69 | self.founds.extend(pred_subject)
70 | self.rights.extend([pre_entity for pre_entity in pred_subject if pre_entity in true_subject])
71 |
72 |
73 | class PromptUIETask(TokenClassificationTask):
74 | """
75 | 通用信息抽取UIE的Task
76 |
77 | Args:
78 | module: 深度学习模型
79 | optimizer: 训练模型使用的优化器名或者优化器对象
80 | loss_function: 训练模型使用的损失函数名或损失函数对象
81 | class_num (:obj:`int` or :obj:`None`, optional, defaults to None): 标签数目
82 | scheduler (:obj:`class`, optional, defaults to None): scheduler对象
83 | n_gpu (:obj:`int`, optional, defaults to 1): GPU数目
84 | device (:obj:`class`, optional, defaults to None): torch.device对象,当device为None时,会自动检测是否有GPU
85 | cuda_device (:obj:`int`, optional, defaults to 0): GPU编号,当device为None时,根据cuda_device设置device
86 | ema_decay (:obj:`int` or :obj:`None`, optional, defaults to None): EMA的加权系数
87 | **kwargs (optional): 其他可选参数
88 | """ # noqa: ignore flake8"
89 |
90 | def _get_train_loss(self, inputs, outputs, **kwargs):
91 | loss = self._compute_loss(inputs, outputs, **kwargs)
92 |
93 | self._compute_loss_record(**kwargs)
94 |
95 | return outputs, loss
96 |
97 | def _get_evaluate_loss(self, inputs, outputs, **kwargs):
98 | loss = self._compute_loss(inputs, outputs, **kwargs)
99 | self._compute_loss_record(**kwargs)
100 |
101 | return outputs, loss
102 |
103 | def _compute_loss(self, inputs, logits, verbose=True, **kwargs):
104 | start_logits = logits[0]
105 | end_logits = logits[1]
106 |
107 | start_logits = start_logits.view(-1, 1)
108 | end_logits = end_logits.view(-1, 1)
109 |
110 | active_loss = inputs['attention_mask'].view(-1) == 1
111 |
112 | active_start_logits = start_logits[active_loss]
113 | active_end_logits = end_logits[active_loss]
114 |
115 | active_start_labels = inputs['start_label_ids'].long().view(-1, 1)[active_loss]
116 | active_end_labels = inputs['end_label_ids'].long().view(-1, 1)[active_loss]
117 |
118 | start_loss = F.binary_cross_entropy(active_start_logits,
119 | active_start_labels.to(torch.float),
120 | reduction='none')
121 | start_loss = torch.sum(start_loss * active_loss) / torch.sum(active_loss)
122 |
123 | end_loss = F.binary_cross_entropy(active_end_logits,
124 | active_end_labels.to(torch.float),
125 | reduction='none')
126 | end_loss = torch.sum(end_loss * active_loss) / torch.sum(active_loss)
127 |
128 | loss = (start_loss + end_loss) / 2.0
129 |
130 | return loss
131 |
132 | def _on_evaluate_epoch_begin(self, **kwargs):
133 |
134 | self.metric = SpanMetrics()
135 |
136 | if self.ema_decay:
137 | self.ema.store(self.module.parameters())
138 | self.ema.copy_to(self.module.parameters())
139 |
140 | self._on_epoch_begin_record(**kwargs)
141 |
142 | def _on_evaluate_step_end(self, inputs, logits, **kwargs):
143 |
144 | with torch.no_grad():
145 | # compute loss
146 | logits, loss = self._get_evaluate_loss(inputs, logits, **kwargs)
147 |
148 | S = []
149 | start_logits = logits[0]
150 | end_logits = logits[1]
151 |
152 | start_pred = start_logits.cpu().numpy().tolist()
153 | end_pred = end_logits.cpu().numpy().tolist()
154 |
155 | start_score_list = get_bool_ids_greater_than(start_pred)
156 | end_score_list = get_bool_ids_greater_than(end_pred)
157 |
158 | for index, (start_score, end_score) in enumerate(zip(start_score_list, end_score_list)):
159 | S = get_span(start_score, end_score)
160 | self.metric.update(true_subject=inputs['label_ids'][index], pred_subject=S)
161 |
162 | self.evaluate_logs['eval_example'] += len(inputs['label_ids'])
163 | self.evaluate_logs['eval_step'] += 1
164 | self.evaluate_logs['eval_loss'] += loss.item()
165 |
166 | def _on_evaluate_epoch_end(self,
167 | validation_data,
168 | epoch=1,
169 | is_evaluate_print=True,
170 | **kwargs):
171 |
172 | with torch.no_grad():
173 | eval_info = self.metric.result()
174 |
175 | if is_evaluate_print:
176 | print('eval_info: ', eval_info)
177 |
178 | def _train_collate_fn(self, batch):
179 | """将InputFeatures转换为Tensor"""
180 |
181 | input_ids = torch.tensor([f['input_ids'] for f in batch], dtype=torch.long)
182 | attention_mask = torch.tensor([f['attention_mask'] for f in batch],
183 | dtype=torch.long)
184 | token_type_ids = torch.tensor([f['token_type_ids'] for f in batch],
185 | dtype=torch.long)
186 | start_label_ids = torch.cat([f['start_label_ids'] for f in batch])
187 | end_label_ids = torch.cat([f['end_label_ids'] for f in batch])
188 | label_ids = [f['label_ids'] for f in batch]
189 |
190 | tensors = {
191 | 'input_ids': input_ids,
192 | 'attention_mask': attention_mask,
193 | 'token_type_ids': token_type_ids,
194 | 'start_label_ids': start_label_ids,
195 | 'end_label_ids': end_label_ids,
196 | 'label_ids': label_ids
197 | }
198 |
199 | return tensors
200 |
201 | def _evaluate_collate_fn(self, batch):
202 | return self._train_collate_fn(batch)
203 |
--------------------------------------------------------------------------------
/tokenizer.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2020 DataArk Authors. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | #
15 | # Author: Xiang Wang, xiangking1995@163.com
16 | # Status: Active
17 |
18 |
19 | import warnings
20 | import unicodedata
21 | import transformers
22 | import numpy as np
23 |
24 | from typing import List
25 | from copy import deepcopy
26 | from ark_nlp.processor.tokenizer._tokenizer import BaseTokenizer
27 |
28 |
29 | class TransfomerTokenizer(BaseTokenizer):
30 | """
31 | Transfomer文本编码器,用于对文本进行分词、ID化、填充等操作
32 |
33 | Args:
34 | vocab: transformers词典类对象、词典地址或词典名,用于实现文本分词和ID化
35 | max_seq_len (:obj:`int`): 预设的文本最大长度
36 | """ # noqa: ignore flake8"
37 |
38 | def __init__(
39 | self,
40 | vocab,
41 | max_seq_len
42 | ):
43 | if isinstance(vocab, str):
44 | # TODO: 改成由自定义的字典所决定
45 | vocab = transformers.AutoTokenizer.from_pretrained(vocab)
46 |
47 | self.vocab = vocab
48 | self.max_seq_len = max_seq_len
49 | self.additional_special_tokens = set()
50 | self.tokenizer_type = 'transfomer'
51 |
52 | @staticmethod
53 | def _is_control(ch):
54 | """控制类字符判断
55 | """
56 | return unicodedata.category(ch) in ('Cc', 'Cf')
57 |
58 | @staticmethod
59 | def _is_special(ch):
60 | """判断是不是有特殊含义的符号
61 | """
62 | return bool(ch) and (ch[0] == '[') and (ch[-1] == ']')
63 |
64 | @staticmethod
65 | def recover_bert_token(token):
66 | """获取token的“词干”(如果是##开头,则自动去掉##)
67 | """
68 | if token[:2] == '##':
69 | return token[2:]
70 | else:
71 | return token
72 |
73 | def get_token_mapping(self, text, tokens, is_mapping_index=True):
74 | """给出原始的text和tokenize后的tokens的映射关系"""
75 | raw_text = deepcopy(text)
76 | text = text.lower()
77 |
78 | normalized_text, char_mapping = '', []
79 | for i, ch in enumerate(text):
80 | ch = unicodedata.normalize('NFD', ch)
81 | ch = ''.join([c for c in ch if unicodedata.category(c) != 'Mn'])
82 | ch = ''.join([
83 | c for c in ch
84 | if not (ord(c) == 0 or ord(c) == 0xfffd or self._is_control(c))
85 | ])
86 | normalized_text += ch
87 | char_mapping.extend([i] * len(ch))
88 |
89 | text, token_mapping, offset = normalized_text, [], 0
90 | for token in tokens:
91 | token = token.lower()
92 | if token == '[unk]' or token in self.additional_special_tokens:
93 | if is_mapping_index:
94 | token_mapping.append(char_mapping[offset:offset+1])
95 | else:
96 | token_mapping.append(raw_text[offset:offset+1])
97 | offset = offset + 1
98 | elif self._is_special(token):
99 | # 如果是[CLS]或者是[SEP]之类的词,则没有对应的映射
100 | token_mapping.append([])
101 | else:
102 | token = self.recover_bert_token(token)
103 | start = text[offset:].index(token) + offset
104 | end = start + len(token)
105 | if is_mapping_index:
106 | token_mapping.append(char_mapping[start:end])
107 | else:
108 | token_mapping.append(raw_text[start:end])
109 | offset = end
110 |
111 | return token_mapping
112 |
113 | def sequence_to_ids(self, sequence_a, sequence_b=None, **kwargs):
114 | if sequence_b is None:
115 | return self.sentence_to_ids(sequence_a, **kwargs)
116 | else:
117 | return self.pair_to_ids(sequence_a, sequence_b, **kwargs)
118 |
119 | def sentence_to_ids(self, sequence, return_sequence_length=False):
120 | if type(sequence) == str:
121 | sequence = self.tokenize(sequence)
122 |
123 | if return_sequence_length:
124 | sequence_length = len(sequence)
125 |
126 | # 对超长序列进行截断
127 | if len(sequence) > self.max_seq_len - 2:
128 | sequence = sequence[0:(self.max_seq_len - 2)]
129 | # 分别在首尾拼接特殊符号
130 | sequence = ['[CLS]'] + sequence + ['[SEP]']
131 | segment_ids = [0] * len(sequence)
132 | # ID化
133 | sequence = self.vocab.convert_tokens_to_ids(sequence)
134 |
135 | # 根据max_seq_len与seq的长度产生填充序列
136 | padding = [0] * (self.max_seq_len - len(sequence))
137 | # 创建seq_mask
138 | sequence_mask = [1] * len(sequence) + padding
139 | # 创建seq_segment
140 | segment_ids = segment_ids + padding
141 | # 对seq拼接填充序列
142 | sequence += padding
143 |
144 | sequence = np.asarray(sequence, dtype='int64')
145 | sequence_mask = np.asarray(sequence_mask, dtype='int64')
146 | segment_ids = np.asarray(segment_ids, dtype='int64')
147 |
148 | if return_sequence_length:
149 | return (sequence, sequence_mask, segment_ids, sequence_length)
150 |
151 | return (sequence, sequence_mask, segment_ids)
152 |
153 | def pair_to_ids(
154 | self,
155 | sequence_a,
156 | sequence_b,
157 | return_sequence_length=False,
158 | truncation_method='average'
159 | ):
160 | if type(sequence_a) == str:
161 | sequence_a = self.tokenize(sequence_a)
162 |
163 | if type(sequence_b) == str:
164 | sequence_b = self.tokenize(sequence_b)
165 |
166 | if return_sequence_length:
167 | sequence_length = (len(sequence_a), len(sequence_b))
168 |
169 | # 对超长序列进行截断
170 | if truncation_method == 'average':
171 | if len(sequence_a) > ((self.max_seq_len - 3)//2):
172 | sequence_a = sequence_a[0:(self.max_seq_len - 3)//2]
173 | if len(sequence_b) > ((self.max_seq_len - 3)//2):
174 | sequence_b = sequence_b[0:(self.max_seq_len - 3)//2]
175 | elif truncation_method == 'last':
176 | if len(sequence_b) > (self.max_seq_len - 3 - len(sequence_a)):
177 | sequence_b = sequence_b[0:(self.max_seq_len - 3 - len(sequence_a))]
178 | elif truncation_method == 'first':
179 | if len(sequence_a) > (self.max_seq_len - 3 - len(sequence_b)):
180 | sequence_a = sequence_a[0:(self.max_seq_len - 3 - len(sequence_b))]
181 | else:
182 | raise ValueError("The truncation method does not exist")
183 |
184 | # 分别在首尾拼接特殊符号
185 | sequence = ['[CLS]'] + sequence_a + ['[SEP]'] + sequence_b + ['[SEP]']
186 | segment_ids = [0] * (len(sequence_a) + 2) + [1] * (len(sequence_b) + 1)
187 |
188 | # ID化
189 | sequence = self.vocab.convert_tokens_to_ids(sequence)
190 |
191 | # 根据max_seq_len与seq的长度产生填充序列
192 | padding = [0] * (self.max_seq_len - len(sequence))
193 | # 创建seq_mask
194 | sequence_mask = [1] * len(sequence) + padding
195 | # 创建seq_segment
196 | segment_ids = segment_ids + padding
197 | # 对seq拼接填充序列
198 | sequence += padding
199 |
200 | sequence = np.asarray(sequence, dtype='int64')
201 | sequence_mask = np.asarray(sequence_mask, dtype='int64')
202 | segment_ids = np.asarray(segment_ids, dtype='int64')
203 |
204 | if return_sequence_length:
205 | return (sequence, sequence_mask, segment_ids, sequence_length)
206 |
207 | return (sequence, sequence_mask, segment_ids)
208 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022 DataArk Authors. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | #
15 | # Author: Xiang Wang, xiangking1995@163.com
16 | # Status: Active
17 |
18 | import math
19 | import random
20 | import numpy as np
21 | import pandas as pd
22 |
23 | from collections import defaultdict
24 |
25 |
26 | def convert_ner_task_uie_df(df, negative_ratio=5):
27 |
28 | negative_examples = []
29 | positive_examples = []
30 |
31 | label_type_set = set()
32 |
33 | for labels in df['label']:
34 | for label in labels:
35 | label_type_set.add(label['type'])
36 |
37 | for text, labels in zip(df['text'], df['label']):
38 | type2entities = defaultdict(list)
39 |
40 | for label in labels:
41 | type2entities[label['type']].append(label)
42 |
43 | positive_num = len(type2entities)
44 |
45 | for type_name, entities in type2entities.items():
46 | positive_examples.append({
47 | 'text': text,
48 | 'label': entities,
49 | 'condition': type_name
50 | })
51 |
52 | if negative_ratio == 0:
53 | continue
54 |
55 | redundant_label_type_list = list(label_type_set - set(type2entities.keys()))
56 | redundant_label_type_list.sort()
57 |
58 | # 负样本抽样
59 | if positive_num != 0:
60 | actual_ratio = math.ceil(len(redundant_label_type_list) / positive_num)
61 | else:
62 | positive_num, actual_ratio = 1, 0
63 |
64 | if actual_ratio <= negative_ratio or negative_ratio == -1:
65 | idxs = [k for k in range(len(redundant_label_type_list))]
66 | else:
67 | idxs = random.sample(range(0, len(redundant_label_type_list)),
68 | negative_ratio * positive_num)
69 |
70 | for idx in idxs:
71 | negative_examples.append({
72 | 'text': text,
73 | 'label': [],
74 | 'condition': redundant_label_type_list[idx]
75 | })
76 |
77 | return pd.DataFrame(positive_examples + negative_examples)
78 |
79 |
80 | def get_bool_ids_greater_than(probs, limit=0.5, return_prob=False):
81 | """
82 | 根据概率列表输出span的index列表
83 |
84 | Args:
85 | probs (list):
86 | 概率列表: [[概率]]
87 | 例如: [[0.1, 0.1, 0.2, 0.5, 0.1, 0.3], [0.7, 0.6, 0.1, 0.1, 0.1, 0.1]]
88 | limit (float, optional): 阈值, 默认值为0.5
89 | return_prob (bool, optional): 返回是否带上概率, 默认值为False
90 |
91 | Returns:
92 | list: span的index列表, 例如: [[], [0, 1]]
93 |
94 | Reference:
95 | [1] https://github.com/PaddlePaddle/PaddleNLP/blob/develop/paddlenlp/utils/tools.py
96 | """
97 | probs = np.array(probs)
98 | dim_len = len(probs.shape)
99 | if dim_len > 1:
100 | result = []
101 | for p in probs:
102 | result.append(get_bool_ids_greater_than(p, limit, return_prob))
103 | return result
104 | else:
105 | result = []
106 | for i, p in enumerate(probs):
107 | if p > limit:
108 | if return_prob:
109 | result.append((i, p))
110 | else:
111 | result.append(i)
112 | return result
113 |
114 |
115 | def get_span(start_ids, end_ids, with_prob=False):
116 | """
117 | 根据span的首index列表和尾index列表生成span
118 |
119 | Args:
120 | start_ids (list): span的首index列表: [index], 例如: [1, 2, 10]
121 | end_ids (list): span的尾index列表: [index], 例如: [4, 12]
122 | with_prob (bool, optional): 输入的列表是否带上概率, 默认值为False
123 |
124 | Returns:
125 | set: span列表, 例如: set((2, 4), (10, 12))
126 |
127 | Reference:
128 | [1] https://github.com/PaddlePaddle/PaddleNLP/blob/develop/paddlenlp/utils/tools.py
129 | """
130 | if with_prob:
131 | start_ids = sorted(start_ids, key=lambda x: x[0])
132 | end_ids = sorted(end_ids, key=lambda x: x[0])
133 | else:
134 | start_ids = sorted(start_ids)
135 | end_ids = sorted(end_ids)
136 |
137 | start_pointer = 0
138 | end_pointer = 0
139 | len_start = len(start_ids)
140 | len_end = len(end_ids)
141 | couple_dict = {}
142 | while start_pointer < len_start and end_pointer < len_end:
143 | if with_prob:
144 | if start_ids[start_pointer][0] == end_ids[end_pointer][0]:
145 | couple_dict[end_ids[end_pointer]] = start_ids[start_pointer]
146 | start_pointer += 1
147 | end_pointer += 1
148 | continue
149 | if start_ids[start_pointer][0] < end_ids[end_pointer][0]:
150 | couple_dict[end_ids[end_pointer]] = start_ids[start_pointer]
151 | start_pointer += 1
152 | continue
153 | if start_ids[start_pointer][0] > end_ids[end_pointer][0]:
154 | end_pointer += 1
155 | continue
156 | else:
157 | if start_ids[start_pointer] == end_ids[end_pointer]:
158 | couple_dict[end_ids[end_pointer]] = start_ids[start_pointer]
159 | start_pointer += 1
160 | end_pointer += 1
161 | continue
162 | if start_ids[start_pointer] < end_ids[end_pointer]:
163 | couple_dict[end_ids[end_pointer]] = start_ids[start_pointer]
164 | start_pointer += 1
165 | continue
166 | if start_ids[start_pointer] > end_ids[end_pointer]:
167 | end_pointer += 1
168 | continue
169 | result = [(couple_dict[end], end) for end in couple_dict]
170 | result = set(result)
171 |
172 | return result
173 |
--------------------------------------------------------------------------------