├── .gitignore ├── LICENSE ├── README.md ├── example.ipynb ├── prompt_uie.py ├── prompt_uie_information_extraction_dataset.py ├── prompt_uie_information_extraction_predictor.py ├── prompt_uie_information_extraction_task.py ├── tokenizer.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # prompt_uie_torch 2 | 3 | 基于PaddleNLP开源的抽取式UIE进行医学命名实体识别 4 | 5 | ## 简介 6 | 7 | [UIE(Universal Information Extraction)](https://arxiv.org/pdf/2203.12277.pdf)是Yaojie Lu等人在ACL-2022中提出了通用信息抽取统一框架。PaddleNLP借鉴该论文的方法,基于ERNIE 3.0知识增强预训练模型,开源了[基于Prompt的抽取式UIE](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/model_zoo/uie)。 8 | 9 | ![](https://user-images.githubusercontent.com/40840292/167236006-66ed845d-21b8-4647-908b-e1c6e7613eb1.png) 10 | 11 | 本项目使用torch进行复现微调,并在CMeEE数据集上进行效果测试。本项目仅做了命名实体部分,后续会在ark-nlp项目中加入关系抽取和事件抽取等任务。 12 | 13 | **数据下载** 14 | 15 | * CMeEE:https://tianchi.aliyun.com/dataset/dataDetail?dataId=95414 16 | 17 | 18 | ## 环境 19 | 20 | ``` 21 | pip install ark-nlp 22 | pip install pandas 23 | ``` 24 | 25 | ## 使用说明 26 | 27 | 项目目录按以下格式设置 28 | 29 | ```shell 30 | │ 31 | ├── data # 数据文件夹 32 | │ ├── source_datasets 33 | │ ├── task_datasets 34 | │ └── output_datasets 35 | │ 36 | ├── checkpoint # 存放训练好的模型 37 | │ ├── ... 38 | │ └── ... 39 | │ 40 | └── example.ipynb # 代码 41 | ``` 42 | 下载数据并解压到`data/source_datasets`中,运行`example.ipynb`文件 43 | 44 | 45 | ## 权重文件 46 | 为了方便使用,已将paddle模型的权重转化成huggingface的格式,并上传至huggingface:https://huggingface.co/freedomking/prompt-uie-base 47 | 48 | 49 | ## 效果 50 | 51 | 运行一到两轮后提交至CBLUE进行测评,大概在65-66左右,已高于大部分的基线模型 52 | 53 | 54 | ## Acknowledge 55 | 56 | 感谢PaddleNLP的开源分享 57 | -------------------------------------------------------------------------------- /example.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import warnings\n", 10 | "warnings.filterwarnings(\"ignore\")\n", 11 | "\n", 12 | "import os\n", 13 | "import jieba\n", 14 | "import torch\n", 15 | "import pickle\n", 16 | "import torch.nn as nn\n", 17 | "import torch.optim as optim\n", 18 | "import pandas as pd\n", 19 | "\n", 20 | "from tokenizer import TransfomerTokenizer as Tokenizer\n", 21 | "from utils import convert_ner_task_uie_df\n", 22 | "from prompt_uie import PromptUIE as Module\n", 23 | "from prompt_uie_information_extraction_dataset import PromptUIEDataset as Dataset\n", 24 | "from prompt_uie_information_extraction_task import PromptUIETask as Task\n", 25 | "from ark_nlp.nn import BertConfig as ModuleConfig\n", 26 | "from ark_nlp.factory.optimizer import get_default_bert_optimizer as get_default_model_optimizer" 27 | ] 28 | }, 29 | { 30 | "cell_type": "code", 31 | "execution_count": null, 32 | "metadata": {}, 33 | "outputs": [], 34 | "source": [ 35 | "train_data_path = './data/source_datasets/CMeEE/CMeEE_train.json'\n", 36 | "dev_data_path = './data/source_datasets/CMeEE/CMeEE_dev.json'" 37 | ] 38 | }, 39 | { 40 | "cell_type": "code", 41 | "execution_count": null, 42 | "metadata": {}, 43 | "outputs": [], 44 | "source": [ 45 | "model_path = 'freedomking/prompt-uie-base'" 46 | ] 47 | }, 48 | { 49 | "cell_type": "markdown", 50 | "metadata": {}, 51 | "source": [ 52 | "### 一、数据读入与处理" 53 | ] 54 | }, 55 | { 56 | "cell_type": "markdown", 57 | "metadata": {}, 58 | "source": [ 59 | "#### 1. 数据读入" 60 | ] 61 | }, 62 | { 63 | "cell_type": "code", 64 | "execution_count": null, 65 | "metadata": {}, 66 | "outputs": [], 67 | "source": [ 68 | "train_data_df = pd.read_json(train_data_path)\n", 69 | "dev_data_df = pd.read_json(dev_data_path)" 70 | ] 71 | }, 72 | { 73 | "cell_type": "code", 74 | "execution_count": null, 75 | "metadata": {}, 76 | "outputs": [], 77 | "source": [ 78 | "train_data_df = train_data_df.rename(columns={'entities': 'label'})\n", 79 | "dev_data_df = dev_data_df.rename(columns={'entities': 'label'})" 80 | ] 81 | }, 82 | { 83 | "cell_type": "code", 84 | "execution_count": null, 85 | "metadata": {}, 86 | "outputs": [], 87 | "source": [ 88 | "type2name = {\n", 89 | " 'dis': '疾病',\n", 90 | " 'sym': '临床表现',\n", 91 | " 'pro': '医疗程序',\n", 92 | " 'equ': '医疗设备',\n", 93 | " 'dru': '药物',\n", 94 | " 'ite': '医学检验项目',\n", 95 | " 'bod': '身体',\n", 96 | " 'dep': '科室',\n", 97 | " 'mic': '微生物类'\n", 98 | "}" 99 | ] 100 | }, 101 | { 102 | "cell_type": "code", 103 | "execution_count": null, 104 | "metadata": {}, 105 | "outputs": [], 106 | "source": [ 107 | "def convert_entity_type(labels):\n", 108 | " \n", 109 | " converted_labels = []\n", 110 | " for label in labels:\n", 111 | " converted_labels.append({\n", 112 | " 'start_idx': label['start_idx'],\n", 113 | " 'end_idx': label['end_idx'],\n", 114 | " 'type': type2name[label['type']],\n", 115 | " 'entity': label['entity']\n", 116 | " })\n", 117 | " \n", 118 | " return converted_labels" 119 | ] 120 | }, 121 | { 122 | "cell_type": "code", 123 | "execution_count": null, 124 | "metadata": {}, 125 | "outputs": [], 126 | "source": [ 127 | "train_data_df['label'] = train_data_df['label'].apply(lambda x: convert_entity_type(x))\n", 128 | "dev_data_df['label'] = dev_data_df['label'].apply(lambda x: convert_entity_type(x))" 129 | ] 130 | }, 131 | { 132 | "cell_type": "code", 133 | "execution_count": null, 134 | "metadata": {}, 135 | "outputs": [], 136 | "source": [ 137 | "train_data_df = convert_ner_task_uie_df(train_data_df, negative_ratio=2)" 138 | ] 139 | }, 140 | { 141 | "cell_type": "code", 142 | "execution_count": null, 143 | "metadata": {}, 144 | "outputs": [], 145 | "source": [ 146 | "dev_data_df = convert_ner_task_uie_df(dev_data_df, negative_ratio=0)" 147 | ] 148 | }, 149 | { 150 | "cell_type": "code", 151 | "execution_count": null, 152 | "metadata": {}, 153 | "outputs": [], 154 | "source": [ 155 | "ner_train_dataset = Dataset(train_data_df)\n", 156 | "ner_dev_dataset = Dataset(dev_data_df)" 157 | ] 158 | }, 159 | { 160 | "cell_type": "markdown", 161 | "metadata": {}, 162 | "source": [ 163 | "#### 2. 词典创建和生成分词器" 164 | ] 165 | }, 166 | { 167 | "cell_type": "code", 168 | "execution_count": null, 169 | "metadata": {}, 170 | "outputs": [], 171 | "source": [ 172 | "tokenizer = Tokenizer(vocab=model_path, max_seq_len=100)" 173 | ] 174 | }, 175 | { 176 | "cell_type": "markdown", 177 | "metadata": {}, 178 | "source": [ 179 | "#### 3. ID化" 180 | ] 181 | }, 182 | { 183 | "cell_type": "code", 184 | "execution_count": null, 185 | "metadata": {}, 186 | "outputs": [], 187 | "source": [ 188 | "ner_train_dataset.convert_to_ids(tokenizer)\n", 189 | "ner_dev_dataset.convert_to_ids(tokenizer)" 190 | ] 191 | }, 192 | { 193 | "cell_type": "markdown", 194 | "metadata": {}, 195 | "source": [ 196 | "
\n", 197 | "\n", 198 | "### 二、模型构建" 199 | ] 200 | }, 201 | { 202 | "cell_type": "markdown", 203 | "metadata": {}, 204 | "source": [ 205 | "#### 1. 模型参数设置" 206 | ] 207 | }, 208 | { 209 | "cell_type": "code", 210 | "execution_count": null, 211 | "metadata": {}, 212 | "outputs": [], 213 | "source": [ 214 | "config = ModuleConfig.from_pretrained(model_path)" 215 | ] 216 | }, 217 | { 218 | "cell_type": "markdown", 219 | "metadata": {}, 220 | "source": [ 221 | "#### 2. 模型创建" 222 | ] 223 | }, 224 | { 225 | "cell_type": "code", 226 | "execution_count": null, 227 | "metadata": {}, 228 | "outputs": [], 229 | "source": [ 230 | "torch.cuda.empty_cache()" 231 | ] 232 | }, 233 | { 234 | "cell_type": "code", 235 | "execution_count": null, 236 | "metadata": {}, 237 | "outputs": [], 238 | "source": [ 239 | "dl_module = Module.from_pretrained(model_path, config=config)" 240 | ] 241 | }, 242 | { 243 | "cell_type": "markdown", 244 | "metadata": {}, 245 | "source": [ 246 | "
\n", 247 | "\n", 248 | "### 三、任务构建" 249 | ] 250 | }, 251 | { 252 | "cell_type": "markdown", 253 | "metadata": {}, 254 | "source": [ 255 | "#### 1. 任务参数和必要部件设定" 256 | ] 257 | }, 258 | { 259 | "cell_type": "code", 260 | "execution_count": null, 261 | "metadata": {}, 262 | "outputs": [], 263 | "source": [ 264 | "# 设置运行次数\n", 265 | "num_epoches = 5\n", 266 | "batch_size = 32" 267 | ] 268 | }, 269 | { 270 | "cell_type": "code", 271 | "execution_count": null, 272 | "metadata": {}, 273 | "outputs": [], 274 | "source": [ 275 | "optimizer = get_default_model_optimizer(dl_module)" 276 | ] 277 | }, 278 | { 279 | "cell_type": "markdown", 280 | "metadata": {}, 281 | "source": [ 282 | "#### 2. 任务创建" 283 | ] 284 | }, 285 | { 286 | "cell_type": "code", 287 | "execution_count": null, 288 | "metadata": {}, 289 | "outputs": [], 290 | "source": [ 291 | "model = Task(dl_module, optimizer, None, cuda_device=0)" 292 | ] 293 | }, 294 | { 295 | "cell_type": "markdown", 296 | "metadata": {}, 297 | "source": [ 298 | "#### 3. 训练" 299 | ] 300 | }, 301 | { 302 | "cell_type": "code", 303 | "execution_count": null, 304 | "metadata": {}, 305 | "outputs": [], 306 | "source": [ 307 | "model.fit(ner_train_dataset, \n", 308 | " ner_dev_dataset,\n", 309 | " lr=1e-5,\n", 310 | " epochs=num_epoches, \n", 311 | " batch_size=batch_size\n", 312 | " )" 313 | ] 314 | }, 315 | { 316 | "cell_type": "markdown", 317 | "metadata": {}, 318 | "source": [ 319 | "
\n", 320 | "\n", 321 | "### 四、生成提交数据" 322 | ] 323 | }, 324 | { 325 | "cell_type": "code", 326 | "execution_count": null, 327 | "metadata": {}, 328 | "outputs": [], 329 | "source": [ 330 | "import json\n", 331 | "\n", 332 | "from tqdm import tqdm\n", 333 | "from prompt_uie_information_extraction_predictor import PromptUIEPredictor as Predictor" 334 | ] 335 | }, 336 | { 337 | "cell_type": "code", 338 | "execution_count": null, 339 | "metadata": {}, 340 | "outputs": [], 341 | "source": [ 342 | "ner_predictor_instance = Predictor(model.module, tokenizer)" 343 | ] 344 | }, 345 | { 346 | "cell_type": "code", 347 | "execution_count": null, 348 | "metadata": {}, 349 | "outputs": [], 350 | "source": [ 351 | "test_df = pd.read_json('./data/source_datasets/CMeEE/CMeEE_test.json')\n", 352 | "\n", 353 | "submit = []\n", 354 | "for _text in tqdm(test_df['text'].to_list()):\n", 355 | " \n", 356 | " entities = []\n", 357 | " for source_type, prompt_type in type2name.items():\n", 358 | " \n", 359 | " for entity in ner_predictor_instance.predict_one_sample([_text, prompt_type]):\n", 360 | " \n", 361 | " entities.append({\n", 362 | " 'start_idx': entity['start_idx'],\n", 363 | " 'end_idx': entity['end_idx'],\n", 364 | " 'type': source_type,\n", 365 | " 'entity': entity['entity'],\n", 366 | " })\n", 367 | " \n", 368 | " submit.append({\n", 369 | " 'text': _text,\n", 370 | " 'entities': entities\n", 371 | " })" 372 | ] 373 | }, 374 | { 375 | "cell_type": "code", 376 | "execution_count": null, 377 | "metadata": {}, 378 | "outputs": [], 379 | "source": [ 380 | "output_path = './submit_CMeEE_test.json'\n", 381 | "\n", 382 | "with open(output_path,'w', encoding='utf-8') as f:\n", 383 | " f.write(json.dumps(submit, ensure_ascii=False))" 384 | ] 385 | } 386 | ], 387 | "metadata": { 388 | "kernelspec": { 389 | "display_name": "Python 3 (ipykernel)", 390 | "language": "python", 391 | "name": "python3" 392 | }, 393 | "language_info": { 394 | "codemirror_mode": { 395 | "name": "ipython", 396 | "version": 3 397 | }, 398 | "file_extension": ".py", 399 | "mimetype": "text/x-python", 400 | "name": "python", 401 | "nbconvert_exporter": "python", 402 | "pygments_lexer": "ipython3", 403 | "version": "3.8.10" 404 | } 405 | }, 406 | "nbformat": 4, 407 | "nbformat_minor": 4 408 | } 409 | -------------------------------------------------------------------------------- /prompt_uie.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 DataArk Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | # Author: Xiang Wang, xiangking1995@163.com 16 | # Status: Active 17 | 18 | import torch 19 | import torch.nn as nn 20 | 21 | from transformers import BertModel 22 | from transformers import BertPreTrainedModel 23 | 24 | 25 | class BertEmbeddings(torch.nn.Module): 26 | """ 27 | bert的嵌入层,包含词嵌入、位置嵌入和token类型嵌入 28 | """# noqa: ignore flake8" 29 | 30 | def __init__(self, config): 31 | super().__init__() 32 | 33 | self.use_task_id = config.use_task_id 34 | 35 | # bert的输入分为三部分:词嵌入、位置嵌入和token类型嵌入 36 | # (token类型嵌入用于区分词是属于哪个句子,主要用于N个句子拼接输入的情况) 37 | self.word_embeddings = nn.Embedding(config.vocab_size, 38 | config.hidden_size, 39 | padding_idx=config.pad_token_id) 40 | self.position_embeddings = nn.Embedding(config.max_position_embeddings, 41 | config.hidden_size) 42 | self.token_type_embeddings = nn.Embedding(config.type_vocab_size, 43 | config.hidden_size) 44 | 45 | if self.use_task_id: 46 | self.task_type_embeddings = nn.Embedding( 47 | config.task_type_vocab_size, config.hidden_size) 48 | 49 | self.LayerNorm = nn.LayerNorm(config.hidden_size, 50 | eps=config.layer_norm_eps) 51 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 52 | 53 | # bert的位置嵌入使用的是绝对位置,即从句首开始按自然数进行编码 54 | self.position_embedding_type = getattr(config, 55 | "position_embedding_type", 56 | "absolute") 57 | # 初始化时position_ids按设置中的max_position_embeddings生成,在forward会根据input_ids输入长度进行截断 58 | self.register_buffer( 59 | "position_ids", 60 | torch.arange(config.max_position_embeddings).expand((1, -1))) 61 | # 初始化时token_type_ids按position_ids的size生成,在forward会根据input_ids输入长度进行截断 62 | self.register_buffer( 63 | "token_type_ids", 64 | torch.zeros(self.position_ids.size(), dtype=torch.long), 65 | persistent=False, 66 | ) 67 | 68 | if self.use_task_id: 69 | self.register_buffer( 70 | "task_type_ids", 71 | torch.zeros(self.position_ids.size(), dtype=torch.long), 72 | persistent=False, 73 | ) 74 | 75 | def forward(self, 76 | input_ids=None, 77 | token_type_ids=None, 78 | position_ids=None, 79 | task_type_ids=None, 80 | **kwargs): 81 | # transformers的库允许不输入input_ids而是输入向量 82 | # 在ark-nlp中不需要对输入向量进行兼容,ark-nlp倾向于用户自己去定义包含该功能的模型 83 | input_shape = input_ids.size() 84 | 85 | seq_length = input_shape[1] 86 | 87 | if position_ids is None: 88 | position_ids = self.position_ids[:, :seq_length] 89 | 90 | if token_type_ids is None: 91 | if hasattr(self, "token_type_ids"): 92 | buffered_token_type_ids = self.token_type_ids[:, :seq_length] 93 | buffered_token_type_ids_expanded = buffered_token_type_ids.expand( 94 | input_shape[0], seq_length) 95 | token_type_ids = buffered_token_type_ids_expanded 96 | else: 97 | token_type_ids = torch.zeros(input_shape, 98 | dtype=torch.long, 99 | device=self.position_ids.device) 100 | 101 | if task_type_ids is None: 102 | if hasattr(self, "task_type_ids"): 103 | buffered_task_type_ids = self.task_type_ids[:, :seq_length] 104 | buffered_task_type_ids_expanded = buffered_task_type_ids.expand( 105 | input_shape[0], seq_length) 106 | task_type_ids = buffered_task_type_ids_expanded 107 | else: 108 | task_type_ids = torch.zeros(input_shape, 109 | dtype=torch.long, 110 | device=self.position_ids.device) 111 | 112 | # 生成词嵌入向量 113 | input_embedings = self.word_embeddings(input_ids) 114 | # 生成token类型嵌入向量 115 | token_type_embeddings = self.token_type_embeddings(token_type_ids) 116 | # 生成位置嵌入向量 117 | # 此处保留transformers里的代码形式,但该判断条件对本部分代码并无实际意义 118 | # 本部分的位置编码仅使用绝对编码 119 | if self.position_embedding_type == "absolute": 120 | position_embeddings = self.position_embeddings(position_ids) 121 | 122 | # 将三个向量相加 123 | embeddings = input_embedings + position_embeddings + token_type_embeddings 124 | 125 | # 生成任务嵌入向量 126 | if self.use_task_id: 127 | task_type_embeddings = self.task_type_embeddings(task_type_ids) 128 | embeddings = embeddings + task_type_embeddings 129 | 130 | embeddings = self.LayerNorm(embeddings) 131 | embeddings = self.dropout(embeddings) 132 | 133 | return embeddings 134 | 135 | 136 | class PromptUIE(BertPreTrainedModel): 137 | """ 138 | 通用信息抽取 UIE(Universal Information Extraction), 基于MRC结构实现 139 | 140 | Args: 141 | config: 模型的配置对象 142 | encoder_trained (bool, optional): bert参数是否可训练, 默认值为True 143 | 144 | Reference: 145 | [1] https://github.com/PaddlePaddle/PaddleNLP/tree/develop/model_zoo/uie 146 | """ # noqa: ignore flake8" 147 | 148 | def __init__(self, 149 | config, 150 | encoder_trained=True): 151 | super(PromptUIE, self).__init__(config) 152 | self.bert = BertModel(config) 153 | 154 | self.bert.embeddings = BertEmbeddings(config) 155 | 156 | for param in self.bert.parameters(): 157 | param.requires_grad = encoder_trained 158 | 159 | self.num_labels = config.num_labels 160 | 161 | self.start_linear = torch.nn.Linear(config.hidden_size, 1) 162 | self.end_linear = torch.nn.Linear(config.hidden_size, 1) 163 | 164 | self.sigmoid = nn.Sigmoid() 165 | 166 | def forward(self, 167 | input_ids, 168 | token_type_ids=None, 169 | pos_ids=None, 170 | attention_mask=None, 171 | **kwargs): 172 | sequence_feature = self.bert(input_ids, 173 | attention_mask=attention_mask, 174 | token_type_ids=token_type_ids, 175 | return_dict=True, 176 | output_hidden_states=True).hidden_states 177 | 178 | sequence_feature = sequence_feature[-1] 179 | 180 | start_logits = self.start_linear(sequence_feature) 181 | start_logits = torch.squeeze(start_logits, -1) 182 | start_prob = self.sigmoid(start_logits) 183 | 184 | end_logits = self.end_linear(sequence_feature) 185 | 186 | end_logits = torch.squeeze(end_logits, -1) 187 | end_prob = self.sigmoid(end_logits) 188 | 189 | return start_prob, end_prob 190 | -------------------------------------------------------------------------------- /prompt_uie_information_extraction_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 DataArk Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | # Author: Xiang Wang, xiangking1995@163.com 16 | # Status: Active 17 | 18 | import torch 19 | 20 | from ark_nlp.dataset import TokenClassificationDataset 21 | 22 | 23 | class PromptUIEDataset(TokenClassificationDataset): 24 | """ 25 | 用于通用信息抽取UIE任务的Dataset 26 | 27 | Args: 28 | data (:obj:`DataFrame` or :obj:`string`): 数据或者数据地址 29 | categories (:obj:`list`, optional, defaults to `None`): 数据类别 30 | is_retain_df (:obj:`bool`, optional, defaults to False): 是否将DataFrame格式的原始数据复制到属性retain_df中 31 | is_retain_dataset (:obj:`bool`, optional, defaults to False): 是否将处理成dataset格式的原始数据复制到属性retain_dataset中 32 | is_train (:obj:`bool`, optional, defaults to True): 数据集是否为训练集数据 33 | is_test (:obj:`bool`, optional, defaults to False): 数据集是否为测试集数据 34 | """ # noqa: ignore flake8" 35 | 36 | def _convert_to_transfomer_ids(self, bert_tokenizer): 37 | 38 | features = [] 39 | for (index_, row_) in enumerate(self.dataset): 40 | 41 | prompt_tokens = bert_tokenizer.tokenize(row_['condition']) 42 | tokens = bert_tokenizer.tokenize(row_['text'])[:bert_tokenizer.max_seq_len - 3 - len(prompt_tokens)] 43 | token_mapping = bert_tokenizer.get_token_mapping(row_['text'], tokens) 44 | 45 | start_mapping = {j[0]: i for i, j in enumerate(token_mapping) if j} 46 | end_mapping = {j[-1]: i for i, j in enumerate(token_mapping) if j} 47 | 48 | input_ids = bert_tokenizer.sequence_to_ids(prompt_tokens, tokens, truncation_method='last') 49 | 50 | input_ids, input_mask, segment_ids = input_ids 51 | 52 | start_label = torch.zeros((bert_tokenizer.max_seq_len)) 53 | end_label = torch.zeros((bert_tokenizer.max_seq_len)) 54 | 55 | label_ = set() 56 | for info_ in row_['label']: 57 | if info_['start_idx'] in start_mapping and info_['end_idx'] in end_mapping: 58 | start_idx = start_mapping[info_['start_idx']] 59 | end_idx = end_mapping[info_['end_idx']] 60 | if start_idx > end_idx or info_['entity'] == '': 61 | continue 62 | 63 | start_label[start_idx + 2 + len(prompt_tokens)] = 1 64 | end_label[end_idx + 2 + len(prompt_tokens)] = 1 65 | 66 | label_.add((start_idx + 2 + len(prompt_tokens), 67 | end_idx + 2 + len(prompt_tokens))) 68 | 69 | features.append({ 70 | 'input_ids': input_ids, 71 | 'attention_mask': input_mask, 72 | 'token_type_ids': segment_ids, 73 | 'start_label_ids': start_label, 74 | 'end_label_ids': end_label, 75 | 'label_ids': list(label_) 76 | }) 77 | 78 | return features 79 | 80 | @property 81 | def to_device_cols(self): 82 | _cols = list(self.dataset[0].keys()) 83 | _cols.remove('label_ids') 84 | return _cols 85 | -------------------------------------------------------------------------------- /prompt_uie_information_extraction_predictor.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 DataArk Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | # Author: Xiang Wang, xiangking1995@163.com 16 | # Status: Active 17 | 18 | import torch 19 | 20 | from utils import get_span 21 | from utils import get_bool_ids_greater_than 22 | 23 | 24 | class PromptUIEPredictor(object): 25 | """ 26 | 通用信息抽取UIE的预测器 27 | 28 | Args: 29 | module: 深度学习模型 30 | tokernizer: 分词器 31 | """ # noqa: ignore flake8" 32 | 33 | def __init__(self, module, tokernizer): 34 | self.module = module 35 | self.module.task = 'TokenLevel' 36 | 37 | self.tokenizer = tokernizer 38 | self.device = list(self.module.parameters())[0].device 39 | 40 | def _convert_to_transfomer_ids(self, text, prompt): 41 | tokens = self.tokenizer.tokenize(text) 42 | token_mapping = self.tokenizer.get_token_mapping(text, tokens) 43 | 44 | prompt_tokens = self.tokenizer.tokenize(prompt) 45 | 46 | input_ids = self.tokenizer.sequence_to_ids(prompt_tokens, tokens, truncation_method='last') 47 | input_ids, input_mask, segment_ids = input_ids 48 | 49 | features = { 50 | 'input_ids': input_ids, 51 | 'attention_mask': input_mask, 52 | 'token_type_ids': segment_ids, 53 | } 54 | 55 | return features, token_mapping 56 | 57 | def _get_input_ids(self, text, prompt): 58 | if self.tokenizer.tokenizer_type == 'transfomer': 59 | return self._convert_to_transfomer_ids(text, prompt) 60 | else: 61 | raise ValueError("The tokenizer type does not exist") 62 | 63 | def _get_module_one_sample_inputs(self, features): 64 | return { 65 | col: torch.Tensor(features[col]).type(torch.long).unsqueeze(0).to(self.device) 66 | for col in features 67 | } 68 | 69 | def predict_one_sample( 70 | self, 71 | text, 72 | ): 73 | """ 74 | 单样本预测 75 | 76 | Args: 77 | text (:obj:`string`): 输入文本 78 | """ # noqa: ignore flake8" 79 | 80 | text, prompt = text 81 | features, token_mapping = self._get_input_ids(text, prompt) 82 | 83 | self.module.eval() 84 | 85 | with torch.no_grad(): 86 | inputs = self._get_module_one_sample_inputs(features) 87 | start_logits, end_logits = self.module(**inputs) 88 | 89 | start_scores = start_logits[0].cpu().numpy()[2 + len(self.tokenizer.tokenize(prompt)):] 90 | end_scores = end_logits[0].cpu().numpy()[2 + len(self.tokenizer.tokenize(prompt)):] 91 | 92 | start_scores = get_bool_ids_greater_than(start_scores) 93 | end_scores = get_bool_ids_greater_than(end_scores) 94 | 95 | entities = [] 96 | for span in get_span(start_scores, end_scores): 97 | 98 | if span[0] >= len(token_mapping) or span[-1] >= len(token_mapping): 99 | continue 100 | 101 | entitie_ = { 102 | "start_idx": token_mapping[span[0]][0], 103 | "end_idx": token_mapping[span[-1]][-1], 104 | "type": prompt, 105 | "entity": text[token_mapping[span[0]][0]:token_mapping[span[-1]][-1] + 1] 106 | } 107 | entities.append(entitie_) 108 | 109 | return entities 110 | -------------------------------------------------------------------------------- /prompt_uie_information_extraction_task.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020 DataArk Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | # Author: Xiang Wang, xiangking1995@163.com 16 | # Status: Active 17 | 18 | import torch 19 | import torch.nn.functional as F 20 | 21 | from collections import Counter 22 | from utils import get_span 23 | from utils import get_bool_ids_greater_than 24 | from ark_nlp.factory.task.base._token_classification import TokenClassificationTask 25 | 26 | 27 | class SpanMetrics(object): 28 | 29 | def __init__(self, id2label=None): 30 | self.id2label = id2label 31 | self.reset() 32 | 33 | def reset(self): 34 | self.origins = [] 35 | self.founds = [] 36 | self.rights = [] 37 | 38 | def compute(self, origin, found, right): 39 | recall = 0 if origin == 0 else (right / origin) 40 | precision = 0 if found == 0 else (right / found) 41 | f1 = 0. if recall + precision == 0 else (2 * precision * recall) / (precision + recall) 42 | return recall, precision, f1 43 | 44 | def result(self): 45 | if self.id2label is not None: 46 | class_info = {} 47 | origin_counter = Counter([self.id2label[x[0]] for x in self.origins]) 48 | found_counter = Counter([self.id2label[x[0]] for x in self.founds]) 49 | right_counter = Counter([self.id2label[x[0]] for x in self.rights]) 50 | for type_, count in origin_counter.items(): 51 | origin = count 52 | found = found_counter.get(type_, 0) 53 | right = right_counter.get(type_, 0) 54 | recall, precision, f1 = self.compute(origin, found, right) 55 | class_info[type_] = {"acc": round(precision, 4), 'recall': round(recall, 4), 'f1': round(f1, 4)} 56 | 57 | origin = len(self.origins) 58 | found = len(self.founds) 59 | right = len(self.rights) 60 | recall, precision, f1 = self.compute(origin, found, right) 61 | 62 | if self.id2label is None: 63 | return {'acc': precision, 'recall': recall, 'f1': f1} 64 | else: 65 | return {'acc': precision, 'recall': recall, 'f1': f1}, class_info 66 | 67 | def update(self, true_subject, pred_subject): 68 | self.origins.extend(true_subject) 69 | self.founds.extend(pred_subject) 70 | self.rights.extend([pre_entity for pre_entity in pred_subject if pre_entity in true_subject]) 71 | 72 | 73 | class PromptUIETask(TokenClassificationTask): 74 | """ 75 | 通用信息抽取UIE的Task 76 | 77 | Args: 78 | module: 深度学习模型 79 | optimizer: 训练模型使用的优化器名或者优化器对象 80 | loss_function: 训练模型使用的损失函数名或损失函数对象 81 | class_num (:obj:`int` or :obj:`None`, optional, defaults to None): 标签数目 82 | scheduler (:obj:`class`, optional, defaults to None): scheduler对象 83 | n_gpu (:obj:`int`, optional, defaults to 1): GPU数目 84 | device (:obj:`class`, optional, defaults to None): torch.device对象,当device为None时,会自动检测是否有GPU 85 | cuda_device (:obj:`int`, optional, defaults to 0): GPU编号,当device为None时,根据cuda_device设置device 86 | ema_decay (:obj:`int` or :obj:`None`, optional, defaults to None): EMA的加权系数 87 | **kwargs (optional): 其他可选参数 88 | """ # noqa: ignore flake8" 89 | 90 | def _get_train_loss(self, inputs, outputs, **kwargs): 91 | loss = self._compute_loss(inputs, outputs, **kwargs) 92 | 93 | self._compute_loss_record(**kwargs) 94 | 95 | return outputs, loss 96 | 97 | def _get_evaluate_loss(self, inputs, outputs, **kwargs): 98 | loss = self._compute_loss(inputs, outputs, **kwargs) 99 | self._compute_loss_record(**kwargs) 100 | 101 | return outputs, loss 102 | 103 | def _compute_loss(self, inputs, logits, verbose=True, **kwargs): 104 | start_logits = logits[0] 105 | end_logits = logits[1] 106 | 107 | start_logits = start_logits.view(-1, 1) 108 | end_logits = end_logits.view(-1, 1) 109 | 110 | active_loss = inputs['attention_mask'].view(-1) == 1 111 | 112 | active_start_logits = start_logits[active_loss] 113 | active_end_logits = end_logits[active_loss] 114 | 115 | active_start_labels = inputs['start_label_ids'].long().view(-1, 1)[active_loss] 116 | active_end_labels = inputs['end_label_ids'].long().view(-1, 1)[active_loss] 117 | 118 | start_loss = F.binary_cross_entropy(active_start_logits, 119 | active_start_labels.to(torch.float), 120 | reduction='none') 121 | start_loss = torch.sum(start_loss * active_loss) / torch.sum(active_loss) 122 | 123 | end_loss = F.binary_cross_entropy(active_end_logits, 124 | active_end_labels.to(torch.float), 125 | reduction='none') 126 | end_loss = torch.sum(end_loss * active_loss) / torch.sum(active_loss) 127 | 128 | loss = (start_loss + end_loss) / 2.0 129 | 130 | return loss 131 | 132 | def _on_evaluate_epoch_begin(self, **kwargs): 133 | 134 | self.metric = SpanMetrics() 135 | 136 | if self.ema_decay: 137 | self.ema.store(self.module.parameters()) 138 | self.ema.copy_to(self.module.parameters()) 139 | 140 | self._on_epoch_begin_record(**kwargs) 141 | 142 | def _on_evaluate_step_end(self, inputs, logits, **kwargs): 143 | 144 | with torch.no_grad(): 145 | # compute loss 146 | logits, loss = self._get_evaluate_loss(inputs, logits, **kwargs) 147 | 148 | S = [] 149 | start_logits = logits[0] 150 | end_logits = logits[1] 151 | 152 | start_pred = start_logits.cpu().numpy().tolist() 153 | end_pred = end_logits.cpu().numpy().tolist() 154 | 155 | start_score_list = get_bool_ids_greater_than(start_pred) 156 | end_score_list = get_bool_ids_greater_than(end_pred) 157 | 158 | for index, (start_score, end_score) in enumerate(zip(start_score_list, end_score_list)): 159 | S = get_span(start_score, end_score) 160 | self.metric.update(true_subject=inputs['label_ids'][index], pred_subject=S) 161 | 162 | self.evaluate_logs['eval_example'] += len(inputs['label_ids']) 163 | self.evaluate_logs['eval_step'] += 1 164 | self.evaluate_logs['eval_loss'] += loss.item() 165 | 166 | def _on_evaluate_epoch_end(self, 167 | validation_data, 168 | epoch=1, 169 | is_evaluate_print=True, 170 | **kwargs): 171 | 172 | with torch.no_grad(): 173 | eval_info = self.metric.result() 174 | 175 | if is_evaluate_print: 176 | print('eval_info: ', eval_info) 177 | 178 | def _train_collate_fn(self, batch): 179 | """将InputFeatures转换为Tensor""" 180 | 181 | input_ids = torch.tensor([f['input_ids'] for f in batch], dtype=torch.long) 182 | attention_mask = torch.tensor([f['attention_mask'] for f in batch], 183 | dtype=torch.long) 184 | token_type_ids = torch.tensor([f['token_type_ids'] for f in batch], 185 | dtype=torch.long) 186 | start_label_ids = torch.cat([f['start_label_ids'] for f in batch]) 187 | end_label_ids = torch.cat([f['end_label_ids'] for f in batch]) 188 | label_ids = [f['label_ids'] for f in batch] 189 | 190 | tensors = { 191 | 'input_ids': input_ids, 192 | 'attention_mask': attention_mask, 193 | 'token_type_ids': token_type_ids, 194 | 'start_label_ids': start_label_ids, 195 | 'end_label_ids': end_label_ids, 196 | 'label_ids': label_ids 197 | } 198 | 199 | return tensors 200 | 201 | def _evaluate_collate_fn(self, batch): 202 | return self._train_collate_fn(batch) 203 | -------------------------------------------------------------------------------- /tokenizer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020 DataArk Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | # Author: Xiang Wang, xiangking1995@163.com 16 | # Status: Active 17 | 18 | 19 | import warnings 20 | import unicodedata 21 | import transformers 22 | import numpy as np 23 | 24 | from typing import List 25 | from copy import deepcopy 26 | from ark_nlp.processor.tokenizer._tokenizer import BaseTokenizer 27 | 28 | 29 | class TransfomerTokenizer(BaseTokenizer): 30 | """ 31 | Transfomer文本编码器,用于对文本进行分词、ID化、填充等操作 32 | 33 | Args: 34 | vocab: transformers词典类对象、词典地址或词典名,用于实现文本分词和ID化 35 | max_seq_len (:obj:`int`): 预设的文本最大长度 36 | """ # noqa: ignore flake8" 37 | 38 | def __init__( 39 | self, 40 | vocab, 41 | max_seq_len 42 | ): 43 | if isinstance(vocab, str): 44 | # TODO: 改成由自定义的字典所决定 45 | vocab = transformers.AutoTokenizer.from_pretrained(vocab) 46 | 47 | self.vocab = vocab 48 | self.max_seq_len = max_seq_len 49 | self.additional_special_tokens = set() 50 | self.tokenizer_type = 'transfomer' 51 | 52 | @staticmethod 53 | def _is_control(ch): 54 | """控制类字符判断 55 | """ 56 | return unicodedata.category(ch) in ('Cc', 'Cf') 57 | 58 | @staticmethod 59 | def _is_special(ch): 60 | """判断是不是有特殊含义的符号 61 | """ 62 | return bool(ch) and (ch[0] == '[') and (ch[-1] == ']') 63 | 64 | @staticmethod 65 | def recover_bert_token(token): 66 | """获取token的“词干”(如果是##开头,则自动去掉##) 67 | """ 68 | if token[:2] == '##': 69 | return token[2:] 70 | else: 71 | return token 72 | 73 | def get_token_mapping(self, text, tokens, is_mapping_index=True): 74 | """给出原始的text和tokenize后的tokens的映射关系""" 75 | raw_text = deepcopy(text) 76 | text = text.lower() 77 | 78 | normalized_text, char_mapping = '', [] 79 | for i, ch in enumerate(text): 80 | ch = unicodedata.normalize('NFD', ch) 81 | ch = ''.join([c for c in ch if unicodedata.category(c) != 'Mn']) 82 | ch = ''.join([ 83 | c for c in ch 84 | if not (ord(c) == 0 or ord(c) == 0xfffd or self._is_control(c)) 85 | ]) 86 | normalized_text += ch 87 | char_mapping.extend([i] * len(ch)) 88 | 89 | text, token_mapping, offset = normalized_text, [], 0 90 | for token in tokens: 91 | token = token.lower() 92 | if token == '[unk]' or token in self.additional_special_tokens: 93 | if is_mapping_index: 94 | token_mapping.append(char_mapping[offset:offset+1]) 95 | else: 96 | token_mapping.append(raw_text[offset:offset+1]) 97 | offset = offset + 1 98 | elif self._is_special(token): 99 | # 如果是[CLS]或者是[SEP]之类的词,则没有对应的映射 100 | token_mapping.append([]) 101 | else: 102 | token = self.recover_bert_token(token) 103 | start = text[offset:].index(token) + offset 104 | end = start + len(token) 105 | if is_mapping_index: 106 | token_mapping.append(char_mapping[start:end]) 107 | else: 108 | token_mapping.append(raw_text[start:end]) 109 | offset = end 110 | 111 | return token_mapping 112 | 113 | def sequence_to_ids(self, sequence_a, sequence_b=None, **kwargs): 114 | if sequence_b is None: 115 | return self.sentence_to_ids(sequence_a, **kwargs) 116 | else: 117 | return self.pair_to_ids(sequence_a, sequence_b, **kwargs) 118 | 119 | def sentence_to_ids(self, sequence, return_sequence_length=False): 120 | if type(sequence) == str: 121 | sequence = self.tokenize(sequence) 122 | 123 | if return_sequence_length: 124 | sequence_length = len(sequence) 125 | 126 | # 对超长序列进行截断 127 | if len(sequence) > self.max_seq_len - 2: 128 | sequence = sequence[0:(self.max_seq_len - 2)] 129 | # 分别在首尾拼接特殊符号 130 | sequence = ['[CLS]'] + sequence + ['[SEP]'] 131 | segment_ids = [0] * len(sequence) 132 | # ID化 133 | sequence = self.vocab.convert_tokens_to_ids(sequence) 134 | 135 | # 根据max_seq_len与seq的长度产生填充序列 136 | padding = [0] * (self.max_seq_len - len(sequence)) 137 | # 创建seq_mask 138 | sequence_mask = [1] * len(sequence) + padding 139 | # 创建seq_segment 140 | segment_ids = segment_ids + padding 141 | # 对seq拼接填充序列 142 | sequence += padding 143 | 144 | sequence = np.asarray(sequence, dtype='int64') 145 | sequence_mask = np.asarray(sequence_mask, dtype='int64') 146 | segment_ids = np.asarray(segment_ids, dtype='int64') 147 | 148 | if return_sequence_length: 149 | return (sequence, sequence_mask, segment_ids, sequence_length) 150 | 151 | return (sequence, sequence_mask, segment_ids) 152 | 153 | def pair_to_ids( 154 | self, 155 | sequence_a, 156 | sequence_b, 157 | return_sequence_length=False, 158 | truncation_method='average' 159 | ): 160 | if type(sequence_a) == str: 161 | sequence_a = self.tokenize(sequence_a) 162 | 163 | if type(sequence_b) == str: 164 | sequence_b = self.tokenize(sequence_b) 165 | 166 | if return_sequence_length: 167 | sequence_length = (len(sequence_a), len(sequence_b)) 168 | 169 | # 对超长序列进行截断 170 | if truncation_method == 'average': 171 | if len(sequence_a) > ((self.max_seq_len - 3)//2): 172 | sequence_a = sequence_a[0:(self.max_seq_len - 3)//2] 173 | if len(sequence_b) > ((self.max_seq_len - 3)//2): 174 | sequence_b = sequence_b[0:(self.max_seq_len - 3)//2] 175 | elif truncation_method == 'last': 176 | if len(sequence_b) > (self.max_seq_len - 3 - len(sequence_a)): 177 | sequence_b = sequence_b[0:(self.max_seq_len - 3 - len(sequence_a))] 178 | elif truncation_method == 'first': 179 | if len(sequence_a) > (self.max_seq_len - 3 - len(sequence_b)): 180 | sequence_a = sequence_a[0:(self.max_seq_len - 3 - len(sequence_b))] 181 | else: 182 | raise ValueError("The truncation method does not exist") 183 | 184 | # 分别在首尾拼接特殊符号 185 | sequence = ['[CLS]'] + sequence_a + ['[SEP]'] + sequence_b + ['[SEP]'] 186 | segment_ids = [0] * (len(sequence_a) + 2) + [1] * (len(sequence_b) + 1) 187 | 188 | # ID化 189 | sequence = self.vocab.convert_tokens_to_ids(sequence) 190 | 191 | # 根据max_seq_len与seq的长度产生填充序列 192 | padding = [0] * (self.max_seq_len - len(sequence)) 193 | # 创建seq_mask 194 | sequence_mask = [1] * len(sequence) + padding 195 | # 创建seq_segment 196 | segment_ids = segment_ids + padding 197 | # 对seq拼接填充序列 198 | sequence += padding 199 | 200 | sequence = np.asarray(sequence, dtype='int64') 201 | sequence_mask = np.asarray(sequence_mask, dtype='int64') 202 | segment_ids = np.asarray(segment_ids, dtype='int64') 203 | 204 | if return_sequence_length: 205 | return (sequence, sequence_mask, segment_ids, sequence_length) 206 | 207 | return (sequence, sequence_mask, segment_ids) 208 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 DataArk Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | # Author: Xiang Wang, xiangking1995@163.com 16 | # Status: Active 17 | 18 | import math 19 | import random 20 | import numpy as np 21 | import pandas as pd 22 | 23 | from collections import defaultdict 24 | 25 | 26 | def convert_ner_task_uie_df(df, negative_ratio=5): 27 | 28 | negative_examples = [] 29 | positive_examples = [] 30 | 31 | label_type_set = set() 32 | 33 | for labels in df['label']: 34 | for label in labels: 35 | label_type_set.add(label['type']) 36 | 37 | for text, labels in zip(df['text'], df['label']): 38 | type2entities = defaultdict(list) 39 | 40 | for label in labels: 41 | type2entities[label['type']].append(label) 42 | 43 | positive_num = len(type2entities) 44 | 45 | for type_name, entities in type2entities.items(): 46 | positive_examples.append({ 47 | 'text': text, 48 | 'label': entities, 49 | 'condition': type_name 50 | }) 51 | 52 | if negative_ratio == 0: 53 | continue 54 | 55 | redundant_label_type_list = list(label_type_set - set(type2entities.keys())) 56 | redundant_label_type_list.sort() 57 | 58 | # 负样本抽样 59 | if positive_num != 0: 60 | actual_ratio = math.ceil(len(redundant_label_type_list) / positive_num) 61 | else: 62 | positive_num, actual_ratio = 1, 0 63 | 64 | if actual_ratio <= negative_ratio or negative_ratio == -1: 65 | idxs = [k for k in range(len(redundant_label_type_list))] 66 | else: 67 | idxs = random.sample(range(0, len(redundant_label_type_list)), 68 | negative_ratio * positive_num) 69 | 70 | for idx in idxs: 71 | negative_examples.append({ 72 | 'text': text, 73 | 'label': [], 74 | 'condition': redundant_label_type_list[idx] 75 | }) 76 | 77 | return pd.DataFrame(positive_examples + negative_examples) 78 | 79 | 80 | def get_bool_ids_greater_than(probs, limit=0.5, return_prob=False): 81 | """ 82 | 根据概率列表输出span的index列表 83 | 84 | Args: 85 | probs (list): 86 | 概率列表: [[概率]] 87 | 例如: [[0.1, 0.1, 0.2, 0.5, 0.1, 0.3], [0.7, 0.6, 0.1, 0.1, 0.1, 0.1]] 88 | limit (float, optional): 阈值, 默认值为0.5 89 | return_prob (bool, optional): 返回是否带上概率, 默认值为False 90 | 91 | Returns: 92 | list: span的index列表, 例如: [[], [0, 1]] 93 | 94 | Reference: 95 | [1] https://github.com/PaddlePaddle/PaddleNLP/blob/develop/paddlenlp/utils/tools.py 96 | """ 97 | probs = np.array(probs) 98 | dim_len = len(probs.shape) 99 | if dim_len > 1: 100 | result = [] 101 | for p in probs: 102 | result.append(get_bool_ids_greater_than(p, limit, return_prob)) 103 | return result 104 | else: 105 | result = [] 106 | for i, p in enumerate(probs): 107 | if p > limit: 108 | if return_prob: 109 | result.append((i, p)) 110 | else: 111 | result.append(i) 112 | return result 113 | 114 | 115 | def get_span(start_ids, end_ids, with_prob=False): 116 | """ 117 | 根据span的首index列表和尾index列表生成span 118 | 119 | Args: 120 | start_ids (list): span的首index列表: [index], 例如: [1, 2, 10] 121 | end_ids (list): span的尾index列表: [index], 例如: [4, 12] 122 | with_prob (bool, optional): 输入的列表是否带上概率, 默认值为False 123 | 124 | Returns: 125 | set: span列表, 例如: set((2, 4), (10, 12)) 126 | 127 | Reference: 128 | [1] https://github.com/PaddlePaddle/PaddleNLP/blob/develop/paddlenlp/utils/tools.py 129 | """ 130 | if with_prob: 131 | start_ids = sorted(start_ids, key=lambda x: x[0]) 132 | end_ids = sorted(end_ids, key=lambda x: x[0]) 133 | else: 134 | start_ids = sorted(start_ids) 135 | end_ids = sorted(end_ids) 136 | 137 | start_pointer = 0 138 | end_pointer = 0 139 | len_start = len(start_ids) 140 | len_end = len(end_ids) 141 | couple_dict = {} 142 | while start_pointer < len_start and end_pointer < len_end: 143 | if with_prob: 144 | if start_ids[start_pointer][0] == end_ids[end_pointer][0]: 145 | couple_dict[end_ids[end_pointer]] = start_ids[start_pointer] 146 | start_pointer += 1 147 | end_pointer += 1 148 | continue 149 | if start_ids[start_pointer][0] < end_ids[end_pointer][0]: 150 | couple_dict[end_ids[end_pointer]] = start_ids[start_pointer] 151 | start_pointer += 1 152 | continue 153 | if start_ids[start_pointer][0] > end_ids[end_pointer][0]: 154 | end_pointer += 1 155 | continue 156 | else: 157 | if start_ids[start_pointer] == end_ids[end_pointer]: 158 | couple_dict[end_ids[end_pointer]] = start_ids[start_pointer] 159 | start_pointer += 1 160 | end_pointer += 1 161 | continue 162 | if start_ids[start_pointer] < end_ids[end_pointer]: 163 | couple_dict[end_ids[end_pointer]] = start_ids[start_pointer] 164 | start_pointer += 1 165 | continue 166 | if start_ids[start_pointer] > end_ids[end_pointer]: 167 | end_pointer += 1 168 | continue 169 | result = [(couple_dict[end], end) for end in couple_dict] 170 | result = set(result) 171 | 172 | return result 173 | --------------------------------------------------------------------------------