├── .gitignore ├── LICENSE ├── README.md ├── convert.py ├── doccano.md ├── doccano.py ├── ernie.py ├── ernie_m.py ├── evaluate.py ├── export_model.py ├── finetune.py ├── labelstudio2doccano.py ├── model.py ├── requirements.txt ├── tokenizer.py ├── uie_predictor.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | build* 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | 53 | # Translations 54 | *.mo 55 | *.pot 56 | *.doctree 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | 63 | # Flask stuff: 64 | instance/ 65 | .webassets-cache 66 | 67 | # Scrapy stuff: 68 | .scrapy 69 | 70 | # Sphinx documentation 71 | docs/_build/ 72 | 73 | # PyBuilder 74 | target/ 75 | 76 | # Jupyter Notebook 77 | .ipynb_checkpoints 78 | 79 | # IPython 80 | profile_default/ 81 | ipython_config.py 82 | 83 | # pyenv 84 | .python-version 85 | 86 | # celery beat schedule file 87 | celerybeat-schedule 88 | 89 | # SageMath parsed files 90 | *.sage.py 91 | 92 | # Environments 93 | .env 94 | .venv 95 | env/ 96 | venv/ 97 | ENV/ 98 | env.bak/ 99 | venv.bak/ 100 | 101 | # Spyder project settings 102 | .spyderproject 103 | .spyproject 104 | 105 | # Rope project settings 106 | .ropeproject 107 | 108 | # mkdocs documentation 109 | /site 110 | 111 | # mypy 112 | .mypy_cache/ 113 | .dmypy.json 114 | dmypy.json 115 | 116 | # Pyre type checker 117 | .pyre/ 118 | 119 | # pycharm 120 | .DS_Store 121 | .idea/ 122 | FETCH_HEAD 123 | 124 | # vscode 125 | .vscode 126 | 127 | checkpoint 128 | data/* 129 | export 130 | model_best 131 | uie-base 132 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document. 10 | 11 | "Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License. 12 | 13 | "Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity. 14 | 15 | "You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License. 16 | 17 | "Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files. 18 | 19 | "Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types. 20 | 21 | "Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below). 22 | 23 | "Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof. 24 | 25 | "Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution." 26 | 27 | "Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work. 28 | 29 | 2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form. 30 | 31 | 3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed. 32 | 33 | 4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions: 34 | 35 | You must give any other recipients of the Work or Derivative Works a copy of this License; and 36 | You must cause any modified files to carry prominent notices stating that You changed the files; and 37 | You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and 38 | If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License. 39 | 40 | You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License. 5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions. 41 | 42 | 6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file. 43 | 44 | 7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License. 45 | 46 | 8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages. 47 | 48 | 9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability. 49 | 50 | END OF TERMS AND CONDITIONS 51 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # 通用信息抽取 UIE(Universal Information Extraction) PyTorch版 2 | 3 | **迁移[PaddleNLP](https://github.com/PaddlePaddle/PaddleNLP)中的UIE模型到PyTorch上** 4 | 5 | * 2022-10-3: 新增对UIE-M系列模型的支持,增加了ErnieM的Tokenizer。ErnieMTokenizer使用C++实现的高性能分词算子FasterTokenizer进行文本预处理加速。需要通过`pip install faster_tokenizer`安装FasterTokenizer库后方可使用。 6 | 7 | PyTorch版功能介绍 8 | - `convert.py`: 自动下载并转换模型,详见[开箱即用](#开箱即用)。 9 | - `doccano.py`: 转换标注数据,详见[数据标注](#数据标注)。 10 | - `evaluate.py`: 评估模型,详见[模型评估](#模型评估)。 11 | - `export_model.py`: 导出ONNX推理模型,详见[模型部署](#模型部署)。 12 | - `finetune.py`: 微调训练,详见[模型微调](#模型微调)。 13 | - `model.py`: 模型定义。 14 | - `uie_predictor.py`: 推理类。 15 | 16 | 17 | **目录** 18 | 19 | - [1. 模型简介](#模型简介) 20 | - [2. 应用示例](#应用示例) 21 | - [3. 开箱即用](#开箱即用) 22 | - [3.1 实体抽取](#实体抽取) 23 | - [3.2 关系抽取](#关系抽取) 24 | - [3.3 事件抽取](#事件抽取) 25 | - [3.4 评论观点抽取](#评论观点抽取) 26 | - [3.5 情感分类](#情感分类) 27 | - [3.6 跨任务抽取](#跨任务抽取) 28 | - [3.7 模型选择](#模型选择) 29 | - [3.8 更多配置](#更多配置) 30 | - [4. 训练定制](#训练定制) 31 | - [4.1 代码结构](#代码结构) 32 | - [4.2 数据标注](#数据标注) 33 | - [4.3 模型微调](#模型微调) 34 | - [4.4 模型评估](#模型评估) 35 | - [4.5 定制模型一键预测](#定制模型一键预测) 36 | - [4.6 实验指标](#实验指标) 37 | - [4.7 模型部署](#模型部署) 38 | 39 | 40 | 41 | ## 1. 模型简介 42 | 43 | [UIE(Universal Information Extraction)](https://arxiv.org/pdf/2203.12277.pdf):Yaojie Lu等人在ACL-2022中提出了通用信息抽取统一框架UIE。该框架实现了实体抽取、关系抽取、事件抽取、情感分析等任务的统一建模,并使得不同任务间具备良好的迁移和泛化能力。为了方便大家使用UIE的强大能力,PaddleNLP借鉴该论文的方法,基于ERNIE 3.0知识增强预训练模型,训练并开源了首个中文通用信息抽取模型UIE。该模型可以支持不限定行业领域和抽取目标的关键信息抽取,实现零样本快速冷启动,并具备优秀的小样本微调能力,快速适配特定的抽取目标。 44 | 45 |
46 | 47 |
48 | 49 | #### UIE的优势 50 | 51 | - **使用简单**:用户可以使用自然语言自定义抽取目标,无需训练即可统一抽取输入文本中的对应信息。**实现开箱即用,并满足各类信息抽取需求**。 52 | 53 | - **降本增效**:以往的信息抽取技术需要大量标注数据才能保证信息抽取的效果,为了提高开发过程中的开发效率,减少不必要的重复工作时间,开放域信息抽取可以实现零样本(zero-shot)或者少样本(few-shot)抽取,**大幅度降低标注数据依赖,在降低成本的同时,还提升了效果**。 54 | 55 | - **效果领先**:开放域信息抽取在多种场景,多种任务上,均有不俗的表现。 56 | 57 | 58 | 59 | ## 2. 应用示例 60 | 61 | UIE不限定行业领域和抽取目标,以下是一些零样本行业示例: 62 | 63 | - 医疗场景-专病结构化 64 | 65 | ![image](https://user-images.githubusercontent.com/40840292/169017581-93c8ee44-856d-4d17-970c-b6138d10f8bc.png) 66 | 67 | - 法律场景-判决书抽取 68 | 69 | ![image](https://user-images.githubusercontent.com/40840292/169017863-442c50f1-bfd4-47d0-8d95-8b1d53cfba3c.png) 70 | 71 | - 金融场景-收入证明、招股书抽取 72 | 73 | ![image](https://user-images.githubusercontent.com/40840292/169017982-e521ddf6-d233-41f3-974e-6f40f8f2edbc.png) 74 | 75 | - 公安场景-事故报告抽取 76 | 77 | ![image](https://user-images.githubusercontent.com/40840292/169018340-31efc1bf-f54d-43f7-b62a-8f7ce9bf0536.png) 78 | 79 | - 旅游场景-宣传册、手册抽取 80 | 81 | ![image](https://user-images.githubusercontent.com/40840292/169018113-c937eb0b-9fd7-4ecc-8615-bcdde2dac81d.png) 82 | 83 | 84 | 85 | ## 3. 开箱即用 86 | 87 | ```uie_predictor```提供通用信息抽取、评价观点抽取等能力,可抽取多种类型的信息,包括但不限于命名实体识别(如人名、地名、机构名等)、关系(如电影的导演、歌曲的发行时间等)、事件(如某路口发生车祸、某地发生地震等)、以及评价维度、观点词、情感倾向等信息。用户可以使用自然语言自定义抽取目标,无需训练即可统一抽取输入文本中的对应信息。**实现开箱即用,并满足各类信息抽取需求** 88 | 89 | ```uie_predictor```现在可以自动下载模型了,**无需手动convert**,如果想手动转换模型,可以参照以下方法。 90 | 91 | **下载并转换模型**,将下载Paddle版的`uie-base`模型到当前目录中,并生成PyTorch版模型`uie_base_pytorch`。 92 | 93 | ```shell 94 | python convert.py 95 | ``` 96 | 97 | 如果没有安装paddlenlp,则使用以下命令。这将不会导入paddlenlp,以及不会验证转换结果正确性。 98 | 99 | ```shell 100 | python convert.py --no_validate_output 101 | ``` 102 | 103 | 可配置参数说明: 104 | 105 | - `input_model`: 输入的模型所在的文件夹,例如存在模型`./model_path/model_state.pdparams`,则传入`./model_path`。如果传入`uie-base`或`uie-tiny`等在模型列表中的模型,且当前目录不存在此文件夹时,将自动下载模型。默认值为`uie-base`。 106 | 107 | 支持自动下载的模型 108 | - `uie-base` 109 | - `uie-medium` 110 | - `uie-mini` 111 | - `uie-micro` 112 | - `uie-nano` 113 | - `uie-medical-base` 114 | - `uie-tiny` (弃用,已改为`uie-medium`) 115 | - `uie-base-en` 116 | - `uie-m-base` 117 | - `uie-m-large` 118 | - `ernie-3.0-base-zh`* 119 | 120 | - `output_model`: 输出的模型的文件夹,默认为`uie_base_pytorch`。 121 | - `no_validate_output`:是否关闭对输出模型的验证,默认打开。 122 | 123 | \* : 使用`ernie-3.0-base-zh`时不会验证模型,需要微调后才能用于预测 124 | 125 | 126 | 127 | 128 | #### 3.1 实体抽取 129 | 130 | 命名实体识别(Named Entity Recognition,简称NER),是指识别文本中具有特定意义的实体。在开放域信息抽取中,抽取的类别没有限制,用户可以自己定义。 131 | 132 | - 例如抽取的目标实体类型是"时间"、"选手"和"赛事名称", schema构造如下: 133 | 134 | ```text 135 | ['时间', '选手', '赛事名称'] 136 | ``` 137 | 138 | 调用示例: 139 | 140 | ```python 141 | >>> from uie_predictor import UIEPredictor 142 | >>> from pprint import pprint 143 | 144 | >>> schema = ['时间', '选手', '赛事名称'] # Define the schema for entity extraction 145 | >>> ie = UIEPredictor(model='uie-base', schema=schema) 146 | >>> pprint(ie("2月8日上午北京冬奥会自由式滑雪女子大跳台决赛中中国选手谷爱凌以188.25分获得金牌!")) # Better print results using pprint 147 | [{'时间': [{'end': 6, 148 | 'probability': 0.9857378532924486, 149 | 'start': 0, 150 | 'text': '2月8日上午'}], 151 | '赛事名称': [{'end': 23, 152 | 'probability': 0.8503089953268272, 153 | 'start': 6, 154 | 'text': '北京冬奥会自由式滑雪女子大跳台决赛'}], 155 | '选手': [{'end': 31, 156 | 'probability': 0.8981548639781138, 157 | 'start': 28, 158 | 'text': '谷爱凌'}]}] 159 | ``` 160 | 161 | - 例如抽取的目标实体类型是"肿瘤的大小"、"肿瘤的个数"、"肝癌级别"和"脉管内癌栓分级", schema构造如下: 162 | 163 | ```text 164 | ['肿瘤的大小', '肿瘤的个数', '肝癌级别', '脉管内癌栓分级'] 165 | ``` 166 | 167 | 在上例中我们已经实例化了一个`UIEPredictor`对象,这里可以通过`set_schema`方法重置抽取目标。 168 | 169 | 调用示例: 170 | 171 | ```python 172 | >>> schema = ['肿瘤的大小', '肿瘤的个数', '肝癌级别', '脉管内癌栓分级'] 173 | >>> ie.set_schema(schema) 174 | >>> pprint(ie("(右肝肿瘤)肝细胞性肝癌(II-III级,梁索型和假腺管型),肿瘤包膜不完整,紧邻肝被膜,侵及周围肝组织,未见脉管内癌栓(MVI分级:M0级)及卫星子灶形成。(肿物1个,大小4.2×4.0×2.8cm)。")) 175 | [{'肝癌级别': [{'end': 20, 176 | 'probability': 0.9243267447402701, 177 | 'start': 13, 178 | 'text': 'II-III级'}], 179 | '肿瘤的个数': [{'end': 84, 180 | 'probability': 0.7538413804059623, 181 | 'start': 82, 182 | 'text': '1个'}], 183 | '肿瘤的大小': [{'end': 100, 184 | 'probability': 0.8341128043459491, 185 | 'start': 87, 186 | 'text': '4.2×4.0×2.8cm'}], 187 | '脉管内癌栓分级': [{'end': 70, 188 | 'probability': 0.9083292325934664, 189 | 'start': 67, 190 | 'text': 'M0级'}]}] 191 | ``` 192 | 193 | - 例如抽取的目标实体类型是"person"和"organization",schema构造如下: 194 | 195 | ```text 196 | ['person', 'organization'] 197 | ``` 198 | 199 | 英文模型调用示例: 200 | 201 | ```python 202 | >>> from uie_predictor import UIEPredictor 203 | >>> from pprint import pprint 204 | >>> schema = ['Person', 'Organization'] 205 | >>> ie_en = UIEPredictor(model='uie-base-en', schema=schema) 206 | >>> pprint(ie_en('In 1997, Steve was excited to become the CEO of Apple.')) 207 | [{'Organization': [{'end': 53, 208 | 'probability': 0.9985840259877357, 209 | 'start': 48, 210 | 'text': 'Apple'}], 211 | 'Person': [{'end': 14, 212 | 'probability': 0.999631971804547, 213 | 'start': 9, 214 | 'text': 'Steve'}]}] 215 | ``` 216 | 217 | 218 | 219 | #### 3.2 关系抽取 220 | 221 | 关系抽取(Relation Extraction,简称RE),是指从文本中识别实体并抽取实体之间的语义关系,进而获取三元组信息,即<主体,谓语,客体>。 222 | 223 | - 例如以"竞赛名称"作为抽取主体,抽取关系类型为"主办方"、"承办方"和"已举办次数", schema构造如下: 224 | 225 | ```text 226 | { 227 | '竞赛名称': [ 228 | '主办方', 229 | '承办方', 230 | '已举办次数' 231 | ] 232 | } 233 | ``` 234 | 235 | 调用示例: 236 | 237 | ```python 238 | >>> schema = {'竞赛名称': ['主办方', '承办方', '已举办次数']} # Define the schema for relation extraction 239 | >>> ie.set_schema(schema) # Reset schema 240 | >>> pprint(ie('2022语言与智能技术竞赛由中国中文信息学会和中国计算机学会联合主办,百度公司、中国中文信息学会评测工作委员会和中国计算机学会自然语言处理专委会承办,已连续举办4届,成为全球最热门的中文NLP赛事之一。')) 241 | [{'竞赛名称': [{'end': 13, 242 | 'probability': 0.7825402622754041, 243 | 'relations': {'主办方': [{'end': 22, 244 | 'probability': 0.8421710521379353, 245 | 'start': 14, 246 | 'text': '中国中文信息学会'}, 247 | {'end': 30, 248 | 'probability': 0.7580801847701935, 249 | 'start': 23, 250 | 'text': '中国计算机学会'}], 251 | '已举办次数': [{'end': 82, 252 | 'probability': 0.4671295049136148, 253 | 'start': 80, 254 | 'text': '4届'}], 255 | '承办方': [{'end': 39, 256 | 'probability': 0.8292706618236352, 257 | 'start': 35, 258 | 'text': '百度公司'}, 259 | {'end': 72, 260 | 'probability': 0.6193477885474685, 261 | 'start': 56, 262 | 'text': '中国计算机学会自然语言处理专委会'}, 263 | {'end': 55, 264 | 'probability': 0.7000497331473241, 265 | 'start': 40, 266 | 'text': '中国中文信息学会评测工作委员会'}]}, 267 | 'start': 0, 268 | 'text': '2022语言与智能技术竞赛'}]}] 269 | ``` 270 | 271 | - 例如以"person"作为抽取主体,抽取关系类型为"Company"和"Position", schema构造如下: 272 | 273 | ```text 274 | { 275 | 'Person': [ 276 | 'Company', 277 | 'Position' 278 | ] 279 | } 280 | ``` 281 | 282 | 英文模型调用示例: 283 | 284 | ```python 285 | >>> schema = [{'Person': ['Company', 'Position']}] 286 | >>> ie_en.set_schema(schema) 287 | >>> pprint(ie_en('In 1997, Steve was excited to become the CEO of Apple.')) 288 | [{'Person': [{'end': 14, 289 | 'probability': 0.999631971804547, 290 | 'relations': {'Company': [{'end': 53, 291 | 'probability': 0.9960158209451642, 292 | 'start': 48, 293 | 'text': 'Apple'}], 294 | 'Position': [{'end': 44, 295 | 'probability': 0.8871063806420736, 296 | 'start': 41, 297 | 'text': 'CEO'}]}, 298 | 'start': 9, 299 | 'text': 'Steve'}]}] 300 | ``` 301 | 302 | 303 | 304 | #### 3.3 事件抽取 305 | 306 | 事件抽取 (Event Extraction, 简称EE),是指从自然语言文本中抽取预定义的事件触发词(Trigger)和事件论元(Argument),组合为相应的事件结构化信息。 307 | 308 | - 例如抽取的目标是"地震"事件的"地震强度"、"时间"、"震中位置"和"震源深度"这些信息,schema构造如下: 309 | 310 | ```text 311 | { 312 | '地震触发词': [ 313 | '地震强度', 314 | '时间', 315 | '震中位置', 316 | '震源深度' 317 | ] 318 | } 319 | ``` 320 | 321 | 触发词的格式统一为`触发词`或``XX触发词`,`XX`表示具体事件类型,上例中的事件类型是`地震`,则对应触发词为`地震触发词`。 322 | 323 | 调用示例: 324 | 325 | ```python 326 | >>> schema = {'地震触发词': ['地震强度', '时间', '震中位置', '震源深度']} # Define the schema for event extraction 327 | >>> ie.set_schema(schema) # Reset schema 328 | >>> ie('中国地震台网正式测定:5月16日06时08分在云南临沧市凤庆县(北纬24.34度,东经99.98度)发生3.5级地震,震源深度10千米。') 329 | [{'地震触发词': [{'text': '地震', 'start': 56, 'end': 58, 'probability': 0.9987181623528585, 'relations': {'地震强度': [{'text': '3.5级', 'start': 52, 'end': 56, 'probability': 0.9962985320905915}], '时间': [{'text': '5月16日06时08分', 'start': 11, 'end': 22, 'probability': 0.9882578028575182}], '震中位置': [{'text': '云南临沧市凤庆县(北纬24.34度,东经99.98度)', 'start': 23, 'end': 50, 'probability': 0.8551415716584501}], '震源深度': [{'text': '10千米', 'start': 63, 'end': 67, 'probability': 0.999158304648045}]}}]}] 330 | ``` 331 | 332 | - 英文模型**暂不支持事件抽取** 333 | 334 | 335 | 336 | #### 3.4 评论观点抽取 337 | 338 | 评论观点抽取,是指抽取文本中包含的评价维度、观点词。 339 | 340 | - 例如抽取的目标是文本中包含的评价维度及其对应的观点词和情感倾向,schema构造如下: 341 | 342 | ```text 343 | { 344 | '评价维度': [ 345 | '观点词', 346 | '情感倾向[正向,负向]' 347 | ] 348 | } 349 | ``` 350 | 351 | 调用示例: 352 | 353 | ```python 354 | >>> schema = {'评价维度': ['观点词', '情感倾向[正向,负向]']} # Define the schema for opinion extraction 355 | >>> ie.set_schema(schema) # Reset schema 356 | >>> pprint(ie("店面干净,很清静,服务员服务热情,性价比很高,发现收银台有排队")) # Better print results using pprint 357 | [{'评价维度': [{'end': 20, 358 | 'probability': 0.9817040258681473, 359 | 'relations': {'情感倾向[正向,负向]': [{'probability': 0.9966142505350533, 360 | 'text': '正向'}], 361 | '观点词': [{'end': 22, 362 | 'probability': 0.957396472711558, 363 | 'start': 21, 364 | 'text': '高'}]}, 365 | 'start': 17, 366 | 'text': '性价比'}, 367 | {'end': 2, 368 | 'probability': 0.9696849569741168, 369 | 'relations': {'情感倾向[正向,负向]': [{'probability': 0.9982153274927796, 370 | 'text': '正向'}], 371 | '观点词': [{'end': 4, 372 | 'probability': 0.9945318044652538, 373 | 'start': 2, 374 | 'text': '干净'}]}, 375 | 'start': 0, 376 | 'text': '店面'}]}] 377 | ``` 378 | 379 | - 英文模型schema构造如下: 380 | 381 | ```text 382 | { 383 | 'Aspect': [ 384 | 'Opinion', 385 | 'Sentiment classification [negative, positive]' 386 | ] 387 | } 388 | ``` 389 | 390 | 调用示例: 391 | 392 | ```python 393 | >>> schema = [{'Aspect': ['Opinion', 'Sentiment classification [negative, positive]']}] 394 | >>> ie_en.set_schema(schema) 395 | >>> pprint(ie_en("The teacher is very nice.")) 396 | [{'Aspect': [{'end': 11, 397 | 'probability': 0.4301476415932193, 398 | 'relations': {'Opinion': [{'end': 24, 399 | 'probability': 0.9072940447883724, 400 | 'start': 15, 401 | 'text': 'very nice'}], 402 | 'Sentiment classification [negative, positive]': [{'probability': 0.9998571920670685, 403 | 'text': 'positive'}]}, 404 | 'start': 4, 405 | 'text': 'teacher'}]}] 406 | ``` 407 | 408 | 409 | 410 | #### 3.5 情感分类 411 | 412 | - 句子级情感倾向分类,即判断句子的情感倾向是“正向”还是“负向”,schema构造如下: 413 | 414 | ```text 415 | '情感倾向[正向,负向]' 416 | ``` 417 | 418 | 调用示例: 419 | 420 | ```python 421 | >>> schema = '情感倾向[正向,负向]' # Define the schema for sentence-level sentiment classification 422 | >>> ie.set_schema(schema) # Reset schema 423 | >>> ie('这个产品用起来真的很流畅,我非常喜欢') 424 | [{'情感倾向[正向,负向]': [{'text': '正向', 'probability': 0.9988661643929895}]}] 425 | ``` 426 | 427 | 英文模型schema构造如下: 428 | 429 | ```text 430 | '情感倾向[正向,负向]' 431 | ``` 432 | 433 | 英文模型调用示例: 434 | 435 | ```python 436 | >>> schema = 'Sentiment classification [negative, positive]' 437 | >>> ie_en.set_schema(schema) 438 | >>> ie_en('I am sorry but this is the worst film I have ever seen in my life.') 439 | [{'Sentiment classification [negative, positive]': [{'text': 'negative', 'probability': 0.9998415771287057}]}] 440 | ``` 441 | 442 | 443 | 444 | #### 3.6 跨任务抽取 445 | 446 | - 例如在法律场景同时对文本进行实体抽取和关系抽取,schema可按照如下方式进行构造: 447 | 448 | ```text 449 | [ 450 | "法院", 451 | { 452 | "原告": "委托代理人" 453 | }, 454 | { 455 | "被告": "委托代理人" 456 | } 457 | ] 458 | ``` 459 | 460 | 调用示例: 461 | 462 | ```python 463 | >>> schema = ['法院', {'原告': '委托代理人'}, {'被告': '委托代理人'}] 464 | >>> ie.set_schema(schema) 465 | >>> pprint(ie("北京市海淀区人民法院\n民事判决书\n(199x)建初字第xxx号\n原告:张三。\n委托代理人李四,北京市 A律师事务所律师。\n被告:B公司,法定代表人王五,开发公司总经理。\n委托代理人赵六,北京市 C律师事务所律师。")) # Better print results using pprint 466 | [{'原告': [{'end': 37, 467 | 'probability': 0.9949814024296764, 468 | 'relations': {'委托代理人': [{'end': 46, 469 | 'probability': 0.7956844697990384, 470 | 'start': 44, 471 | 'text': '李四'}]}, 472 | 'start': 35, 473 | 'text': '张三'}], 474 | '法院': [{'end': 10, 475 | 'probability': 0.9221074192336651, 476 | 'start': 0, 477 | 'text': '北京市海淀区人民法院'}], 478 | '被告': [{'end': 67, 479 | 'probability': 0.8437349536631089, 480 | 'relations': {'委托代理人': [{'end': 92, 481 | 'probability': 0.7267121388225029, 482 | 'start': 90, 483 | 'text': '赵六'}]}, 484 | 'start': 64, 485 | 'text': 'B公司'}]}] 486 | ``` 487 | 488 | 489 | 490 | #### 3.7 模型选择 491 | 492 | - 多模型选择,满足精度、速度要求 493 | 494 | | 模型 | 结构 | 语言 | 495 | | :---: | :--------: | :--------: | 496 | | `uie-base` (默认)| 12-layers, 768-hidden, 12-heads | 中文 | 497 | | `uie-base-en` | 12-layers, 768-hidden, 12-heads | 英文 | 498 | | `uie-medical-base` | 12-layers, 768-hidden, 12-heads | 中文 | 499 | | `uie-medium`| 6-layers, 768-hidden, 12-heads | 中文 | 500 | | `uie-mini`| 6-layers, 384-hidden, 12-heads | 中文 | 501 | | `uie-micro`| 4-layers, 384-hidden, 12-heads | 中文 | 502 | | `uie-nano`| 4-layers, 312-hidden, 12-heads | 中文 | 503 | | `uie-m-large`| 24-layers, 1024-hidden, 16-heads | 中、英文 | 504 | | `uie-m-base`| 12-layers, 768-hidden, 12-heads | 中、英文 | 505 | 506 | 507 | - `uie-nano`调用示例: 508 | 509 | ```python 510 | >>> from uie_predictor import UIEPredictor 511 | 512 | >>> schema = ['时间', '选手', '赛事名称'] 513 | >>> ie = UIEPredictor('uie-nano', schema=schema) 514 | >>> ie("2月8日上午北京冬奥会自由式滑雪女子大跳台决赛中中国选手谷爱凌以188.25分获得金牌!") 515 | [{'时间': [{'text': '2月8日上午', 'start': 0, 'end': 6, 'probability': 0.6513581678349247}], '选手': [{'text': '谷爱凌', 'start': 28, 'end': 31, 'probability': 0.9819330659468051}], '赛事名称': [{'text': '北京冬奥会自由式滑雪女子大跳台决赛', 'start': 6, 'end': 23, 'probability': 0.4908131110420939}]}] 516 | ``` 517 | 518 | - `uie-m-base`和`uie-m-large`支持中英文混合抽取,调用示例: 519 | 520 | ```python 521 | >>> from pprint import pprint 522 | >>> from uie_predictor import UIEPredictor 523 | 524 | >>> schema = ['Time', 'Player', 'Competition', 'Score'] 525 | >>> ie = UIEPredictor(schema=schema, model="uie-m-base", schema_lang="en") 526 | >>> pprint(ie(["2月8日上午北京冬奥会自由式滑雪女子大跳台决赛中中国选手谷爱凌以188.25分获得金牌!", "Rafael Nadal wins French Open Final!"])) 527 | [{'Competition': [{'end': 23, 528 | 'probability': 0.9373889907291257, 529 | 'start': 6, 530 | 'text': '北京冬奥会自由式滑雪女子大跳台决赛'}], 531 | 'Player': [{'end': 31, 532 | 'probability': 0.6981119555336441, 533 | 'start': 28, 534 | 'text': '谷爱凌'}], 535 | 'Score': [{'end': 39, 536 | 'probability': 0.9888507878270296, 537 | 'start': 32, 538 | 'text': '188.25分'}], 539 | 'Time': [{'end': 6, 540 | 'probability': 0.9784080036931151, 541 | 'start': 0, 542 | 'text': '2月8日上午'}]}, 543 | {'Competition': [{'end': 35, 544 | 'probability': 0.9851549932171295, 545 | 'start': 18, 546 | 'text': 'French Open Final'}], 547 | 'Player': [{'end': 12, 548 | 'probability': 0.9379371275888104, 549 | 'start': 0, 550 | 'text': 'Rafael Nadal'}]}] 551 | ``` 552 | 553 | 554 | 555 | #### 3.8 更多配置 556 | 557 | ```python 558 | >>> from uie_predictor import UIEPredictor 559 | 560 | >>> ie = UIEPredictor('uie_nano', 561 | schema=schema) 562 | ``` 563 | 564 | * `model`:选择任务使用的模型,默认为`uie-base`,可选有`uie-base`, `uie-medium`, `uie-mini`, `uie-micro`, `uie-nano`和`uie-medical-base`, `uie-base-en`。 565 | * `schema`:定义任务抽取目标,可参考开箱即用中不同任务的调用示例进行配置。 566 | * `schema_lang`:设置schema的语言,默认为`zh`, 可选有`zh`和`en`。因为中英schema的构造有所不同,因此需要指定schema的语言。该参数只对`uie-m-base`和`uie-m-large`模型有效。 567 | * `batch_size`:批处理大小,请结合机器情况进行调整,默认为1。 568 | * `task_path`:设定自定义的模型。 569 | * `position_prob`:模型对于span的起始位置/终止位置的结果概率在0~1之间,返回结果去掉小于这个阈值的结果,默认为0.5,span的最终概率输出为起始位置概率和终止位置概率的乘积。 570 | * `use_fp16`:是否使用`fp16`进行加速,默认关闭。`fp16`推理速度更快。如果选择`fp16`,请先确保机器正确安装NVIDIA相关驱动和基础软件,**确保CUDA>=11.2,cuDNN>=8.1.1**,初次使用需按照提示安装相关依赖。其次,需要确保GPU设备的CUDA计算能力(CUDA Compute Capability)大于7.0,典型的设备包括V100、T4、A10、A100、GTX 20系列和30系列显卡等。更多关于CUDA Compute Capability和精度支持情况请参考NVIDIA文档:[GPU硬件与支持精度对照表](https://docs.nvidia.com/deeplearning/tensorrt/archives/tensorrt-840-ea/support-matrix/index.html#hardware-precision-matrix)。 571 | 572 | 573 | 574 | ## 4. 训练定制 575 | 576 | 对于简单的抽取目标可以直接使用```UIEPredictor```实现零样本(zero-shot)抽取,对于细分场景我们推荐使用轻定制功能(标注少量数据进行模型微调)以进一步提升效果。下面通过`报销工单信息抽取`的例子展示如何通过5条训练数据进行UIE模型微调。 577 | 578 | 579 | #### 4.1 代码结构 580 | 581 | ```shell 582 | . 583 | ├── utils.py # 数据处理工具 584 | ├── model.py # 模型组网脚本 585 | ├── doccano.py # 数据标注脚本 586 | ├── doccano.md # 数据标注文档 587 | ├── finetune.py # 模型微调脚本 588 | ├── evaluate.py # 模型评估脚本 589 | └── README.md 590 | ``` 591 | 592 | 593 | 594 | #### 4.2 数据标注 595 | 596 | 我们推荐使用数据标注平台[doccano](https://github.com/doccano/doccano) 进行数据标注,本示例也打通了从标注到训练的通道,即doccano导出数据后可通过[doccano.py](./doccano.py)脚本轻松将数据转换为输入模型时需要的形式,实现无缝衔接。标注方法的详细介绍请参考[doccano数据标注指南](doccano.md)。 597 | 598 | 原始数据示例: 599 | 600 | ```text 601 | 深大到双龙28块钱4月24号交通费 602 | ``` 603 | 604 | 抽取的目标(schema)为: 605 | 606 | ```python 607 | schema = ['出发地', '目的地', '费用', '时间'] 608 | ``` 609 | 610 | 标注步骤如下: 611 | 612 | - 在doccano平台上,创建一个类型为``序列标注``的标注项目。 613 | - 定义实体标签类别,上例中需要定义的实体标签有``出发地``、``目的地``、``费用``和``时间``。 614 | - 使用以上定义的标签开始标注数据,下面展示了一个doccano标注示例: 615 | 616 |
617 | 618 |
619 | 620 | - 标注完成后,在doccano平台上导出文件,并将其重命名为``doccano_ext.json``后,放入``./data``目录下。 621 | 622 | - 这里我们提供预先标注好的文件[doccano_ext.json](https://bj.bcebos.com/paddlenlp/datasets/uie/doccano_ext.json),可直接下载并放入`./data`目录。执行以下脚本进行数据转换,执行后会在`./data`目录下生成训练/验证/测试集文件。 623 | 624 | ```shell 625 | python doccano.py \ 626 | --doccano_file ./data/doccano_ext.json \ 627 | --task_type ext \ 628 | --save_dir ./data \ 629 | --splits 0.8 0.2 0 630 | ``` 631 | 632 | 633 | 可配置参数说明: 634 | 635 | - ``doccano_file``: 从doccano导出的数据标注文件。 636 | - ``save_dir``: 训练数据的保存目录,默认存储在``data``目录下。 637 | - ``negative_ratio``: 最大负例比例,该参数只对抽取类型任务有效,适当构造负例可提升模型效果。负例数量和实际的标签数量有关,最大负例数量 = negative_ratio * 正例数量。该参数只对训练集有效,默认为5。为了保证评估指标的准确性,验证集和测试集默认构造全负例。 638 | - ``splits``: 划分数据集时训练集、验证集所占的比例。默认为[0.8, 0.1, 0.1]表示按照``8:1:1``的比例将数据划分为训练集、验证集和测试集。 639 | - ``task_type``: 选择任务类型,可选有抽取和分类两种类型的任务。 640 | - ``options``: 指定分类任务的类别标签,该参数只对分类类型任务有效。默认为["正向", "负向"]。 641 | - ``prompt_prefix``: 声明分类任务的prompt前缀信息,该参数只对分类类型任务有效。默认为"情感倾向"。 642 | - ``is_shuffle``: 是否对数据集进行随机打散,默认为True。 643 | - ``seed``: 随机种子,默认为1000. 644 | - ``separator``: 实体类别/评价维度与分类标签的分隔符,该参数只对实体/评价维度级分类任务有效。默认为"##"。 645 | 646 | 备注: 647 | - 默认情况下 [doccano.py](./doccano.py) 脚本会按照比例将数据划分为 train/dev/test 数据集 648 | - 每次执行 [doccano.py](./doccano.py) 脚本,将会覆盖已有的同名数据文件 649 | - 在模型训练阶段我们推荐构造一些负例以提升模型效果,在数据转换阶段我们内置了这一功能。可通过`negative_ratio`控制自动构造的负样本比例;负样本数量 = negative_ratio * 正样本数量。 650 | - 对于从doccano导出的文件,默认文件中的每条数据都是经过人工正确标注的。 651 | 652 | 更多**不同类型任务(关系抽取、事件抽取、评价观点抽取等)的标注规则及参数说明**,请参考[doccano数据标注指南](doccano.md)。 653 | 654 | 此外,也可以通过数据标注平台 [Label Studio](https://labelstud.io/) 进行数据标注。本示例提供了 [labelstudio2doccano.py](./labelstudio2doccano.py) 脚本,将 label studio 导出的 JSON 数据文件格式转换成 doccano 导出的数据文件格式,后续的数据转换与模型微调等操作不变。 655 | 656 | ```shell 657 | python labelstudio2doccano.py --labelstudio_file label-studio.json 658 | ``` 659 | 660 | 可配置参数说明: 661 | 662 | - ``labelstudio_file``: label studio 的导出文件路径(仅支持 JSON 格式)。 663 | - ``doccano_file``: doccano 格式的数据文件保存路径,默认为 "doccano_ext.jsonl"。 664 | - ``task_type``: 任务类型,可选有抽取("ext")和分类("cls")两种类型的任务,默认为 "ext"。 665 | 666 | 667 | 668 | #### 4.3 模型微调 669 | 670 | 通过运行以下命令进行模型微调: 671 | 672 | ```shell 673 | python finetune.py \ 674 | --train_path "./data/train.txt" \ 675 | --dev_path "./data/dev.txt" \ 676 | --save_dir "./checkpoint" \ 677 | --learning_rate 1e-5 \ 678 | --batch_size 16 \ 679 | --max_seq_len 512 \ 680 | --num_epochs 100 \ 681 | --model "uie_base_pytorch" \ 682 | --seed 1000 \ 683 | --logging_steps 10 \ 684 | --valid_steps 100 \ 685 | --device "gpu" 686 | ``` 687 | 688 | 可配置参数说明: 689 | 690 | - `train_path`: 训练集文件路径。 691 | - `dev_path`: 验证集文件路径。 692 | - `save_dir`: 模型存储路径,默认为`./checkpoint`。 693 | - `learning_rate`: 学习率,默认为1e-5。 694 | - `batch_size`: 批处理大小,请结合机器情况进行调整,默认为16。 695 | - `max_seq_len`: 文本最大切分长度,输入超过最大长度时会对输入文本进行自动切分,默认为512。 696 | - `num_epochs`: 训练轮数,默认为100。 697 | - `model`: 选择模型,程序会基于选择的模型进行模型微调,默认为`uie_base_pytorch`。 698 | - `seed`: 随机种子,默认为1000. 699 | - `logging_steps`: 日志打印的间隔steps数,默认10。 700 | - `valid_steps`: evaluate的间隔steps数,默认100。 701 | - `device`: 选用什么设备进行训练,可选cpu或gpu。 702 | - `max_model_num`: 保存的模型的个数,不包含`model_best`和`early_stopping`保存的模型,默认为5。 703 | - `early_stopping`: 是否采用提前停止(Early Stopping),默认不使用。 704 | 705 | 706 | 707 | #### 4.4 模型评估 708 | 709 | 通过运行以下命令进行模型评估: 710 | 711 | ```shell 712 | python evaluate.py \ 713 | --model_path ./checkpoint/model_best \ 714 | --test_path ./data/dev.txt \ 715 | --batch_size 16 \ 716 | --max_seq_len 512 717 | ``` 718 | 719 | 评估方式说明:采用单阶段评价的方式,即关系抽取、事件抽取等需要分阶段预测的任务对每一阶段的预测结果进行分别评价。验证/测试集默认会利用同一层级的所有标签来构造出全部负例。 720 | 721 | 可开启`debug`模式对每个正例类别分别进行评估,该模式仅用于模型调试: 722 | 723 | ```shell 724 | python evaluate.py \ 725 | --model_path ./checkpoint/model_best \ 726 | --test_path ./data/dev.txt \ 727 | --debug 728 | ``` 729 | 730 | 输出打印示例: 731 | 732 | ```text 733 | [2022-09-14 03:13:58,877] [ INFO] - ----------------------------- 734 | [2022-09-14 03:13:58,877] [ INFO] - Class Name: 疾病 735 | [2022-09-14 03:13:58,877] [ INFO] - Evaluation Precision: 0.89744 | Recall: 0.83333 | F1: 0.86420 736 | [2022-09-14 03:13:59,145] [ INFO] - ----------------------------- 737 | [2022-09-14 03:13:59,145] [ INFO] - Class Name: 手术治疗 738 | [2022-09-14 03:13:59,145] [ INFO] - Evaluation Precision: 0.90000 | Recall: 0.85714 | F1: 0.87805 739 | [2022-09-14 03:13:59,439] [ INFO] - ----------------------------- 740 | [2022-09-14 03:13:59,440] [ INFO] - Class Name: 检查 741 | [2022-09-14 03:13:59,440] [ INFO] - Evaluation Precision: 0.77778 | Recall: 0.56757 | F1: 0.65625 742 | [2022-09-14 03:13:59,708] [ INFO] - ----------------------------- 743 | [2022-09-14 03:13:59,709] [ INFO] - Class Name: X的手术治疗 744 | [2022-09-14 03:13:59,709] [ INFO] - Evaluation Precision: 0.90000 | Recall: 0.85714 | F1: 0.87805 745 | [2022-09-14 03:13:59,893] [ INFO] - ----------------------------- 746 | [2022-09-14 03:13:59,893] [ INFO] - Class Name: X的实验室检查 747 | [2022-09-14 03:13:59,894] [ INFO] - Evaluation Precision: 0.71429 | Recall: 0.55556 | F1: 0.62500 748 | [2022-09-14 03:14:00,057] [ INFO] - ----------------------------- 749 | [2022-09-14 03:14:00,058] [ INFO] - Class Name: X的影像学检查 750 | [2022-09-14 03:14:00,058] [ INFO] - Evaluation Precision: 0.69231 | Recall: 0.45000 | F1: 0.54545 751 | ``` 752 | 753 | 可配置参数说明: 754 | 755 | - `model_path`: 进行评估的模型文件夹路径,路径下需包含模型权重文件`pytorch_model.bin`及配置文件`config.json`。 756 | - `test_path`: 进行评估的测试集文件。 757 | - `batch_size`: 批处理大小,请结合机器情况进行调整,默认为16。 758 | - `max_seq_len`: 文本最大切分长度,输入超过最大长度时会对输入文本进行自动切分,默认为512。 759 | - `device`: 选用进行训练的设备,可选`cpu`或`gpu`。 760 | 761 | 762 | 763 | #### 4.5 定制模型一键预测 764 | 765 | `UIEPredictor`装载定制模型,通过`task_path`指定模型权重文件的路径,路径下需要包含训练好的模型权重文件`pytorch_model.bin`。 766 | 767 | ```python 768 | >>> from pprint import pprint 769 | >>> from uie_predictor import UIEPredictor 770 | 771 | >>> schema = ['出发地', '目的地', '费用', '时间'] 772 | # 设定抽取目标和定制化模型权重路径 773 | >>> my_ie = UIEPredictor(model='uie-base',task_path='./checkpoint/model_best', schema=schema) 774 | >>> pprint(my_ie("城市内交通费7月5日金额114广州至佛山")) 775 | [{'出发地': [{'end': 17, 776 | 'probability': 0.9975287467835301, 777 | 'start': 15, 778 | 'text': '广州'}], 779 | '时间': [{'end': 10, 780 | 'probability': 0.9999476678061399, 781 | 'start': 6, 782 | 'text': '7月5日'}], 783 | '目的地': [{'end': 20, 784 | 'probability': 0.9998511131226735, 785 | 'start': 18, 786 | 'text': '佛山'}], 787 | '费用': [{'end': 15, 788 | 'probability': 0.9994474579292856, 789 | 'start': 12, 790 | 'text': '114'}]}] 791 | ``` 792 | 793 | 794 | 795 | #### 4.6 实验指标 796 | 797 | 我们在互联网、医疗、金融三大垂类自建测试集上进行了实验: 798 | 799 | 800 |
金融医疗互联网 801 |
0-shot5-shot0-shot5-shot0-shot5-shot 802 |
uie-base (12L768H)46.4370.9271.8385.7278.3381.86 803 |
uie-medium (6L768H)41.1164.5365.4075.7278.3279.68 804 |
uie-mini (6L384H)37.0464.6560.5078.3672.0976.38 805 |
uie-micro (4L384H)37.5362.1157.0475.9266.0070.22 806 |
uie-nano (4L312H)38.9466.8348.2976.7462.8672.35 807 |
uie-m-large (24L1024H)49.3574.5570.5092.6678.4983.02 808 |
uie-m-base (12L768H)38.4674.3163.3787.3276.2780.13 809 |
810 | 811 | 0-shot表示无训练数据直接通过```UIEPredictor```进行预测,5-shot表示每个类别包含5条标注数据进行模型微调。**实验表明UIE在垂类场景可以通过少量数据(few-shot)进一步提升效果**。 812 | 813 | 814 | 815 | #### 4.7 模型部署 816 | 817 | 以下是UIE Python端的部署流程,包括环境准备、模型导出和使用示例。 818 | 819 | - 环境准备 820 | UIE的部署分为CPU和GPU两种情况,请根据你的部署环境安装对应的依赖。 821 | 822 | - CPU端 823 | 824 | CPU端的部署请使用如下命令安装所需依赖 825 | 826 | ```shell 827 | pip install onnx onnxruntime 828 | ``` 829 | - GPU端 830 | 831 | 为了在GPU上获得最佳的推理性能和稳定性,请先确保机器已正确安装NVIDIA相关驱动和基础软件,确保**CUDA >= 11.2,cuDNN >= 8.1.1**,并使用以下命令安装所需依赖 832 | 833 | ```shell 834 | pip install onnx onnxconverter_common onnxruntime-gpu 835 | ``` 836 | 837 | 如需使用半精度(FP16)部署,请确保GPU设备的CUDA计算能力 (CUDA Compute Capability) 大于7.0,典型的设备包括V100、T4、A10、A100、GTX 20系列和30系列显卡等。 838 | 更多关于CUDA Compute Capability和精度支持情况请参考NVIDIA文档:[GPU硬件与支持精度对照表](https://docs.nvidia.com/deeplearning/tensorrt/archives/tensorrt-840-ea/support-matrix/index.html#hardware-precision-matrix) 839 | 840 | 841 | - 模型导出 842 | 843 | 将训练后的动态图参数导出为静态图参数: 844 | 845 | ```shell 846 | python export_model.py --model_path ./checkpoint/model_best --output_path ./export 847 | ``` 848 | 849 | 可配置参数说明: 850 | 851 | - `model_path`: 动态图训练保存的参数路径,路径下包含模型参数文件`pytorch_model.bin`和模型配置文件`config.json`。 852 | - `output_path`: 静态图参数导出路径,默认导出路径为`model_path`,即保存到输入模型同目录下。 853 | 854 | - 推理 855 | 856 | - CPU端推理样例 857 | 858 | 在CPU端,请使用如下命令进行部署 859 | 860 | ```shell 861 | python uie_predictor.py --task_path ./export --engine onnx --device cpu 862 | ``` 863 | 864 | 可配置参数说明: 865 | - `model`:选择任务使用的模型,默认为`uie-base`,可选有`uie-base`, `uie-medium`, `uie-mini`, `uie-micro`, `uie-nano`和`uie-medical-base`, `uie-base-en`。 866 | - `task_path`: 用于推理的ONNX模型文件所在文件夹。例如模型文件路径为`./export/inference.onnx`,则传入`./export`。如果不设置,将自动下载`model`对应的模型。 867 | - `position_prob`:模型对于span的起始位置/终止位置的结果概率0~1之间,返回结果去掉小于这个阈值的结果,默认为0.5,span的最终概率输出为起始位置概率和终止位置概率的乘积。 868 | - `max_seq_len`: 文本最大切分长度,输入超过最大长度时会对输入文本进行自动切分,默认为512。 869 | - `engine`: 可选值为`pytorch`和`onnx`。推理使用的推理引擎。 870 | 871 | - GPU端推理样例 872 | 873 | 在GPU端,请使用如下命令进行部署 874 | 875 | ```shell 876 | python uie_predictor.py --task_path ./export --engine onnx --device gpu --use_fp16 877 | ``` 878 | 879 | 可配置参数说明: 880 | - `model`:选择任务使用的模型,默认为`uie-base`,可选有`uie-base`, `uie-medium`, `uie-mini`, `uie-micro`, `uie-nano`和`uie-medical-base`, `uie-base-en`。 881 | - `task_path`: 用于推理的ONNX模型文件所在文件夹。例如模型文件路径为`./export/inference.onnx`,则传入`./export/inference`。如果不设置,将自动下载`model`对应的模型。 882 | - `use_fp16`: 是否使用FP16进行加速,默认关闭。 883 | - `position_prob`:模型对于span的起始位置/终止位置的结果概率0~1之间,返回结果去掉小于这个阈值的结果,默认为0.5,span的最终概率输出为起始位置概率和终止位置概率的乘积。 884 | - `max_seq_len`: 文本最大切分长度,输入超过最大长度时会对输入文本进行自动切分,默认为512。 885 | - `engine`: 可选值为`pytorch`和`onnx`。推理使用的推理引擎。 886 | -------------------------------------------------------------------------------- /convert.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 Heiheiyoyo. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import argparse 16 | import collections 17 | import json 18 | import os 19 | import pickle 20 | import shutil 21 | import numpy as np 22 | from base64 import b64decode 23 | 24 | import torch 25 | try: 26 | import paddle 27 | from paddle.utils.download import get_path_from_url 28 | paddle_installed = True 29 | except (ImportError, ModuleNotFoundError): 30 | from utils import get_path_from_url 31 | paddle_installed = False 32 | 33 | from model import UIE, UIEM 34 | from utils import logger 35 | 36 | MODEL_MAP = { 37 | # vocab.txt/special_tokens_map.json/tokenizer_config.json are common to the default model. 38 | "uie-base": { 39 | "resource_file_urls": { 40 | "model_state.pdparams": 41 | "https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_base_v1.0/model_state.pdparams", 42 | "model_config.json": 43 | "https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_base/model_config.json", 44 | "vocab.txt": 45 | "https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_base/vocab.txt", 46 | "special_tokens_map.json": 47 | "https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_base/special_tokens_map.json", 48 | "tokenizer_config.json": 49 | "https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_base/tokenizer_config.json" 50 | } 51 | }, 52 | "uie-medium": { 53 | "resource_file_urls": { 54 | "model_state.pdparams": 55 | "https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_medium_v1.0/model_state.pdparams", 56 | "model_config.json": 57 | "https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_medium/model_config.json", 58 | "vocab.txt": 59 | "https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_base/vocab.txt", 60 | "special_tokens_map.json": 61 | "https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_base/special_tokens_map.json", 62 | "tokenizer_config.json": 63 | "https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_base/tokenizer_config.json", 64 | } 65 | }, 66 | "uie-mini": { 67 | "resource_file_urls": { 68 | "model_state.pdparams": 69 | "https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_mini_v1.0/model_state.pdparams", 70 | "model_config.json": 71 | "https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_mini/model_config.json", 72 | "vocab.txt": 73 | "https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_base/vocab.txt", 74 | "special_tokens_map.json": 75 | "https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_base/special_tokens_map.json", 76 | "tokenizer_config.json": 77 | "https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_base/tokenizer_config.json", 78 | } 79 | }, 80 | "uie-micro": { 81 | "resource_file_urls": { 82 | "model_state.pdparams": 83 | "https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_micro_v1.0/model_state.pdparams", 84 | "model_config.json": 85 | "https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_micro/model_config.json", 86 | "vocab.txt": 87 | "https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_base/vocab.txt", 88 | "special_tokens_map.json": 89 | "https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_base/special_tokens_map.json", 90 | "tokenizer_config.json": 91 | "https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_base/tokenizer_config.json", 92 | } 93 | }, 94 | "uie-nano": { 95 | "resource_file_urls": { 96 | "model_state.pdparams": 97 | "https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_nano_v1.0/model_state.pdparams", 98 | "model_config.json": 99 | "https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_nano/model_config.json", 100 | "vocab.txt": 101 | "https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_base/vocab.txt", 102 | "special_tokens_map.json": 103 | "https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_base/special_tokens_map.json", 104 | "tokenizer_config.json": 105 | "https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_base/tokenizer_config.json", 106 | } 107 | }, 108 | "uie-medical-base": { 109 | "resource_file_urls": { 110 | "model_state.pdparams": 111 | "https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_medical_base_v0.1/model_state.pdparams", 112 | "model_config.json": 113 | "https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_base/model_config.json", 114 | "vocab.txt": 115 | "https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_base/vocab.txt", 116 | "special_tokens_map.json": 117 | "https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_base/special_tokens_map.json", 118 | "tokenizer_config.json": 119 | "https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_base/tokenizer_config.json", 120 | } 121 | }, 122 | "uie-base-en": { 123 | "resource_file_urls": { 124 | "model_state.pdparams": 125 | "https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_base_en_v1.1/model_state.pdparams", 126 | "model_config.json": 127 | "https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_base_en/model_config.json", 128 | "vocab.txt": 129 | "https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_base_en/vocab.txt", 130 | "special_tokens_map.json": 131 | "https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_base_en/special_tokens_map.json", 132 | "tokenizer_config.json": 133 | "https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_base_en/tokenizer_config.json", 134 | } 135 | }, 136 | # uie-m模型需要Ernie-M模型 137 | "uie-m-base": { 138 | "resource_file_urls": { 139 | "model_state.pdparams": 140 | "https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_m_base_v1.0/model_state.pdparams", 141 | "model_config.json": 142 | "https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_m_base/model_config.json", 143 | "vocab.txt": 144 | "https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_m_base/vocab.txt", 145 | "special_tokens_map.json": 146 | "https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_m_base/special_tokens_map.json", 147 | "tokenizer_config.json": 148 | "https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_m_base/tokenizer_config.json", 149 | "sentencepiece.bpe.model": 150 | "https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_m_base/sentencepiece.bpe.model" 151 | 152 | } 153 | }, 154 | "uie-m-large": { 155 | "resource_file_urls": { 156 | "model_state.pdparams": 157 | "https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_m_large_v1.0/model_state.pdparams", 158 | "model_config.json": 159 | "https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_m_large/model_config.json", 160 | "vocab.txt": 161 | "https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_m_large/vocab.txt", 162 | "special_tokens_map.json": 163 | "https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_m_large/special_tokens_map.json", 164 | "tokenizer_config.json": 165 | "https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_m_large/tokenizer_config.json", 166 | "sentencepiece.bpe.model": 167 | "https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_m_base/sentencepiece.bpe.model" 168 | } 169 | }, 170 | # Rename to `uie-medium` and the name of `uie-tiny` will be deprecated in future. 171 | "uie-tiny": { 172 | "resource_file_urls": { 173 | "model_state.pdparams": 174 | "https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_tiny_v0.1/model_state.pdparams", 175 | "model_config.json": 176 | "https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_tiny/model_config.json", 177 | "vocab.txt": 178 | "https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_tiny/vocab.txt", 179 | "special_tokens_map.json": 180 | "https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_tiny/special_tokens_map.json", 181 | "tokenizer_config.json": 182 | "https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_tiny/tokenizer_config.json" 183 | } 184 | }, 185 | "ernie-3.0-base-zh": { 186 | "resource_file_urls": { 187 | "model_state.pdparams": 188 | "https://bj.bcebos.com/paddlenlp/models/transformers/ernie_3.0/ernie_3.0_base_zh.pdparams", 189 | "model_config.json": 190 | "base64:ew0KICAiYXR0ZW50aW9uX3Byb2JzX2Ryb3BvdXRfcHJvYiI6IDAuMSwNCiAgImhpZGRlbl9hY3QiOiAiZ2VsdSIsDQogICJoaWRkZW5fZHJvcG91dF9wcm9iIjogMC4xLA0KICAiaGlkZGVuX3NpemUiOiA3NjgsDQogICJpbml0aWFsaXplcl9yYW5nZSI6IDAuMDIsDQogICJtYXhfcG9zaXRpb25fZW1iZWRkaW5ncyI6IDIwNDgsDQogICJudW1fYXR0ZW50aW9uX2hlYWRzIjogMTIsDQogICJudW1faGlkZGVuX2xheWVycyI6IDEyLA0KICAidGFza190eXBlX3ZvY2FiX3NpemUiOiAzLA0KICAidHlwZV92b2NhYl9zaXplIjogNCwNCiAgInVzZV90YXNrX2lkIjogdHJ1ZSwNCiAgInZvY2FiX3NpemUiOiA0MDAwMCwNCiAgImluaXRfY2xhc3MiOiAiRXJuaWVNb2RlbCINCn0=", 191 | "vocab.txt": 192 | "https://bj.bcebos.com/paddlenlp/models/transformers/ernie_3.0/ernie_3.0_base_zh_vocab.txt", 193 | "special_tokens_map.json": 194 | "base64:eyJ1bmtfdG9rZW4iOiAiW1VOS10iLCAic2VwX3Rva2VuIjogIltTRVBdIiwgInBhZF90b2tlbiI6ICJbUEFEXSIsICJjbHNfdG9rZW4iOiAiW0NMU10iLCAibWFza190b2tlbiI6ICJbTUFTS10ifQ==", 195 | "tokenizer_config.json": 196 | "base64:eyJkb19sb3dlcl9jYXNlIjogdHJ1ZSwgInVua190b2tlbiI6ICJbVU5LXSIsICJzZXBfdG9rZW4iOiAiW1NFUF0iLCAicGFkX3Rva2VuIjogIltQQURdIiwgImNsc190b2tlbiI6ICJbQ0xTXSIsICJtYXNrX3Rva2VuIjogIltNQVNLXSIsICJ0b2tlbml6ZXJfY2xhc3MiOiAiRXJuaWVUb2tlbml6ZXIifQ==" 197 | } 198 | } 199 | } 200 | 201 | 202 | def build_params_map(model_prefix='encoder', attention_num=12): 203 | """ 204 | build params map from paddle-paddle's ERNIE to transformer's BERT 205 | :return: 206 | """ 207 | weight_map = collections.OrderedDict({ 208 | f'{model_prefix}.embeddings.word_embeddings.weight': "encoder.embeddings.word_embeddings.weight", 209 | f'{model_prefix}.embeddings.position_embeddings.weight': "encoder.embeddings.position_embeddings.weight", 210 | f'{model_prefix}.embeddings.token_type_embeddings.weight': "encoder.embeddings.token_type_embeddings.weight", 211 | f'{model_prefix}.embeddings.task_type_embeddings.weight': "encoder.embeddings.task_type_embeddings.weight", 212 | f'{model_prefix}.embeddings.layer_norm.weight': 'encoder.embeddings.LayerNorm.gamma', 213 | f'{model_prefix}.embeddings.layer_norm.bias': 'encoder.embeddings.LayerNorm.beta', 214 | }) 215 | # add attention layers 216 | for i in range(attention_num): 217 | weight_map[f'{model_prefix}.encoder.layers.{i}.self_attn.q_proj.weight'] = f'encoder.encoder.layer.{i}.attention.self.query.weight' 218 | weight_map[f'{model_prefix}.encoder.layers.{i}.self_attn.q_proj.bias'] = f'encoder.encoder.layer.{i}.attention.self.query.bias' 219 | weight_map[f'{model_prefix}.encoder.layers.{i}.self_attn.k_proj.weight'] = f'encoder.encoder.layer.{i}.attention.self.key.weight' 220 | weight_map[f'{model_prefix}.encoder.layers.{i}.self_attn.k_proj.bias'] = f'encoder.encoder.layer.{i}.attention.self.key.bias' 221 | weight_map[f'{model_prefix}.encoder.layers.{i}.self_attn.v_proj.weight'] = f'encoder.encoder.layer.{i}.attention.self.value.weight' 222 | weight_map[f'{model_prefix}.encoder.layers.{i}.self_attn.v_proj.bias'] = f'encoder.encoder.layer.{i}.attention.self.value.bias' 223 | weight_map[f'{model_prefix}.encoder.layers.{i}.self_attn.out_proj.weight'] = f'encoder.encoder.layer.{i}.attention.output.dense.weight' 224 | weight_map[f'{model_prefix}.encoder.layers.{i}.self_attn.out_proj.bias'] = f'encoder.encoder.layer.{i}.attention.output.dense.bias' 225 | weight_map[f'{model_prefix}.encoder.layers.{i}.norm1.weight'] = f'encoder.encoder.layer.{i}.attention.output.LayerNorm.gamma' 226 | weight_map[f'{model_prefix}.encoder.layers.{i}.norm1.bias'] = f'encoder.encoder.layer.{i}.attention.output.LayerNorm.beta' 227 | weight_map[f'{model_prefix}.encoder.layers.{i}.linear1.weight'] = f'encoder.encoder.layer.{i}.intermediate.dense.weight' 228 | weight_map[f'{model_prefix}.encoder.layers.{i}.linear1.bias'] = f'encoder.encoder.layer.{i}.intermediate.dense.bias' 229 | weight_map[f'{model_prefix}.encoder.layers.{i}.linear2.weight'] = f'encoder.encoder.layer.{i}.output.dense.weight' 230 | weight_map[f'{model_prefix}.encoder.layers.{i}.linear2.bias'] = f'encoder.encoder.layer.{i}.output.dense.bias' 231 | weight_map[f'{model_prefix}.encoder.layers.{i}.norm2.weight'] = f'encoder.encoder.layer.{i}.output.LayerNorm.gamma' 232 | weight_map[f'{model_prefix}.encoder.layers.{i}.norm2.bias'] = f'encoder.encoder.layer.{i}.output.LayerNorm.beta' 233 | # add pooler 234 | weight_map.update( 235 | { 236 | f'{model_prefix}.pooler.dense.weight': 'encoder.pooler.dense.weight', 237 | f'{model_prefix}.pooler.dense.bias': 'encoder.pooler.dense.bias', 238 | 'linear_start.weight': 'linear_start.weight', 239 | 'linear_start.bias': 'linear_start.bias', 240 | 'linear_end.weight': 'linear_end.weight', 241 | 'linear_end.bias': 'linear_end.bias', 242 | } 243 | ) 244 | return weight_map 245 | 246 | 247 | def extract_and_convert(input_dir, output_dir, verbose=False): 248 | if not os.path.exists(output_dir): 249 | os.makedirs(output_dir) 250 | if verbose: 251 | logger.info('=' * 20 + 'save config file' + '=' * 20) 252 | config = json.load( 253 | open(os.path.join(input_dir, 'model_config.json'), 'rt', encoding='utf-8')) 254 | if 'init_args' in config: 255 | config = config['init_args'][0] 256 | config["architectures"] = ["UIE"] 257 | config['layer_norm_eps'] = 1e-12 258 | del config['init_class'] 259 | if 'sent_type_vocab_size' in config: 260 | config['type_vocab_size'] = config['sent_type_vocab_size'] 261 | config['intermediate_size'] = 4 * config['hidden_size'] 262 | json.dump(config, open(os.path.join(output_dir, 'config.json'), 263 | 'wt', encoding='utf-8'), indent=4) 264 | if verbose: 265 | logger.info('=' * 20 + 'save vocab file' + '=' * 20) 266 | shutil.copy(os.path.join(input_dir, 'vocab.txt'), 267 | os.path.join(output_dir, 'vocab.txt')) 268 | special_tokens_map = json.load(open(os.path.join( 269 | input_dir, 'special_tokens_map.json'), 'rt', encoding='utf-8')) 270 | json.dump(special_tokens_map, open(os.path.join(output_dir, 'special_tokens_map.json'), 271 | 'wt', encoding='utf-8')) 272 | tokenizer_config = json.load( 273 | open(os.path.join(input_dir, 'tokenizer_config.json'), 'rt', encoding='utf-8')) 274 | if tokenizer_config['tokenizer_class'] == 'ErnieTokenizer': 275 | tokenizer_config['tokenizer_class'] = "BertTokenizer" 276 | json.dump(tokenizer_config, open(os.path.join(output_dir, 'tokenizer_config.json'), 277 | 'wt', encoding='utf-8')) 278 | spm_file = os.path.join(input_dir, 'sentencepiece.bpe.model') 279 | if os.path.exists(spm_file): 280 | shutil.copy(spm_file, os.path.join( 281 | output_dir, 'sentencepiece.bpe.model')) 282 | if verbose: 283 | logger.info('=' * 20 + 'extract weights' + '=' * 20) 284 | state_dict = collections.OrderedDict() 285 | weight_map = build_params_map(attention_num=config['num_hidden_layers']) 286 | weight_map.update(build_params_map( 287 | 'ernie', attention_num=config['num_hidden_layers'])) 288 | if paddle_installed: 289 | import paddle.fluid.dygraph as D 290 | from paddle import fluid 291 | with fluid.dygraph.guard(): 292 | paddle_paddle_params, _ = D.load_dygraph( 293 | os.path.join(input_dir, 'model_state')) 294 | else: 295 | paddle_paddle_params = pickle.load( 296 | open(os.path.join(input_dir, 'model_state.pdparams'), 'rb')) 297 | del paddle_paddle_params['StructuredToParameterName@@'] 298 | for weight_name, weight_value in paddle_paddle_params.items(): 299 | transposed = '' 300 | if 'weight' in weight_name: 301 | if '.encoder' in weight_name or 'pooler' in weight_name or 'linear' in weight_name: 302 | weight_value = weight_value.transpose() 303 | transposed = '.T' 304 | # Fix: embedding error 305 | if 'word_embeddings.weight' in weight_name: 306 | weight_value[0, :] = 0 307 | if weight_name not in weight_map: 308 | if verbose: 309 | logger.info(f"{'='*20} [SKIP] {weight_name} {'='*20}") 310 | continue 311 | state_dict[weight_map[weight_name]] = torch.FloatTensor(weight_value) 312 | if verbose: 313 | logger.info( 314 | f"{weight_name}{transposed} -> {weight_map[weight_name]} {weight_value.shape}") 315 | torch.save(state_dict, os.path.join(output_dir, "pytorch_model.bin")) 316 | 317 | 318 | def check_model(input_model): 319 | if not os.path.exists(input_model): 320 | if input_model not in MODEL_MAP: 321 | raise ValueError('input_model not exists!') 322 | 323 | resource_file_urls = MODEL_MAP[input_model]['resource_file_urls'] 324 | logger.info("Downloading resource files...") 325 | 326 | for key, val in resource_file_urls.items(): 327 | file_path = os.path.join(input_model, key) 328 | if not os.path.exists(file_path): 329 | if val.startswith('base64:'): 330 | base64data = b64decode(val.replace( 331 | 'base64:', '').encode('utf-8')) 332 | with open(file_path, 'wb') as f: 333 | f.write(base64data) 334 | else: 335 | download_path = get_path_from_url(val, input_model) 336 | if download_path != file_path: 337 | shutil.move(download_path, file_path) 338 | 339 | 340 | def validate_model(tokenizer, pt_model, pd_model, model_type='uie', atol: float = 1e-5): 341 | 342 | logger.info("Validating PyTorch model...") 343 | 344 | batch_size = 2 345 | seq_length = 6 346 | seq_length_with_token = seq_length+2 347 | max_seq_length = 512 348 | dummy_input = [" ".join([tokenizer.unk_token]) 349 | * seq_length] * batch_size 350 | encoded_inputs = dict(tokenizer(dummy_input, pad_to_max_seq_len=True, max_seq_len=512, return_attention_mask=True, 351 | return_position_ids=True)) 352 | paddle_inputs = {} 353 | for name, value in encoded_inputs.items(): 354 | if name == "attention_mask": 355 | if model_type == 'uie-m': 356 | continue 357 | name = "att_mask" 358 | if name == "position_ids": 359 | name = "pos_ids" 360 | paddle_inputs[name] = paddle.to_tensor(value, dtype=paddle.int64) 361 | 362 | paddle_named_outputs = ['start_prob', 'end_prob'] 363 | paddle_outputs = pd_model(**paddle_inputs) 364 | 365 | torch_inputs = {} 366 | for name, value in encoded_inputs.items(): 367 | if name == "attention_mask": 368 | if model_type == 'uie-m': 369 | continue 370 | torch_inputs[name] = torch.tensor(value, dtype=torch.int64) 371 | torch_outputs = pt_model(**torch_inputs) 372 | torch_outputs_dict = {} 373 | 374 | for name, value in torch_outputs.items(): 375 | torch_outputs_dict[name] = value 376 | 377 | torch_outputs_set, ref_outputs_set = set( 378 | torch_outputs_dict.keys()), set(paddle_named_outputs) 379 | if not torch_outputs_set.issubset(ref_outputs_set): 380 | logger.info( 381 | f"\t-[x] Pytorch model output names {torch_outputs_set} do not match reference model {ref_outputs_set}" 382 | ) 383 | 384 | raise ValueError( 385 | "Outputs doesn't match between reference model and Pytorch converted model: " 386 | f"{torch_outputs_set.difference(ref_outputs_set)}" 387 | ) 388 | else: 389 | logger.info( 390 | f"\t-[✓] Pytorch model output names match reference model ({torch_outputs_set})") 391 | 392 | # Check the shape and values match 393 | for name, ref_value in zip(paddle_named_outputs, paddle_outputs): 394 | ref_value = ref_value.numpy() 395 | pt_value = torch_outputs_dict[name].detach().numpy() 396 | logger.info(f'\t- Validating PyTorch Model output "{name}":') 397 | 398 | # Shape 399 | if not pt_value.shape == ref_value.shape: 400 | logger.info( 401 | f"\t\t-[x] shape {pt_value.shape} doesn't match {ref_value.shape}") 402 | raise ValueError( 403 | "Outputs shape doesn't match between reference model and Pytorch converted model: " 404 | f"Got {ref_value.shape} (reference) and {pt_value.shape} (PyTorch)" 405 | ) 406 | else: 407 | logger.info( 408 | f"\t\t-[✓] {pt_value.shape} matches {ref_value.shape}") 409 | 410 | # Values 411 | if not np.allclose(ref_value, pt_value, atol=atol): 412 | logger.info( 413 | f"\t\t-[x] values not close enough (atol: {atol})") 414 | raise ValueError( 415 | "Outputs values doesn't match between reference model and Pytorch converted model: " 416 | f"Got max absolute difference of: {np.amax(np.abs(ref_value - pt_value))}" 417 | ) 418 | else: 419 | logger.info( 420 | f"\t\t-[✓] all values close (atol: {atol})") 421 | 422 | 423 | def do_main(): 424 | if args.output_model is None: 425 | args.output_model = args.input_model.replace('-', '_')+'_pytorch' 426 | check_model(args.input_model) 427 | extract_and_convert(args.input_model, args.output_model, verbose=True) 428 | if not (args.no_validate_output or 'ernie' in args.input_model): 429 | if paddle_installed: 430 | try: 431 | from paddlenlp.transformers import ErnieTokenizer, ErnieMTokenizer 432 | from paddlenlp.taskflow.models import UIE as UIEPaddle, UIEM as UIEMPaddle 433 | except (ImportError, ModuleNotFoundError) as e: 434 | raise ModuleNotFoundError( 435 | 'Module PaddleNLP is not installed. Try install paddlenlp or run convert.py with --no_validate_output') from e 436 | if 'uie-m' in args.input_model: 437 | tokenizer: ErnieMTokenizer = ErnieMTokenizer.from_pretrained( 438 | args.input_model) 439 | model = UIEM.from_pretrained(args.output_model) 440 | model.eval() 441 | paddle_model = UIEMPaddle.from_pretrained(args.input_model) 442 | paddle_model.eval() 443 | model_type = 'uie-m' 444 | else: 445 | tokenizer: ErnieTokenizer = ErnieTokenizer.from_pretrained( 446 | args.input_model) 447 | model = UIE.from_pretrained(args.output_model) 448 | model.eval() 449 | paddle_model = UIEPaddle.from_pretrained(args.input_model) 450 | paddle_model.eval() 451 | model_type = 'uie' 452 | validate_model(tokenizer, model, paddle_model, model_type) 453 | else: 454 | logger.warning("Skipping validating PyTorch model because paddle is not installed. " 455 | "The outputs of the model may not be the same as Paddle model.") 456 | 457 | 458 | if __name__ == '__main__': 459 | parser = argparse.ArgumentParser() 460 | parser.add_argument("-i", "--input_model", default="uie-base", type=str, 461 | help="Directory of input paddle model.\n Will auto download model [uie-base/uie-tiny]") 462 | parser.add_argument("-o", "--output_model", default=None, type=str, 463 | help="Directory of output pytorch model") 464 | parser.add_argument("--no_validate_output", action="store_true", 465 | help="Directory of output pytorch model") 466 | args = parser.parse_args() 467 | 468 | do_main() 469 | -------------------------------------------------------------------------------- /doccano.md: -------------------------------------------------------------------------------- 1 | # doccano 2 | 3 | **目录** 4 | 5 | * [1. 安装](#安装) 6 | * [2. 项目创建](#项目创建) 7 | * [3. 数据上传](#数据上传) 8 | * [4. 标签构建](#标签构建) 9 | * [5. 任务标注](#任务标注) 10 | * [6. 数据导出](#数据导出) 11 | * [7. 数据转换](#数据转换) 12 | 13 | 14 | 15 | ## 1. 安装 16 | 17 | 参考[doccano官方文档](https://github.com/doccano/doccano) 完成doccano的安装与初始配置。 18 | 19 | **以下标注示例用到的环境配置:** 20 | 21 | - doccano 1.6.2 22 | 23 | 24 | 25 | ## 2. 项目创建 26 | 27 | UIE支持抽取与分类两种类型的任务,根据实际需要创建一个新的项目: 28 | 29 | #### 2.1 抽取式任务项目创建 30 | 31 | 创建项目时选择**序列标注**任务,并勾选**Allow overlapping entity**及**Use relation Labeling**。适配**命名实体识别、关系抽取、事件抽取、评价观点抽取**等任务。 32 | 33 |
34 | 35 |
36 | 37 | #### 2.2 分类式任务项目创建 38 | 39 | 创建项目时选择**文本分类**任务。适配**文本分类、句子级情感倾向分类**等任务。 40 | 41 |
42 | 43 |
44 | 45 | 46 | 47 | ## 3. 数据上传 48 | 49 | 上传的文件为txt格式,每一行为一条待标注文本,示例: 50 | 51 | ```text 52 | 2月8日上午北京冬奥会自由式滑雪女子大跳台决赛中中国选手谷爱凌以188.25分获得金牌 53 | 第十四届全运会在西安举办 54 | ``` 55 | 56 | 上传数据类型**选择TextLine**: 57 | 58 |
59 | 60 |
61 | 62 | **NOTE**:doccano支持`TextFile`、`TextLine`、`JSONL`和`CoNLL`四种数据上传格式,UIE定制训练中**统一使用TextLine**这一文件格式,即上传的文件需要为txt格式,且在数据标注时,该文件的每一行待标注文本显示为一页内容。 63 | 64 | 65 | 66 | ## 4. 标签构建 67 | 68 | #### 4.1 构建抽取式任务标签 69 | 70 | 抽取式任务包含**Span**与**Relation**两种标签类型,Span指**原文本中的目标信息片段**,如实体识别中某个类型的实体,事件抽取中的触发词和论元;Relation指**原文本中Span之间的关系**,如关系抽取中两个实体(Subject&Object)之间的关系,事件抽取中论元和触发词之间的关系。 71 | 72 | Span类型标签构建示例: 73 | 74 |
75 | 76 |
77 | 78 | Relation类型标签构建示例: 79 | 80 |
81 | 82 |
83 | 84 | #### 4.2 构建分类式任务标签 85 | 86 | 添加分类类别标签: 87 | 88 |
89 | 90 |
91 | 92 | 93 | 94 | ## 5. 任务标注 95 | 96 | #### 5.1 命名实体识别 97 | 98 | 命名实体识别(Named Entity Recognition,简称NER),是指识别文本中具有特定意义的实体。在开放域信息抽取中,**抽取的类别没有限制,用户可以自己定义**。 99 | 100 | 标注示例: 101 | 102 |
103 | 104 |
105 | 106 | 示例中定义了`时间`、`选手`、`赛事名称`和`得分`四种Span类型标签。 107 | 108 | ```text 109 | schema = [ 110 | '时间', 111 | '选手', 112 | '赛事名称', 113 | '得分' 114 | ] 115 | ``` 116 | 117 | #### 5.2 关系抽取 118 | 119 | 关系抽取(Relation Extraction,简称RE),是指从文本中识别实体并抽取实体之间的语义关系,即抽取三元组(实体一,关系类型,实体二)。 120 | 121 | 标注示例: 122 | 123 |
124 | 125 |
126 | 127 | 示例中定义了`作品名`、`人物名`和`时间`三种Span类型标签,以及`歌手`、`发行时间`和`所属专辑`三种Relation标签。Relation标签**由Subject对应实体指向Object对应实体**。 128 | 129 | 该标注示例对应的schema为: 130 | 131 | ```text 132 | schema = { 133 | '作品名': [ 134 | '歌手', 135 | '发行时间', 136 | '所属专辑' 137 | ] 138 | } 139 | ``` 140 | 141 | #### 5.3 事件抽取 142 | 143 | 事件抽取 (Event Extraction, 简称EE),是指从自然语言文本中抽取事件并识别事件类型和事件论元的技术。UIE所包含的事件抽取任务,是指根据已知事件类型,抽取该事件所包含的事件论元。 144 | 145 | 标注示例: 146 | 147 |
148 | 149 |
150 | 151 | 示例中定义了`地震触发词`(触发词)、`等级`(事件论元)和`时间`(事件论元)三种Span标签,以及`时间`和`震级`两种Relation标签。触发词标签**统一格式为`XX触发词`**,`XX`表示具体事件类型,上例中的事件类型是`地震`,则对应触发词为`地震触发词`。Relation标签**由触发词指向对应的事件论元**。 152 | 153 | 该标注示例对应的schema为: 154 | 155 | ```text 156 | schema = { 157 | '地震触发词': [ 158 | '时间', 159 | '震级' 160 | ] 161 | } 162 | ``` 163 | 164 | #### 5.4 评价观点抽取 165 | 166 | 评论观点抽取,是指抽取文本中包含的评价维度、观点词。 167 | 168 | 标注示例: 169 | 170 |
171 | 172 |
173 | 174 | 示例中定义了`评价维度`和`观点词`两种Span标签,以及`观点词`一种Relation标签。Relation标签**由评价维度指向观点词**。 175 | 176 | 该标注示例对应的schema为: 177 | 178 | ```text 179 | schema = { 180 | '评价维度': '观点词' 181 | } 182 | ``` 183 | 184 | #### 5.5 句子级分类任务 185 | 186 | 标注示例: 187 | 188 |
189 | 190 |
191 | 192 | 示例中定义了`正向`和`负向`两种类别标签对文本的情感倾向进行分类。 193 | 194 | 该标注示例对应的schema为: 195 | 196 | ```text 197 | schema = '情感倾向[正向,负向]' 198 | ``` 199 | 200 | #### 5.6 实体/评价维度级分类任务 201 | 202 |
203 | 204 |
205 | 206 | 标注示例: 207 | 208 | 示例中定义了`评价维度##正向`,`评价维度##负向`和`观点词`三种Span标签以及`观点词`一种Relation标签。其中,`##`是实体类别/评价维度与分类标签的分隔符(可通过doccano.py中的separator参数自定义)。 209 | 210 | 该标注示例对应的schema为: 211 | 212 | ```text 213 | schema = { 214 | '评价维度': [ 215 | '观点词', 216 | '情感倾向[正向,负向]' 217 | ] 218 | } 219 | ``` 220 | 221 | 222 | 223 | ## 6. 数据导出 224 | 225 | #### 6.1 导出抽取式和实体/评价维度级分类任务数据 226 | 227 | 选择导出的文件类型为``JSONL(relation)``,导出数据示例: 228 | 229 | ```text 230 | { 231 | "id": 38, 232 | "text": "百科名片你知道我要什么,是歌手高明骏演唱的一首歌曲,1989年发行,收录于个人专辑《丛林男孩》中", 233 | "relations": [ 234 | { 235 | "id": 20, 236 | "from_id": 51, 237 | "to_id": 53, 238 | "type": "歌手" 239 | }, 240 | { 241 | "id": 21, 242 | "from_id": 51, 243 | "to_id": 55, 244 | "type": "发行时间" 245 | }, 246 | { 247 | "id": 22, 248 | "from_id": 51, 249 | "to_id": 54, 250 | "type": "所属专辑" 251 | } 252 | ], 253 | "entities": [ 254 | { 255 | "id": 51, 256 | "start_offset": 4, 257 | "end_offset": 11, 258 | "label": "作品名" 259 | }, 260 | { 261 | "id": 53, 262 | "start_offset": 15, 263 | "end_offset": 18, 264 | "label": "人物名" 265 | }, 266 | { 267 | "id": 54, 268 | "start_offset": 42, 269 | "end_offset": 46, 270 | "label": "作品名" 271 | }, 272 | { 273 | "id": 55, 274 | "start_offset": 26, 275 | "end_offset": 31, 276 | "label": "时间" 277 | } 278 | ] 279 | } 280 | ``` 281 | 282 | 标注数据保存在同一个文本文件中,每条样例占一行且存储为``json``格式,其包含以下字段 283 | - ``id``: 样本在数据集中的唯一标识ID。 284 | - ``text``: 原始文本数据。 285 | - ``entities``: 数据中包含的Span标签,每个Span标签包含四个字段: 286 | - ``id``: Span在数据集中的唯一标识ID。 287 | - ``start_offset``: Span的起始token在文本中的下标。 288 | - ``end_offset``: Span的结束token在文本中下标的下一个位置。 289 | - ``label``: Span类型。 290 | - ``relations``: 数据中包含的Relation标签,每个Relation标签包含四个字段: 291 | - ``id``: (Span1, Relation, Span2)三元组在数据集中的唯一标识ID,不同样本中的相同三元组对应同一个ID。 292 | - ``from_id``: Span1对应的标识ID。 293 | - ``to_id``: Span2对应的标识ID。 294 | - ``type``: Relation类型。 295 | 296 | #### 6.2 导出句子级分类任务数据 297 | 298 | 选择导出的文件类型为``JSONL``,导出数据示例: 299 | 300 | ```text 301 | { 302 | "id": 41, 303 | "data": "大年初一就把车前保险杠给碰坏了,保险杠和保险公司 真够倒霉的,我决定步行反省。", 304 | "label": [ 305 | "负向" 306 | ] 307 | } 308 | ``` 309 | 310 | 标注数据保存在同一个文本文件中,每条样例占一行且存储为``json``格式,其包含以下字段 311 | - ``id``: 样本在数据集中的唯一标识ID。 312 | - ``data``: 原始文本数据。 313 | - ``label``: 文本对应类别标签。 314 | 315 | 316 | 317 | ## 7.数据转换 318 | 319 | 该章节详细说明如何通过`doccano.py`脚本对doccano平台导出的标注数据进行转换,一键生成训练/验证/测试集。 320 | 321 | #### 7.1 抽取式任务数据转换 322 | 323 | - 当标注完成后,在 doccano 平台上导出 `JSONL(relation)` 形式的文件,并将其重命名为 `doccano_ext.json` 后,放入 `./data` 目录下。 324 | - 通过 [doccano.py](./doccano.py) 脚本进行数据形式转换,然后便可以开始进行相应模型训练。 325 | 326 | ```shell 327 | python doccano.py \ 328 | --doccano_file ./data/doccano_ext.json \ 329 | --task_type "ext" \ 330 | --save_dir ./data \ 331 | --negative_ratio 5 332 | ``` 333 | 334 | #### 7.2 句子级分类任务数据转换 335 | 336 | - 当标注完成后,在 doccano 平台上导出 `JSON` 形式的文件,并将其重命名为 `doccano_cls.json` 后,放入 `./data` 目录下。 337 | - 在数据转换阶段,我们会自动构造用于模型训练的prompt信息。例如句子级情感分类中,prompt为``情感倾向[正向,负向]``,可以通过`prompt_prefix`和`options`参数进行声明。 338 | - 通过 [doccano.py](./doccano.py) 脚本进行数据形式转换,然后便可以开始进行相应模型训练。 339 | 340 | ```shell 341 | python doccano.py \ 342 | --doccano_file ./data/doccano_cls.json \ 343 | --task_type "cls" \ 344 | --save_dir ./data \ 345 | --splits 0.8 0.1 0.1 \ 346 | --prompt_prefix "情感倾向" \ 347 | --options "正向" "负向" 348 | ``` 349 | 350 | #### 7.3 实体/评价维度级分类任务数据转换 351 | 352 | - 当标注完成后,在 doccano 平台上导出 `JSONL(relation)` 形式的文件,并将其重命名为 `doccano_ext.json` 后,放入 `./data` 目录下。 353 | - 在数据转换阶段,我们会自动构造用于模型训练的prompt信息。例如评价维度级情感分类中,prompt为``XXX的情感倾向[正向,负向]``,可以通过`prompt_prefix`和`options`参数进行声明。 354 | - 通过 [doccano.py](./doccano.py) 脚本进行数据形式转换,然后便可以开始进行相应模型训练。 355 | 356 | ```shell 357 | python doccano.py \ 358 | --doccano_file ./data/doccano_ext.json \ 359 | --task_type "ext" \ 360 | --save_dir ./data \ 361 | --splits 0.8 0.1 0.1 \ 362 | --prompt_prefix "情感倾向" \ 363 | --options "正向" "负向" \ 364 | --separator "##" 365 | ``` 366 | 367 | 可配置参数说明: 368 | 369 | - ``doccano_file``: 从doccano导出的数据标注文件。 370 | - ``save_dir``: 训练数据的保存目录,默认存储在``data``目录下。 371 | - ``negative_ratio``: 最大负例比例,该参数只对抽取类型任务有效,适当构造负例可提升模型效果。负例数量和实际的标签数量有关,最大负例数量 = negative_ratio * 正例数量。该参数只对训练集有效,默认为5。为了保证评估指标的准确性,验证集和测试集默认构造全负例。 372 | - ``splits``: 划分数据集时训练集、验证集所占的比例。默认为[0.8, 0.1, 0.1]表示按照``8:1:1``的比例将数据划分为训练集、验证集和测试集。 373 | - ``task_type``: 选择任务类型,可选有抽取和分类两种类型的任务。 374 | - ``options``: 指定分类任务的类别标签,该参数只对分类类型任务有效。默认为["正向", "负向"]。 375 | - ``prompt_prefix``: 声明分类任务的prompt前缀信息,该参数只对分类类型任务有效。默认为"情感倾向"。 376 | - ``is_shuffle``: 是否对数据集进行随机打散,默认为True。 377 | - ``seed``: 随机种子,默认为1000. 378 | - ``separator``: 实体类别/评价维度与分类标签的分隔符,该参数只对实体/评价维度级分类任务有效。默认为"##"。 379 | 380 | 备注: 381 | - 默认情况下 [doccano.py](./doccano.py) 脚本会按照比例将数据划分为 train/dev/test 数据集 382 | - 每次执行 [doccano.py](./doccano.py) 脚本,将会覆盖已有的同名数据文件 383 | - 在模型训练阶段我们推荐构造一些负例以提升模型效果,在数据转换阶段我们内置了这一功能。可通过`negative_ratio`控制自动构造的负样本比例;负样本数量 = negative_ratio * 正样本数量。 384 | - 对于从doccano导出的文件,默认文件中的每条数据都是经过人工正确标注的。 385 | 386 | ## References 387 | - **[doccano](https://github.com/doccano/doccano)** 388 | -------------------------------------------------------------------------------- /doccano.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import os 17 | import time 18 | import argparse 19 | import json 20 | from decimal import Decimal 21 | import numpy as np 22 | 23 | from utils import set_seed, convert_ext_examples, convert_cls_examples, logger 24 | 25 | 26 | def do_convert(): 27 | set_seed(args.seed) 28 | 29 | tic_time = time.time() 30 | if not os.path.exists(args.doccano_file): 31 | raise ValueError("Please input the correct path of doccano file.") 32 | 33 | if not os.path.exists(args.save_dir): 34 | os.makedirs(args.save_dir) 35 | 36 | if len(args.splits) != 0 and len(args.splits) != 3: 37 | raise ValueError("Only []/ len(splits)==3 accepted for splits.") 38 | 39 | def _check_sum(splits): 40 | return Decimal(str(splits[0])) + Decimal(str(splits[1])) + Decimal( 41 | str(splits[2])) == Decimal("1") 42 | 43 | if len(args.splits) == 3 and not _check_sum(args.splits): 44 | raise ValueError( 45 | "Please set correct splits, sum of elements in splits should be equal to 1." 46 | ) 47 | 48 | with open(args.doccano_file, "r", encoding="utf-8") as f: 49 | raw_examples = f.readlines() 50 | 51 | def _create_ext_examples(examples, 52 | negative_ratio, 53 | prompt_prefix="情感倾向", 54 | options=["正向", "负向"], 55 | separator="##", 56 | shuffle=False, 57 | is_train=True): 58 | entities, relations, aspects = convert_ext_examples( 59 | examples, negative_ratio, prompt_prefix, options, separator, 60 | is_train) 61 | examples = entities + relations + aspects 62 | if shuffle: 63 | indexes = np.random.permutation(len(examples)) 64 | examples = [examples[i] for i in indexes] 65 | return examples 66 | 67 | def _create_cls_examples(examples, prompt_prefix, options, shuffle=False): 68 | examples = convert_cls_examples(examples, prompt_prefix, options) 69 | if shuffle: 70 | indexes = np.random.permutation(len(examples)) 71 | examples = [examples[i] for i in indexes] 72 | return examples 73 | 74 | def _save_examples(save_dir, file_name, examples): 75 | count = 0 76 | save_path = os.path.join(save_dir, file_name) 77 | if not examples: 78 | logger.info("Skip saving %d examples to %s." % (0, save_path)) 79 | return 80 | with open(save_path, "w", encoding="utf-8") as f: 81 | for example in examples: 82 | f.write(json.dumps(example, ensure_ascii=False) + "\n") 83 | count += 1 84 | logger.info("Save %d examples to %s." % (count, save_path)) 85 | 86 | if len(args.splits) == 0: 87 | if args.task_type == "ext": 88 | examples = _create_ext_examples(raw_examples, args.negative_ratio, 89 | args.prompt_prefix, args.options, 90 | args.separator, args.is_shuffle) 91 | else: 92 | examples = _create_cls_examples(raw_examples, args.prompt_prefix, 93 | args.options, args.is_shuffle) 94 | _save_examples(args.save_dir, "train.txt", examples) 95 | else: 96 | if args.is_shuffle: 97 | indexes = np.random.permutation(len(raw_examples)) 98 | index_list = indexes.tolist() 99 | raw_examples = [raw_examples[i] for i in indexes] 100 | 101 | i1, i2, _ = args.splits 102 | p1 = int(len(raw_examples) * i1) 103 | p2 = int(len(raw_examples) * (i1 + i2)) 104 | 105 | train_ids = index_list[:p1] 106 | dev_ids = index_list[p1:p2] 107 | test_ids = index_list[p2:] 108 | 109 | with open(os.path.join(args.save_dir, "sample_index.json"), "w") as fp: 110 | maps = { 111 | "train_ids": train_ids, 112 | "dev_ids": dev_ids, 113 | "test_ids": test_ids 114 | } 115 | fp.write(json.dumps(maps)) 116 | 117 | if args.task_type == "ext": 118 | train_examples = _create_ext_examples(raw_examples[:p1], 119 | args.negative_ratio, 120 | args.prompt_prefix, 121 | args.options, args.separator, 122 | args.is_shuffle) 123 | dev_examples = _create_ext_examples(raw_examples[p1:p2], 124 | -1, 125 | args.prompt_prefix, 126 | args.options, 127 | args.separator, 128 | is_train=False) 129 | test_examples = _create_ext_examples(raw_examples[p2:], 130 | -1, 131 | args.prompt_prefix, 132 | args.options, 133 | args.separator, 134 | is_train=False) 135 | else: 136 | train_examples = _create_cls_examples(raw_examples[:p1], 137 | args.prompt_prefix, 138 | args.options) 139 | dev_examples = _create_cls_examples(raw_examples[p1:p2], 140 | args.prompt_prefix, 141 | args.options) 142 | test_examples = _create_cls_examples(raw_examples[p2:], 143 | args.prompt_prefix, 144 | args.options) 145 | 146 | _save_examples(args.save_dir, "train.txt", train_examples) 147 | _save_examples(args.save_dir, "dev.txt", dev_examples) 148 | _save_examples(args.save_dir, "test.txt", test_examples) 149 | 150 | logger.info('Finished! It takes %.2f seconds' % (time.time() - tic_time)) 151 | 152 | 153 | if __name__ == "__main__": 154 | # yapf: disable 155 | parser = argparse.ArgumentParser() 156 | 157 | parser.add_argument("-d", "--doccano_file", default="./data/doccano.json", 158 | type=str, help="The doccano file exported from doccano platform.") 159 | parser.add_argument("-s", "--save_dir", default="./data", 160 | type=str, help="The path of data that you wanna save.") 161 | parser.add_argument("--negative_ratio", default=5, type=int, 162 | help="Used only for the extraction task, the ratio of positive and negative samples, number of negtive samples = negative_ratio * number of positive samples") 163 | parser.add_argument("--splits", default=[0.8, 0.1, 0.1], type=float, nargs="*", 164 | help="The ratio of samples in datasets. [0.6, 0.2, 0.2] means 60%% samples used for training, 20%% for evaluation and 20%% for test.") 165 | parser.add_argument("--task_type", choices=['ext', 'cls'], default="ext", type=str, 166 | help="Select task type, ext for the extraction task and cls for the classification task, defaults to ext.") 167 | parser.add_argument("--options", default=["正向", "负向"], type=str, nargs="+", 168 | help="Used only for the classification task, the options for classification") 169 | parser.add_argument("--prompt_prefix", default="情感倾向", type=str, 170 | help="Used only for the classification task, the prompt prefix for classification") 171 | parser.add_argument("--is_shuffle", default=True, type=bool, 172 | help="Whether to shuffle the labeled dataset, defaults to True.") 173 | parser.add_argument("--seed", type=int, default=1000, 174 | help="Random seed for initialization") 175 | parser.add_argument("--separator", type=str, default='##', 176 | help="Used only for entity/aspect-level classification task, separator for entity label and classification label") 177 | 178 | args = parser.parse_args() 179 | # yapf: enable 180 | 181 | do_convert() 182 | -------------------------------------------------------------------------------- /evaluate.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 Heiheiyoyo. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from model import UIE 16 | import argparse 17 | from functools import partial 18 | 19 | import torch 20 | from transformers import BertTokenizerFast 21 | from torch.utils.data import DataLoader 22 | 23 | from utils import IEMapDataset, SpanEvaluator, IEDataset, convert_example, get_relation_type_dict, logger, tqdm, unify_prompt_name 24 | 25 | 26 | @torch.no_grad() 27 | def evaluate(model, metric, data_loader, device='gpu', loss_fn=None, show_bar=True): 28 | """ 29 | Given a dataset, it evals model and computes the metric. 30 | Args: 31 | model(obj:`torch.nn.Module`): A model to classify texts. 32 | metric(obj:`Metric`): The evaluation metric. 33 | data_loader(obj:`torch.utils.data.DataLoader`): The dataset loader which generates batches. 34 | """ 35 | return_loss = False 36 | if loss_fn is not None: 37 | return_loss = True 38 | model.eval() 39 | metric.reset() 40 | loss_list = [] 41 | loss_sum = 0 42 | loss_num = 0 43 | if show_bar: 44 | data_loader = tqdm( 45 | data_loader, desc="Evaluating", unit='batch') 46 | for batch in data_loader: 47 | input_ids, token_type_ids, att_mask, start_ids, end_ids = batch 48 | if device == 'gpu': 49 | input_ids = input_ids.cuda() 50 | token_type_ids = token_type_ids.cuda() 51 | att_mask = att_mask.cuda() 52 | outputs = model(input_ids=input_ids, 53 | token_type_ids=token_type_ids, 54 | attention_mask=att_mask) 55 | start_prob, end_prob = outputs[0], outputs[1] 56 | if device == 'gpu': 57 | start_prob, end_prob = start_prob.cpu(), end_prob.cpu() 58 | start_ids = start_ids.type(torch.float32) 59 | end_ids = end_ids.type(torch.float32) 60 | 61 | if return_loss: 62 | # Calculate loss 63 | loss_start = loss_fn(start_prob, start_ids) 64 | loss_end = loss_fn(end_prob, end_ids) 65 | loss = (loss_start + loss_end) / 2.0 66 | loss = float(loss) 67 | loss_list.append(loss) 68 | loss_sum += loss 69 | loss_num += 1 70 | if show_bar: 71 | data_loader.set_postfix( 72 | { 73 | 'dev loss': f'{loss_sum / loss_num:.5f}' 74 | } 75 | ) 76 | 77 | # Calcalate metric 78 | num_correct, num_infer, num_label = metric.compute(start_prob, end_prob, 79 | start_ids, end_ids) 80 | metric.update(num_correct, num_infer, num_label) 81 | precision, recall, f1 = metric.accumulate() 82 | model.train() 83 | if return_loss: 84 | loss_avg = sum(loss_list) / len(loss_list) 85 | return loss_avg, precision, recall, f1 86 | else: 87 | return precision, recall, f1 88 | 89 | 90 | def do_eval(): 91 | tokenizer = BertTokenizerFast.from_pretrained(args.model_path) 92 | model = UIE.from_pretrained(args.model_path) 93 | if args.device == 'gpu': 94 | model = model.cuda() 95 | 96 | test_ds = IEDataset(args.test_path, tokenizer=tokenizer, 97 | max_seq_len=args.max_seq_len) 98 | 99 | test_data_loader = DataLoader( 100 | test_ds, batch_size=args.batch_size, shuffle=False) 101 | class_dict = {} 102 | relation_data = [] 103 | if args.debug: 104 | for data in test_ds.dataset: 105 | class_name = unify_prompt_name(data['prompt']) 106 | # Only positive examples are evaluated in debug mode 107 | if len(data['result_list']) != 0: 108 | if "的" not in data['prompt']: 109 | class_dict.setdefault(class_name, []).append(data) 110 | else: 111 | relation_data.append((data['prompt'], data)) 112 | relation_type_dict = get_relation_type_dict(relation_data) 113 | else: 114 | class_dict["all_classes"] = test_ds 115 | 116 | for key in class_dict.keys(): 117 | if args.debug: 118 | test_ds = IEMapDataset(class_dict[key], tokenizer=tokenizer, 119 | max_seq_len=args.max_seq_len) 120 | else: 121 | test_ds = class_dict[key] 122 | 123 | test_data_loader = DataLoader( 124 | test_ds, batch_size=args.batch_size, shuffle=False) 125 | metric = SpanEvaluator() 126 | precision, recall, f1 = evaluate( 127 | model, metric, test_data_loader, args.device) 128 | logger.info("-----------------------------") 129 | logger.info("Class Name: %s" % key) 130 | logger.info("Evaluation Precision: %.5f | Recall: %.5f | F1: %.5f" % 131 | (precision, recall, f1)) 132 | 133 | if args.debug and len(relation_type_dict.keys()) != 0: 134 | for key in relation_type_dict.keys(): 135 | test_ds = IEMapDataset(relation_type_dict[key], tokenizer=tokenizer, 136 | max_seq_len=args.max_seq_len) 137 | 138 | test_data_loader = DataLoader( 139 | test_ds, batch_size=args.batch_size, shuffle=False) 140 | metric = SpanEvaluator() 141 | precision, recall, f1 = evaluate( 142 | model, metric, test_data_loader, args.device) 143 | logger.info("-----------------------------") 144 | logger.info("Class Name: X的%s" % key) 145 | logger.info("Evaluation Precision: %.5f | Recall: %.5f | F1: %.5f" % 146 | (precision, recall, f1)) 147 | 148 | 149 | if __name__ == "__main__": 150 | # yapf: disable 151 | parser = argparse.ArgumentParser() 152 | 153 | parser.add_argument("-m", "--model_path", type=str, required=True, 154 | help="The path of saved model that you want to load.") 155 | parser.add_argument("-t", "--test_path", type=str, required=True, 156 | help="The path of test set.") 157 | parser.add_argument("-b", "--batch_size", type=int, default=16, 158 | help="Batch size per GPU/CPU for training.") 159 | parser.add_argument("--max_seq_len", type=int, default=512, 160 | help="The maximum total input sequence length after tokenization.") 161 | parser.add_argument("-D", '--device', choices=['cpu', 'gpu'], default="gpu", 162 | help="Select which device to run model, defaults to gpu.") 163 | parser.add_argument("--debug", action='store_true', 164 | help="Precision, recall and F1 score are calculated for each class separately if this option is enabled.") 165 | 166 | args = parser.parse_args() 167 | # yapf: enable 168 | 169 | do_eval() 170 | -------------------------------------------------------------------------------- /export_model.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 Heiheiyoyo. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import argparse 16 | import os 17 | from itertools import chain 18 | from typing import List, Union 19 | import shutil 20 | from pathlib import Path 21 | 22 | import numpy as np 23 | import torch 24 | from transformers import (BertTokenizer, PreTrainedModel, 25 | PreTrainedTokenizerBase) 26 | 27 | from model import UIE 28 | from utils import logger 29 | 30 | 31 | def validate_onnx(tokenizer: PreTrainedTokenizerBase, pt_model: PreTrainedModel, onnx_path: Union[Path, str], strict: bool = True, atol: float = 1e-05): 32 | 33 | # 验证模型 34 | from onnxruntime import InferenceSession, SessionOptions 35 | from transformers import AutoTokenizer 36 | 37 | logger.info("Validating ONNX model...") 38 | if strict: 39 | ref_inputs = tokenizer('装备', "印媒所称的“印度第一艘国产航母”—“维克兰特”号", 40 | add_special_tokens=True, 41 | truncation=True, 42 | max_length=512, 43 | return_tensors="pt") 44 | else: 45 | batch_size = 2 46 | seq_length = 6 47 | dummy_input = [" ".join([tokenizer.unk_token]) 48 | * seq_length] * batch_size 49 | ref_inputs = dict(tokenizer(dummy_input, return_tensors="pt")) 50 | # ref_inputs = 51 | ref_outputs = pt_model(**ref_inputs) 52 | ref_outputs_dict = {} 53 | 54 | # We flatten potential collection of outputs (i.e. past_keys) to a flat structure 55 | for name, value in ref_outputs.items(): 56 | # Overwriting the output name as "present" since it is the name used for the ONNX outputs 57 | # ("past_key_values" being taken for the ONNX inputs) 58 | if name == "past_key_values": 59 | name = "present" 60 | ref_outputs_dict[name] = value 61 | 62 | # Create ONNX Runtime session 63 | options = SessionOptions() 64 | session = InferenceSession(str(onnx_path), options, providers=[ 65 | "CPUExecutionProvider"]) 66 | 67 | # We flatten potential collection of inputs (i.e. past_keys) 68 | onnx_inputs = {} 69 | for name, value in ref_inputs.items(): 70 | onnx_inputs[name] = value.numpy() 71 | onnx_named_outputs = ['start_prob', 'end_prob'] 72 | # Compute outputs from the ONNX model 73 | onnx_outputs = session.run(onnx_named_outputs, onnx_inputs) 74 | 75 | # Check we have a subset of the keys into onnx_outputs against ref_outputs 76 | ref_outputs_set, onnx_outputs_set = set( 77 | ref_outputs_dict.keys()), set(onnx_named_outputs) 78 | if not onnx_outputs_set.issubset(ref_outputs_set): 79 | logger.info( 80 | f"\t-[x] ONNX model output names {onnx_outputs_set} do not match reference model {ref_outputs_set}" 81 | ) 82 | 83 | raise ValueError( 84 | "Outputs doesn't match between reference model and ONNX exported model: " 85 | f"{onnx_outputs_set.difference(ref_outputs_set)}" 86 | ) 87 | else: 88 | logger.info( 89 | f"\t-[✓] ONNX model output names match reference model ({onnx_outputs_set})") 90 | 91 | # Check the shape and values match 92 | for name, ort_value in zip(onnx_named_outputs, onnx_outputs): 93 | ref_value = ref_outputs_dict[name].detach().numpy() 94 | 95 | logger.info(f'\t- Validating ONNX Model output "{name}":') 96 | 97 | # Shape 98 | if not ort_value.shape == ref_value.shape: 99 | logger.info( 100 | f"\t\t-[x] shape {ort_value.shape} doesn't match {ref_value.shape}") 101 | raise ValueError( 102 | "Outputs shape doesn't match between reference model and ONNX exported model: " 103 | f"Got {ref_value.shape} (reference) and {ort_value.shape} (ONNX)" 104 | ) 105 | else: 106 | logger.info( 107 | f"\t\t-[✓] {ort_value.shape} matches {ref_value.shape}") 108 | 109 | # Values 110 | if not np.allclose(ref_value, ort_value, atol=atol): 111 | logger.info(f"\t\t-[x] values not close enough (atol: {atol})") 112 | raise ValueError( 113 | "Outputs values doesn't match between reference model and ONNX exported model: " 114 | f"Got max absolute difference of: {np.amax(np.abs(ref_value - ort_value))}" 115 | ) 116 | else: 117 | logger.info(f"\t\t-[✓] all values close (atol: {atol})") 118 | 119 | 120 | def export_onnx(output_path: Union[Path, str], tokenizer: PreTrainedTokenizerBase, model: PreTrainedModel, device: torch.device, input_names: List[str], output_names: List[str]): 121 | with torch.no_grad(): 122 | model = model.to(device) 123 | model.eval() 124 | model.config.return_dict = True 125 | model.config.use_cache = False 126 | 127 | output_path = Path(output_path) 128 | 129 | # Create folder 130 | if not output_path.exists(): 131 | output_path.mkdir(parents=True) 132 | save_path = output_path / "inference.onnx" 133 | 134 | dynamic_axes = {name: {0: 'batch', 1: 'sequence'} 135 | for name in chain(input_names, output_names)} 136 | 137 | # Generate dummy input 138 | batch_size = 2 139 | seq_length = 6 140 | dummy_input = [" ".join([tokenizer.unk_token]) 141 | * seq_length] * batch_size 142 | inputs = dict(tokenizer(dummy_input, return_tensors="pt")) 143 | 144 | if save_path.exists(): 145 | logger.warning(f'Overwrite model {save_path.as_posix()}') 146 | save_path.unlink() 147 | 148 | torch.onnx.export(model, 149 | (inputs,), 150 | save_path, 151 | input_names=input_names, 152 | output_names=output_names, 153 | dynamic_axes=dynamic_axes, 154 | do_constant_folding=True, 155 | opset_version=11 156 | ) 157 | 158 | if not os.path.exists(save_path): 159 | logger.error(f'Export Failed!') 160 | 161 | return save_path 162 | 163 | 164 | def main(): 165 | parser = argparse.ArgumentParser() 166 | parser.add_argument("-m", "--model_path", type=Path, required=True, 167 | default='./checkpoint/model_best', help="The path to model parameters to be loaded.") 168 | parser.add_argument("-o", "--output_path", type=Path, default=None, 169 | help="The path of model parameter in static graph to be saved.") 170 | args = parser.parse_args() 171 | 172 | if args.output_path is None: 173 | args.output_path = args.model_path 174 | 175 | tokenizer = BertTokenizer.from_pretrained(args.model_path) 176 | model = UIE.from_pretrained(args.model_path) 177 | device = torch.device('cpu') 178 | input_names = [ 179 | 'input_ids', 180 | 'token_type_ids', 181 | 'attention_mask', 182 | ] 183 | output_names = [ 184 | 'start_prob', 185 | 'end_prob' 186 | ] 187 | 188 | logger.info("Export Tokenizer Config...") 189 | 190 | export_tokenizer(args) 191 | 192 | logger.info("Export ONNX Model...") 193 | 194 | save_path = export_onnx( 195 | args.output_path, tokenizer, model, device, input_names, output_names) 196 | validate_onnx(tokenizer, model, save_path) 197 | 198 | logger.info(f"All good, model saved at: {save_path.as_posix()}") 199 | 200 | 201 | def export_tokenizer(args): 202 | for tokenizer_fine in ['tokenizer_config.json', 'special_tokens_map.json', 'vocab.txt']: 203 | file_from = args.model_path / tokenizer_fine 204 | file_to = args.output_path/tokenizer_fine 205 | if file_from.resolve() == file_to.resolve(): 206 | continue 207 | shutil.copyfile(file_from, file_to) 208 | 209 | 210 | if __name__ == "__main__": 211 | 212 | main() 213 | -------------------------------------------------------------------------------- /finetune.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 Heiheiyoyo. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import argparse 16 | import shutil 17 | import sys 18 | import time 19 | import os 20 | import torch 21 | from torch.utils.data import DataLoader 22 | from transformers import BertTokenizerFast 23 | 24 | from utils import IEDataset, logger, tqdm 25 | from model import UIE 26 | from evaluate import evaluate 27 | from utils import set_seed, SpanEvaluator, EarlyStopping, logging_redirect_tqdm 28 | 29 | 30 | def do_train(): 31 | 32 | set_seed(args.seed) 33 | show_bar = True 34 | 35 | tokenizer = BertTokenizerFast.from_pretrained(args.model) 36 | model = UIE.from_pretrained(args.model) 37 | if args.device == 'gpu': 38 | model = model.cuda() 39 | train_ds = IEDataset(args.train_path, tokenizer=tokenizer, 40 | max_seq_len=args.max_seq_len) 41 | dev_ds = IEDataset(args.dev_path, tokenizer=tokenizer, 42 | max_seq_len=args.max_seq_len) 43 | 44 | train_data_loader = DataLoader( 45 | train_ds, batch_size=args.batch_size, shuffle=True) 46 | dev_data_loader = DataLoader( 47 | dev_ds, batch_size=args.batch_size, shuffle=True) 48 | 49 | optimizer = torch.optim.AdamW( 50 | lr=args.learning_rate, params=model.parameters()) 51 | 52 | criterion = torch.nn.functional.binary_cross_entropy 53 | metric = SpanEvaluator() 54 | 55 | if args.early_stopping: 56 | early_stopping_save_dir = os.path.join( 57 | args.save_dir, "early_stopping") 58 | if not os.path.exists(early_stopping_save_dir): 59 | os.makedirs(early_stopping_save_dir) 60 | if show_bar: 61 | def trace_func(*args, **kwargs): 62 | with logging_redirect_tqdm([logger.logger]): 63 | logger.info(*args, **kwargs) 64 | else: 65 | trace_func = logger.info 66 | early_stopping = EarlyStopping( 67 | patience=7, verbose=True, trace_func=trace_func, 68 | save_dir=early_stopping_save_dir) 69 | 70 | loss_list = [] 71 | loss_sum = 0 72 | loss_num = 0 73 | global_step = 0 74 | best_step = 0 75 | best_f1 = 0 76 | tic_train = time.time() 77 | epoch_iterator = range(1, args.num_epochs + 1) 78 | if show_bar: 79 | train_postfix_info = {'loss': 'unknown'} 80 | epoch_iterator = tqdm( 81 | epoch_iterator, desc='Training', unit='epoch') 82 | for epoch in epoch_iterator: 83 | train_data_iterator = train_data_loader 84 | if show_bar: 85 | train_data_iterator = tqdm(train_data_iterator, 86 | desc=f'Training Epoch {epoch}', unit='batch') 87 | train_data_iterator.set_postfix(train_postfix_info) 88 | for batch in train_data_iterator: 89 | if show_bar: 90 | epoch_iterator.refresh() 91 | input_ids, token_type_ids, att_mask, start_ids, end_ids = batch 92 | if args.device == 'gpu': 93 | input_ids = input_ids.cuda() 94 | token_type_ids = token_type_ids.cuda() 95 | att_mask = att_mask.cuda() 96 | start_ids = start_ids.cuda() 97 | end_ids = end_ids.cuda() 98 | outputs = model(input_ids=input_ids, 99 | token_type_ids=token_type_ids, 100 | attention_mask=att_mask) 101 | start_prob, end_prob = outputs[0], outputs[1] 102 | 103 | start_ids = start_ids.type(torch.float32) 104 | end_ids = end_ids.type(torch.float32) 105 | loss_start = criterion(start_prob, start_ids) 106 | loss_end = criterion(end_prob, end_ids) 107 | loss = (loss_start + loss_end) / 2.0 108 | loss.backward() 109 | optimizer.step() 110 | optimizer.zero_grad() 111 | loss_list.append(float(loss)) 112 | loss_sum += float(loss) 113 | loss_num += 1 114 | 115 | if show_bar: 116 | loss_avg = loss_sum / loss_num 117 | train_postfix_info.update({ 118 | 'loss': f'{loss_avg:.5f}' 119 | }) 120 | train_data_iterator.set_postfix(train_postfix_info) 121 | 122 | global_step += 1 123 | if global_step % args.logging_steps == 0: 124 | time_diff = time.time() - tic_train 125 | loss_avg = loss_sum / loss_num 126 | 127 | if show_bar: 128 | with logging_redirect_tqdm([logger.logger]): 129 | logger.info( 130 | "global step %d, epoch: %d, loss: %.5f, speed: %.2f step/s" 131 | % (global_step, epoch, loss_avg, 132 | args.logging_steps / time_diff)) 133 | else: 134 | logger.info( 135 | "global step %d, epoch: %d, loss: %.5f, speed: %.2f step/s" 136 | % (global_step, epoch, loss_avg, 137 | args.logging_steps / time_diff)) 138 | tic_train = time.time() 139 | 140 | if global_step % args.valid_steps == 0: 141 | save_dir = os.path.join( 142 | args.save_dir, "model_%d" % global_step) 143 | if not os.path.exists(save_dir): 144 | os.makedirs(save_dir) 145 | model_to_save = model 146 | model_to_save.save_pretrained(save_dir) 147 | tokenizer.save_pretrained(save_dir) 148 | if args.max_model_num: 149 | model_to_delete = global_step-args.max_model_num*args.valid_steps 150 | model_to_delete_path = os.path.join( 151 | args.save_dir, "model_%d" % model_to_delete) 152 | if model_to_delete > 0 and os.path.exists(model_to_delete_path): 153 | shutil.rmtree(model_to_delete_path) 154 | 155 | dev_loss_avg, precision, recall, f1 = evaluate( 156 | model, metric, data_loader=dev_data_loader, device=args.device, loss_fn=criterion) 157 | 158 | if show_bar: 159 | train_postfix_info.update({ 160 | 'F1': f'{f1:.3f}', 161 | 'dev loss': f'{dev_loss_avg:.5f}' 162 | }) 163 | train_data_iterator.set_postfix(train_postfix_info) 164 | with logging_redirect_tqdm([logger.logger]): 165 | logger.info("Evaluation precision: %.5f, recall: %.5f, F1: %.5f, dev loss: %.5f" 166 | % (precision, recall, f1, dev_loss_avg)) 167 | else: 168 | logger.info("Evaluation precision: %.5f, recall: %.5f, F1: %.5f, dev loss: %.5f" 169 | % (precision, recall, f1, dev_loss_avg)) 170 | # Save model which has best F1 171 | if f1 > best_f1: 172 | if show_bar: 173 | with logging_redirect_tqdm([logger.logger]): 174 | logger.info( 175 | f"best F1 performence has been updated: {best_f1:.5f} --> {f1:.5f}" 176 | ) 177 | else: 178 | logger.info( 179 | f"best F1 performence has been updated: {best_f1:.5f} --> {f1:.5f}" 180 | ) 181 | best_f1 = f1 182 | save_dir = os.path.join(args.save_dir, "model_best") 183 | model_to_save = model 184 | model_to_save.save_pretrained(save_dir) 185 | tokenizer.save_pretrained(save_dir) 186 | tic_train = time.time() 187 | 188 | if args.early_stopping: 189 | dev_loss_avg, precision, recall, f1 = evaluate( 190 | model, metric, data_loader=dev_data_loader, device=args.device, loss_fn=criterion) 191 | 192 | if show_bar: 193 | train_postfix_info.update({ 194 | 'F1': f'{f1:.3f}', 195 | 'dev loss': f'{dev_loss_avg:.5f}' 196 | }) 197 | train_data_iterator.set_postfix(train_postfix_info) 198 | with logging_redirect_tqdm([logger.logger]): 199 | logger.info("Evaluation precision: %.5f, recall: %.5f, F1: %.5f, dev loss: %.5f" 200 | % (precision, recall, f1, dev_loss_avg)) 201 | else: 202 | logger.info("Evaluation precision: %.5f, recall: %.5f, F1: %.5f, dev loss: %.5f" 203 | % (precision, recall, f1, dev_loss_avg)) 204 | 205 | # Early Stopping 206 | early_stopping(dev_loss_avg, model) 207 | if early_stopping.early_stop: 208 | if show_bar: 209 | with logging_redirect_tqdm([logger.logger]): 210 | logger.info("Early stopping") 211 | else: 212 | logger.info("Early stopping") 213 | tokenizer.save_pretrained(early_stopping_save_dir) 214 | sys.exit(0) 215 | 216 | 217 | if __name__ == "__main__": 218 | # yapf: disable 219 | parser = argparse.ArgumentParser() 220 | 221 | parser.add_argument("-b", "--batch_size", default=16, type=int, 222 | help="Batch size per GPU/CPU for training.") 223 | parser.add_argument("--learning_rate", default=1e-5, 224 | type=float, help="The initial learning rate for Adam.") 225 | parser.add_argument("-t", "--train_path", default=None, required=True, 226 | type=str, help="The path of train set.") 227 | parser.add_argument("-d", "--dev_path", default=None, required=True, 228 | type=str, help="The path of dev set.") 229 | parser.add_argument("-s", "--save_dir", default='./checkpoint', type=str, 230 | help="The output directory where the model checkpoints will be written.") 231 | parser.add_argument("--max_seq_len", default=512, type=int, help="The maximum input sequence length. " 232 | "Sequences longer than this will be split automatically.") 233 | parser.add_argument("--num_epochs", default=100, type=int, 234 | help="Total number of training epochs to perform.") 235 | parser.add_argument("--seed", default=1000, type=int, 236 | help="Random seed for initialization") 237 | parser.add_argument("--logging_steps", default=10, 238 | type=int, help="The interval steps to logging.") 239 | parser.add_argument("--valid_steps", default=100, type=int, 240 | help="The interval steps to evaluate model performance.") 241 | parser.add_argument("-D", '--device', choices=['cpu', 'gpu'], default="gpu", 242 | help="Select which device to train model, defaults to gpu.") 243 | parser.add_argument("-m", "--model", default="uie_base_pytorch", type=str, 244 | help="Select the pretrained model for few-shot learning.") 245 | parser.add_argument("--max_model_num", default=5, type=int, 246 | help="Max number of saved model. Best model and earlystopping model is not included.") 247 | parser.add_argument("--early_stopping", action='store_true', default=False, 248 | help="Use early stopping while training") 249 | 250 | args = parser.parse_args() 251 | # yapf: enable 252 | 253 | do_train() 254 | -------------------------------------------------------------------------------- /labelstudio2doccano.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import argparse 16 | import os 17 | import json 18 | 19 | 20 | def append_attrs(data, item, label_id, relation_id): 21 | 22 | mapp = {} 23 | 24 | for anno in data["annotations"][0]["result"]: 25 | if anno["type"] == "labels": 26 | label_id += 1 27 | item["entities"].append({ 28 | "id": label_id, 29 | "label": anno["value"]["labels"][0], 30 | "start_offset": anno["value"]["start"], 31 | "end_offset": anno["value"]["end"] 32 | }) 33 | mapp[anno["id"]] = label_id 34 | 35 | for anno in data["annotations"][0]["result"]: 36 | if anno["type"] == "relation": 37 | relation_id += 1 38 | item["relations"].append({ 39 | "id": relation_id, 40 | "from_id": mapp[anno["from_id"]], 41 | "to_id": mapp[anno["to_id"]], 42 | "type": anno["labels"][0] 43 | }) 44 | 45 | return item, label_id, relation_id 46 | 47 | 48 | def convert(dataset, task_type): 49 | results = [] 50 | outer_id = 0 51 | if task_type == "ext": 52 | label_id = 0 53 | relation_id = 0 54 | for data in dataset: 55 | outer_id += 1 56 | item = { 57 | "id": outer_id, 58 | "text": data["data"]["text"], 59 | "entities": [], 60 | "relations": [] 61 | } 62 | item, label_id, relation_id = append_attrs(data, item, label_id, 63 | relation_id) 64 | results.append(item) 65 | # for the classification task 66 | else: 67 | for data in dataset: 68 | outer_id += 1 69 | results.append({ 70 | "id": 71 | outer_id, 72 | "text": 73 | data["data"]["text"], 74 | "label": 75 | data["annotations"][0]["result"][0]["value"]["choices"] 76 | }) 77 | return results 78 | 79 | 80 | def do_convert(args): 81 | 82 | if not os.path.exists(args.labelstudio_file): 83 | raise ValueError("Please input the correct path of label studio file.") 84 | 85 | with open(args.labelstudio_file, "r", encoding="utf-8") as infile: 86 | for content in infile: 87 | dataset = json.loads(content) 88 | results = convert(dataset, args.task_type) 89 | 90 | with open(args.doccano_file, "w", encoding="utf-8") as outfile: 91 | for item in results: 92 | outline = json.dumps(item, ensure_ascii=False) 93 | outfile.write(outline + "\n") 94 | 95 | 96 | if __name__ == "__main__": 97 | 98 | parser = argparse.ArgumentParser() 99 | 100 | parser.add_argument( 101 | '--labelstudio_file', 102 | type=str, 103 | help= 104 | 'The export file path of label studio, only support the JSON format.') 105 | parser.add_argument('--doccano_file', 106 | type=str, 107 | default='doccano_ext.jsonl', 108 | help='Saving path in doccano format.') 109 | parser.add_argument( 110 | '--task_type', 111 | type=str, 112 | choices=['ext', 'cls'], 113 | default='ext', 114 | help= 115 | 'Select task type, ext for the extraction task and cls for the classification task, defaults to ext.' 116 | ) 117 | 118 | args = parser.parse_args() 119 | 120 | do_convert(args) 121 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 Heiheiyoyo. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import torch 16 | import torch.nn as nn 17 | from dataclasses import dataclass 18 | from transformers import PretrainedConfig 19 | from transformers.utils import ModelOutput 20 | from typing import Optional, Tuple 21 | 22 | from ernie import ErnieModel, ErniePreTrainedModel 23 | from ernie_m import ErnieMModel, ErnieMPreTrainedModel 24 | 25 | 26 | @dataclass 27 | class UIEModelOutput(ModelOutput): 28 | """ 29 | Output class for outputs of UIE. 30 | Args: 31 | loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): 32 | Total span extraction loss is the sum of a Cross-Entropy for the start and end positions. 33 | start_prob (`torch.FloatTensor` of shape `(batch_size, sequence_length)`): 34 | Span-start scores (after Sigmoid). 35 | end_prob (`torch.FloatTensor` of shape `(batch_size, sequence_length)`): 36 | Span-end scores (after Sigmoid). 37 | hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): 38 | Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + 39 | one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. 40 | Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. 41 | attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): 42 | Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, 43 | sequence_length)`. 44 | Attentions weights after the attention softmax, used to compute the weighted average in the self-attention 45 | heads. 46 | """ 47 | loss: Optional[torch.FloatTensor] = None 48 | start_prob: torch.FloatTensor = None 49 | end_prob: torch.FloatTensor = None 50 | hidden_states: Optional[Tuple[torch.FloatTensor]] = None 51 | attentions: Optional[Tuple[torch.FloatTensor]] = None 52 | 53 | 54 | class UIE(ErniePreTrainedModel): 55 | """ 56 | UIE model based on Bert model. 57 | This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the 58 | library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads 59 | etc.) 60 | This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. 61 | Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage 62 | and behavior. 63 | Parameters: 64 | config ([`PretrainedConfig`]): Model configuration class with all the parameters of the model. 65 | Initializing with a config file does not load the weights associated with the model, only the 66 | configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. 67 | """ 68 | 69 | def __init__(self, config: PretrainedConfig): 70 | super(UIE, self).__init__(config) 71 | self.encoder = ErnieModel(config) 72 | self.config = config 73 | hidden_size = self.config.hidden_size 74 | 75 | self.linear_start = nn.Linear(hidden_size, 1) 76 | self.linear_end = nn.Linear(hidden_size, 1) 77 | self.sigmoid = nn.Sigmoid() 78 | 79 | # if hasattr(config, 'use_task_id') and config.use_task_id: 80 | # # Add task type embedding to BERT 81 | # task_type_embeddings = nn.Embedding( 82 | # config.task_type_vocab_size, config.hidden_size) 83 | # self.encoder.embeddings.task_type_embeddings = task_type_embeddings 84 | 85 | # def hook(module, input, output): 86 | # input = input[0] 87 | # return output+task_type_embeddings(torch.zeros(input.size(), dtype=torch.int64, device=input.device)) 88 | # self.encoder.embeddings.word_embeddings.register_forward_hook(hook) 89 | 90 | self.post_init() 91 | 92 | def forward(self, input_ids: Optional[torch.Tensor] = None, 93 | token_type_ids: Optional[torch.Tensor] = None, 94 | position_ids: Optional[torch.Tensor] = None, 95 | attention_mask: Optional[torch.Tensor] = None, 96 | head_mask: Optional[torch.Tensor] = None, 97 | inputs_embeds: Optional[torch.Tensor] = None, 98 | start_positions: Optional[torch.Tensor] = None, 99 | end_positions: Optional[torch.Tensor] = None, 100 | output_attentions: Optional[bool] = None, 101 | output_hidden_states: Optional[bool] = None, 102 | return_dict: Optional[bool] = None 103 | ): 104 | """ 105 | Args: 106 | input_ids (`torch.LongTensor` of shape `({0})`): 107 | Indices of input sequence tokens in the vocabulary. 108 | Indices can be obtained using [`BertTokenizer`]. See [`PreTrainedTokenizer.encode`] and 109 | [`PreTrainedTokenizer.__call__`] for details. 110 | [What are input IDs?](../glossary#input-ids) 111 | attention_mask (`torch.FloatTensor` of shape `({0})`, *optional*): 112 | Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: 113 | - 1 for tokens that are **not masked**, 114 | - 0 for tokens that are **masked**. 115 | [What are attention masks?](../glossary#attention-mask) 116 | token_type_ids (`torch.LongTensor` of shape `({0})`, *optional*): 117 | Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0, 118 | 1]`: 119 | - 0 corresponds to a *sentence A* token, 120 | - 1 corresponds to a *sentence B* token. 121 | [What are token type IDs?](../glossary#token-type-ids) 122 | position_ids (`torch.LongTensor` of shape `({0})`, *optional*): 123 | Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, 124 | config.max_position_embeddings - 1]`. 125 | [What are position IDs?](../glossary#position-ids) 126 | head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): 127 | Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`: 128 | - 1 indicates the head is **not masked**, 129 | - 0 indicates the head is **masked**. 130 | inputs_embeds (`torch.FloatTensor` of shape `({0}, hidden_size)`, *optional*): 131 | Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This 132 | is useful if you want more control over how to convert `input_ids` indices into associated vectors than the 133 | model's internal embedding lookup matrix. 134 | start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): 135 | Labels for position (index) of the start of the labelled span for computing the token classification loss. 136 | Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence 137 | are not taken into account for computing the loss. 138 | end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): 139 | Labels for position (index) of the end of the labelled span for computing the token classification loss. 140 | Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence 141 | are not taken into account for computing the loss. 142 | output_attentions (`bool`, *optional*): 143 | Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned 144 | tensors for more detail. 145 | output_hidden_states (`bool`, *optional*): 146 | Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for 147 | more detail. 148 | return_dict (`bool`, *optional*): 149 | Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. 150 | """ 151 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 152 | outputs = self.encoder( 153 | input_ids=input_ids, 154 | token_type_ids=token_type_ids, 155 | position_ids=position_ids, 156 | attention_mask=attention_mask, 157 | head_mask=head_mask, 158 | inputs_embeds=inputs_embeds, 159 | output_attentions=output_attentions, 160 | output_hidden_states=output_hidden_states, 161 | return_dict=return_dict 162 | ) 163 | sequence_output = outputs[0] 164 | 165 | start_logits = self.linear_start(sequence_output) 166 | start_logits = torch.squeeze(start_logits, -1) 167 | start_prob = self.sigmoid(start_logits) 168 | end_logits = self.linear_end(sequence_output) 169 | end_logits = torch.squeeze(end_logits, -1) 170 | end_prob = self.sigmoid(end_logits) 171 | 172 | total_loss = None 173 | if start_positions is not None and end_positions is not None: 174 | loss_fct = nn.BCELoss() 175 | start_loss = loss_fct(start_prob, start_positions) 176 | end_loss = loss_fct(end_prob, end_positions) 177 | total_loss = (start_loss + end_loss) / 2.0 178 | 179 | if not return_dict: 180 | output = (start_prob, end_prob) + outputs[2:] 181 | return ((total_loss,) + output) if total_loss is not None else output 182 | 183 | return UIEModelOutput( 184 | loss=total_loss, 185 | start_prob=start_prob, 186 | end_prob=end_prob, 187 | hidden_states=outputs.hidden_states, 188 | attentions=outputs.attentions, 189 | ) 190 | 191 | 192 | class UIEM(ErnieMPreTrainedModel): 193 | """ 194 | UIE model based on Bert model. 195 | This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the 196 | library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads 197 | etc.) 198 | This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. 199 | Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage 200 | and behavior. 201 | Parameters: 202 | config ([`PretrainedConfig`]): Model configuration class with all the parameters of the model. 203 | Initializing with a config file does not load the weights associated with the model, only the 204 | configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. 205 | """ 206 | 207 | def __init__(self, config: PretrainedConfig): 208 | super(UIEM, self).__init__(config) 209 | self.encoder = ErnieMModel(config) 210 | self.config = config 211 | hidden_size = self.config.hidden_size 212 | 213 | self.linear_start = nn.Linear(hidden_size, 1) 214 | self.linear_end = nn.Linear(hidden_size, 1) 215 | self.sigmoid = nn.Sigmoid() 216 | 217 | self.post_init() 218 | 219 | def forward(self, input_ids: Optional[torch.Tensor] = None, 220 | position_ids: Optional[torch.Tensor] = None, 221 | attention_mask: Optional[torch.Tensor] = None, 222 | head_mask: Optional[torch.Tensor] = None, 223 | start_positions: Optional[torch.Tensor] = None, 224 | end_positions: Optional[torch.Tensor] = None, 225 | output_attentions: Optional[bool] = None, 226 | output_hidden_states: Optional[bool] = None, 227 | return_dict: Optional[bool] = None 228 | ): 229 | """ 230 | Args: 231 | input_ids (`torch.LongTensor` of shape `({0})`): 232 | Indices of input sequence tokens in the vocabulary. 233 | Indices can be obtained using [`BertTokenizer`]. See [`PreTrainedTokenizer.encode`] and 234 | [`PreTrainedTokenizer.__call__`] for details. 235 | [What are input IDs?](../glossary#input-ids) 236 | attention_mask (`torch.FloatTensor` of shape `({0})`, *optional*): 237 | Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: 238 | - 1 for tokens that are **not masked**, 239 | - 0 for tokens that are **masked**. 240 | [What are attention masks?](../glossary#attention-mask) 241 | position_ids (`torch.LongTensor` of shape `({0})`, *optional*): 242 | Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, 243 | config.max_position_embeddings - 1]`. 244 | [What are position IDs?](../glossary#position-ids) 245 | head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): 246 | Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`: 247 | - 1 indicates the head is **not masked**, 248 | - 0 indicates the head is **masked**. 249 | inputs_embeds (`torch.FloatTensor` of shape `({0}, hidden_size)`, *optional*): 250 | Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This 251 | is useful if you want more control over how to convert `input_ids` indices into associated vectors than the 252 | model's internal embedding lookup matrix. 253 | start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): 254 | Labels for position (index) of the start of the labelled span for computing the token classification loss. 255 | Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence 256 | are not taken into account for computing the loss. 257 | end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): 258 | Labels for position (index) of the end of the labelled span for computing the token classification loss. 259 | Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence 260 | are not taken into account for computing the loss. 261 | output_attentions (`bool`, *optional*): 262 | Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned 263 | tensors for more detail. 264 | output_hidden_states (`bool`, *optional*): 265 | Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for 266 | more detail. 267 | return_dict (`bool`, *optional*): 268 | Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. 269 | """ 270 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 271 | outputs = self.encoder( 272 | input_ids=input_ids, 273 | position_ids=position_ids, 274 | # attention_mask=attention_mask, 275 | # head_mask=head_mask, 276 | output_attentions=output_attentions, 277 | output_hidden_states=output_hidden_states, 278 | return_dict=return_dict 279 | ) 280 | sequence_output = outputs[0] 281 | 282 | start_logits = self.linear_start(sequence_output) 283 | start_logits = torch.squeeze(start_logits, -1) 284 | start_prob = self.sigmoid(start_logits) 285 | end_logits = self.linear_end(sequence_output) 286 | end_logits = torch.squeeze(end_logits, -1) 287 | end_prob = self.sigmoid(end_logits) 288 | 289 | total_loss = None 290 | if start_positions is not None and end_positions is not None: 291 | loss_fct = nn.BCELoss() 292 | start_loss = loss_fct(start_prob, start_positions) 293 | end_loss = loss_fct(end_prob, end_positions) 294 | total_loss = (start_loss + end_loss) / 2.0 295 | 296 | if not return_dict: 297 | output = (start_prob, end_prob) + outputs[2:] 298 | return ((total_loss,) + output) if total_loss is not None else output 299 | 300 | return UIEModelOutput( 301 | loss=total_loss, 302 | start_prob=start_prob, 303 | end_prob=end_prob, 304 | hidden_states=outputs.hidden_states, 305 | attentions=outputs.attentions, 306 | ) 307 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | numpy >=1.22 3 | six 4 | colorlog 5 | torch >=1.10,<2.0 # cpu版 6 | transformers >=4.18,<5.0 7 | packaging 8 | tqdm 9 | sentencepiece 10 | protobuf==3.19.0 11 | onnxruntime # cpu版 12 | 13 | # faster-tokenizer==0.2.0 # faster-tokenizer 已经废弃 14 | fast-tokenizer-python==1.0.0 15 | 16 | # paddlepaddle # paddlepaddle 是可选的 -------------------------------------------------------------------------------- /tokenizer.py: -------------------------------------------------------------------------------- 1 | 2 | import collections 3 | import json 4 | import os 5 | import unicodedata 6 | from shutil import copyfile 7 | from typing import Any, Dict, List, Optional, Tuple 8 | 9 | import sentencepiece as spm 10 | from transformers import SLOW_TO_FAST_CONVERTERS, PreTrainedTokenizerFast, requires_backends 11 | from transformers.convert_slow_tokenizer import Converter, SentencePieceExtractor 12 | 13 | try: 14 | from fast_tokenizer import Tokenizer, normalizers, pretokenizers, postprocessors 15 | from fast_tokenizer.models import BPE, Unigram 16 | except ImportError as e: 17 | print('fast_tokenizer 未安装! pip install fast-tokenizer-python==1.0.0') 18 | print(e) 19 | from faster_tokenizer import Tokenizer, normalizers, pretokenizers, postprocessors 20 | from faster_tokenizer.models import BPE, Unigram 21 | 22 | from transformers.tokenization_utils import AddedToken, PreTrainedTokenizer 23 | from transformers.utils import SPIECE_UNDERLINE 24 | 25 | from utils import logger 26 | 27 | 28 | VOCAB_FILES_NAMES = { 29 | "sentencepiece_model_file": "sentencepiece.bpe.model", 30 | "vocab_file": "vocab.txt", 31 | } 32 | 33 | PRETRAINED_VOCAB_FILES_MAP = { 34 | "vocab_file": { 35 | "ernie-m-base": "https://bj.bcebos.com/paddlenlp/models/transformers/ernie_m/ernie_m.vocab.txt", 36 | "ernie-m-large": "https://bj.bcebos.com/paddlenlp/models/transformers/ernie_m/ernie_m.vocab.txt" 37 | }, 38 | "sentencepiece_model_file": { 39 | "ernie-m-base": "https://bj.bcebos.com/paddlenlp/models/transformers/ernie_m/ernie_m.sentencepiece.bpe.model", 40 | "ernie-m-large": "https://bj.bcebos.com/paddlenlp/models/transformers/ernie_m/ernie_m.sentencepiece.bpe.model", 41 | } 42 | } 43 | 44 | PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { 45 | "ernie-m-base": 514, 46 | "ernie-m-large": 514, 47 | } 48 | 49 | 50 | def load_vocab(vocab_file): 51 | """Loads a vocabulary file into a dictionary.""" 52 | vocab = collections.OrderedDict() 53 | with open(vocab_file, "r", encoding="utf-8") as reader: 54 | tokens = reader.readlines() 55 | for index, token in enumerate(tokens): 56 | token = token.rstrip("\n") 57 | if token in vocab: 58 | print(f'{token} 重复!') 59 | vocab[token] = index 60 | return vocab 61 | 62 | 63 | class ErnieMTokenizer(PreTrainedTokenizer): 64 | """ 65 | Construct an Erine-M tokenizer. Based on [SentencePiece](https://github.com/google/sentencepiece). 66 | 67 | This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to 68 | this superclass for more information regarding those methods. 69 | 70 | Args: 71 | vocab_file (`str`): 72 | File containing the vocabulary. 73 | sentencepiece_model_file (`str`): 74 | [SentencePiece](https://github.com/google/sentencepiece) file (generally has a .spm extension) that 75 | contains the vocabulary necessary to instantiate a tokenizer. 76 | do_lower_case (`bool`, *optional*, defaults to `True`): 77 | Whether to lowercase the input when tokenizing. 78 | unk_token (`str`, *optional*, defaults to `""`): 79 | The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this 80 | token instead. 81 | sep_token (`str`, *optional*, defaults to `""`): 82 | The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for 83 | sequence classification or for a text and a question for question answering. It is also used as the last 84 | token of a sequence built with special tokens. 85 | pad_token (`str`, *optional*, defaults to `""`): 86 | The token used for padding, for example when batching sequences of different lengths. 87 | cls_token (`str`, *optional*, defaults to `""`): 88 | The classifier token which is used when doing sequence classification (classification of the whole sequence 89 | instead of per-token classification). It is the first token of the sequence when built with special tokens. 90 | mask_token (`str`, *optional*, defaults to `""`): 91 | The token used for masking values. This is the token used when training this model with masked language 92 | modeling. This is the token which the model will try to predict. 93 | sp_model_kwargs (`dict`, *optional*): 94 | Will be passed to the `SentencePieceProcessor.__init__()` method. The [Python wrapper for 95 | SentencePiece](https://github.com/google/sentencepiece/tree/master/python) can be used, among other things, 96 | to set: 97 | 98 | - `enable_sampling`: Enable subword regularization. 99 | - `nbest_size`: Sampling parameters for unigram. Invalid for BPE-Dropout. 100 | 101 | - `nbest_size = {0,1}`: No sampling is performed. 102 | - `nbest_size > 1`: samples from the nbest_size results. 103 | - `nbest_size < 0`: assuming that nbest_size is infinite and samples from the all hypothesis (lattice) 104 | using forward-filtering-and-backward-sampling algorithm. 105 | 106 | - `alpha`: Smoothing parameter for unigram sampling, and dropout probability of merge operations for 107 | BPE-dropout. 108 | 109 | Attributes: 110 | sp_model (`SentencePieceProcessor`): 111 | The *SentencePiece* processor that is used for every conversion (string, tokens and IDs). 112 | """ 113 | 114 | vocab_files_names = VOCAB_FILES_NAMES 115 | pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP 116 | max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES 117 | padding_side = "left" 118 | 119 | def __init__( 120 | self, 121 | vocab_file, 122 | sentencepiece_model_file, 123 | do_lower_case=False, 124 | unk_token="[UNK]", 125 | sep_token="[SEP]", 126 | pad_token="[PAD]", 127 | cls_token="[CLS]", 128 | mask_token="[MASK]", 129 | sp_model_kwargs: Optional[Dict[str, Any]] = None, 130 | **kwargs 131 | ) -> None: 132 | 133 | # Mask token behave like a normal word, i.e. include the space before it 134 | mask_token = AddedToken(mask_token, lstrip=True, rstrip=False) if isinstance( 135 | mask_token, str) else mask_token 136 | 137 | self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs 138 | 139 | super().__init__( 140 | do_lower_case=do_lower_case, 141 | unk_token=unk_token, 142 | sep_token=sep_token, 143 | pad_token=pad_token, 144 | cls_token=cls_token, 145 | mask_token=mask_token, 146 | sp_model_kwargs=self.sp_model_kwargs, 147 | **kwargs, 148 | ) 149 | 150 | self.do_lower_case = do_lower_case 151 | self.sentencepiece_model_file = sentencepiece_model_file 152 | if not os.path.isfile(sentencepiece_model_file): 153 | raise ValueError( 154 | f"Can't find a vocabulary file at path '{sentencepiece_model_file}'. To load the vocabulary from a Google pretrained " 155 | "model use `tokenizer = ErnieMTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`" 156 | ) 157 | 158 | self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs) 159 | self.sp_model.Load(sentencepiece_model_file) 160 | 161 | if not os.path.isfile(vocab_file): 162 | raise ValueError( 163 | f"Can't find a vocabulary file at path '{vocab_file}'. To load the vocabulary from a Google pretrained " 164 | "model use `tokenizer = ErnieMTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`" 165 | ) 166 | self.vocab = load_vocab(vocab_file) 167 | self.ids_to_tokens = collections.OrderedDict( 168 | [(ids, tok) for tok, ids in self.vocab.items()]) 169 | 170 | self.SP_CHAR_MAPPING = {} 171 | 172 | for ch in range(65281, 65375): 173 | if ch in [ord(u'~')]: 174 | self.SP_CHAR_MAPPING[chr(ch)] = chr(ch) 175 | continue 176 | self.SP_CHAR_MAPPING[chr(ch)] = chr(ch - 65248) 177 | 178 | @property 179 | def vocab_size(self): 180 | return len(self.sp_model) 181 | 182 | def get_vocab(self): 183 | return dict(self.vocab, **self.added_tokens_encoder) 184 | 185 | def __getstate__(self): 186 | state = self.__dict__.copy() 187 | state["sp_model"] = None 188 | return state 189 | 190 | def __setstate__(self, d): 191 | self.__dict__ = d 192 | 193 | # for backward compatibility 194 | if not hasattr(self, "sp_model_kwargs"): 195 | self.sp_model_kwargs = {} 196 | 197 | self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs) 198 | self.sp_model.Load(self.vocab_file) 199 | 200 | def preprocess_text(self, inputs): 201 | outputs = ''.join((self.SP_CHAR_MAPPING.get(c, c) for c in inputs)) 202 | outputs = outputs.replace("``", '"').replace("''", '"') 203 | 204 | outputs = unicodedata.normalize("NFKD", outputs) 205 | outputs = "".join( 206 | [c for c in outputs if not unicodedata.combining(c)]) 207 | if self.do_lower_case: 208 | outputs = outputs.lower() 209 | 210 | return outputs 211 | 212 | def _tokenize(self, text: str) -> List[str]: 213 | """Tokenize a string.""" 214 | text = self.preprocess_text(text) 215 | 216 | pieces = self.sp_model.EncodeAsPieces(text) 217 | 218 | new_pieces = [] 219 | for piece in pieces: 220 | if piece == SPIECE_UNDERLINE: 221 | continue 222 | lst_i = 0 223 | for i, c in enumerate(piece): 224 | if c == SPIECE_UNDERLINE: 225 | continue 226 | if self.is_ch_char(c) or self.is_punct(c): 227 | if i > lst_i and piece[lst_i:i] != SPIECE_UNDERLINE: 228 | new_pieces.append(piece[lst_i:i]) 229 | new_pieces.append(c) 230 | lst_i = i + 1 231 | elif c.isdigit() and i > 0 and not piece[i - 1].isdigit(): 232 | if i > lst_i and piece[lst_i:i] != SPIECE_UNDERLINE: 233 | new_pieces.append(piece[lst_i:i]) 234 | lst_i = i 235 | elif not c.isdigit() and i > 0 and piece[i - 1].isdigit(): 236 | if i > lst_i and piece[lst_i:i] != SPIECE_UNDERLINE: 237 | new_pieces.append(piece[lst_i:i]) 238 | lst_i = i 239 | if len(piece) > lst_i: 240 | new_pieces.append(piece[lst_i:]) 241 | 242 | return new_pieces 243 | 244 | def _convert_token_to_id(self, token): 245 | """Converts a token (str) in an id using the vocab.""" 246 | 247 | return self.vocab.get(token, self.vocab.get(self.unk_token)) 248 | 249 | def _convert_id_to_token(self, index): 250 | """Converts an index (integer) in a token (str) using the vocab.""" 251 | return self.ids_to_tokens.get(index, self.unk_token) 252 | 253 | def convert_tokens_to_string(self, tokens): 254 | """Converts a sequence of tokens (strings for sub-words) in a single string.""" 255 | out_string = "".join(tokens).replace(SPIECE_UNDERLINE, " ").strip() 256 | return out_string 257 | 258 | def build_inputs_with_special_tokens( 259 | self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None 260 | ) -> List[int]: 261 | """ 262 | Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and 263 | adding special tokens. An Erine-M sequence has the following format: 264 | 265 | - single sequence: `X ` 266 | - pair of sequences: `A B ` 267 | 268 | Args: 269 | token_ids_0 (`List[int]`): 270 | List of IDs to which the special tokens will be added. 271 | token_ids_1 (`List[int]`, *optional*): 272 | Optional second list of IDs for sequence pairs. 273 | 274 | Returns: 275 | `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens. 276 | """ 277 | sep = [self.sep_token_id] 278 | cls = [self.cls_token_id] 279 | 280 | if token_ids_1 is None: 281 | return cls + token_ids_0 + sep 282 | 283 | return cls + token_ids_0 + sep + sep + token_ids_1 + sep 284 | 285 | def get_special_tokens_mask( 286 | self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False 287 | ) -> List[int]: 288 | """ 289 | Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding 290 | special tokens using the tokenizer `prepare_for_model` method. 291 | 292 | Args: 293 | token_ids_0 (`List[int]`): 294 | List of IDs. 295 | token_ids_1 (`List[int]`, *optional*): 296 | Optional second list of IDs for sequence pairs. 297 | already_has_special_tokens (`bool`, *optional*, defaults to `False`): 298 | Whether or not the token list is already formatted with special tokens for the model. 299 | 300 | Returns: 301 | `List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token. 302 | """ 303 | 304 | if already_has_special_tokens: 305 | return super().get_special_tokens_mask( 306 | token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True 307 | ) 308 | 309 | if token_ids_1 is not None: 310 | return [1] + ([0] * len(token_ids_0)) + [1, 1] + ([0] * len(token_ids_1)) + [1] 311 | return [1] + ([0] * len(token_ids_0)) + [1] 312 | 313 | def create_token_type_ids_from_sequences( 314 | self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None 315 | ) -> List[int]: 316 | """ 317 | Create a mask from the two sequences passed to be used in a sequence-pair classification task. An Erine-M 318 | sequence pair mask has the following format: 319 | 320 | ``` 321 | 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 322 | | first sequence | second sequence | 323 | ``` 324 | 325 | If `token_ids_1` is `None`, this method only returns the first portion of the mask (0s). 326 | 327 | Args: 328 | token_ids_0 (`List[int]`): 329 | List of IDs. 330 | token_ids_1 (`List[int]`, *optional*): 331 | Optional second list of IDs for sequence pairs. 332 | 333 | Returns: 334 | `List[int]`: List of [token type IDs](../glossary#token-type-ids) according to the given sequence(s). 335 | """ 336 | 337 | if token_ids_1 is None: 338 | # [CLS] X [SEP] 339 | return (len(token_ids_0) + 2) * [0] 340 | 341 | # [CLS] A [SEP] [SEP] B [SEP] 342 | return [0] * (len(token_ids_0) + 1) + [1] * (len(token_ids_1) + 3) 343 | 344 | def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]: 345 | if not os.path.isdir(save_directory): 346 | logger.error( 347 | f"Vocabulary path ({save_directory}) should be a directory") 348 | return 349 | sentencepiece_model_file = os.path.join( 350 | save_directory, (filename_prefix + "-" if filename_prefix else "") + 351 | VOCAB_FILES_NAMES["sentencepiece_model_file"] 352 | ) 353 | vocab_file = (filename_prefix + 354 | "-" if filename_prefix else "") + save_directory 355 | 356 | if os.path.abspath(self.vocab_file) != os.path.abspath(sentencepiece_model_file) and os.path.isfile(self.vocab_file): 357 | copyfile(self.vocab_file, sentencepiece_model_file) 358 | elif not os.path.isfile(self.vocab_file): 359 | with open(sentencepiece_model_file, "wb") as fi: 360 | content_spiece_model = self.sp_model.serialized_model_proto() 361 | fi.write(content_spiece_model) 362 | 363 | index = 0 364 | with open(vocab_file, "w", encoding="utf-8") as writer: 365 | for token, token_index in sorted(self.vocab.items(), key=lambda kv: kv[1]): 366 | if index != token_index: 367 | logger.warning( 368 | f"Saving vocabulary to {vocab_file}: vocabulary indices are not consecutive." 369 | " Please check that the vocabulary is not corrupted!" 370 | ) 371 | index = token_index 372 | writer.write(token + "\n") 373 | index += 1 374 | return vocab_file, sentencepiece_model_file 375 | 376 | def is_ch_char(self, char): 377 | """ 378 | is_ch_char 379 | """ 380 | if u'\u4e00' <= char <= u'\u9fff': 381 | return True 382 | return False 383 | 384 | def is_alpha(self, char): 385 | """ 386 | is_alpha 387 | """ 388 | if 'a' <= char <= 'z': 389 | return True 390 | if 'A' <= char <= 'Z': 391 | return True 392 | return False 393 | 394 | def is_punct(self, char): 395 | """ 396 | is_punct 397 | """ 398 | if char in u",;:.?!~,;:。?!《》【】": 399 | return True 400 | return False 401 | 402 | def is_whitespace(self, char): 403 | """ 404 | is whitespace 405 | """ 406 | if char == " " or char == "\t" or char == "\n" or char == "\r": 407 | return True 408 | if len(char) == 1: 409 | cat = unicodedata.category(char) 410 | if cat == "Zs": 411 | return True 412 | return False 413 | 414 | 415 | class ErnieMTokenizerFast(PreTrainedTokenizerFast): 416 | r""" 417 | Construct a "fast" ERNIE-M tokenizer (backed by HuggingFace's *tokenizers* library). Based on WordPiece. 418 | This tokenizer inherits from [`PreTrainedTokenizerFast`] which contains most of the main methods. Users should 419 | refer to this superclass for more information regarding those methods. 420 | Args: 421 | vocab_file (`str`): 422 | File containing the vocabulary. 423 | sentencepiece_model_file (`str`): 424 | [SentencePiece](https://github.com/google/sentencepiece) file (generally has a .spm extension) that 425 | contains the vocabulary necessary to instantiate a tokenizer. 426 | do_lower_case (`bool`, *optional*, defaults to `True`): 427 | Whether or not to lowercase the input when tokenizing. 428 | unk_token (`str`, *optional*, defaults to `"[UNK]"`): 429 | The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this 430 | token instead. 431 | sep_token (`str`, *optional*, defaults to `"[SEP]"`): 432 | The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for 433 | sequence classification or for a text and a question for question answering. It is also used as the last 434 | token of a sequence built with special tokens. 435 | pad_token (`str`, *optional*, defaults to `"[PAD]"`): 436 | The token used for padding, for example when batching sequences of different lengths. 437 | cls_token (`str`, *optional*, defaults to `"[CLS]"`): 438 | The classifier token which is used when doing sequence classification (classification of the whole sequence 439 | instead of per-token classification). It is the first token of the sequence when built with special tokens. 440 | mask_token (`str`, *optional*, defaults to `"[MASK]"`): 441 | The token used for masking values. This is the token used when training this model with masked language 442 | modeling. This is the token which the model will try to predict. 443 | clean_text (`bool`, *optional*, defaults to `True`): 444 | Whether or not to clean the text before tokenization by removing any control characters and replacing all 445 | whitespaces by the classic one. 446 | tokenize_chinese_chars (`bool`, *optional*, defaults to `True`): 447 | Whether or not to tokenize Chinese characters. This should likely be deactivated for Japanese (see [this 448 | issue](https://github.com/huggingface/transformers/issues/328)). 449 | strip_accents (`bool`, *optional*): 450 | Whether or not to strip all accents. If this option is not specified, then it will be determined by the 451 | value for `lowercase` (as in the original ERNIE-M). 452 | wordpieces_prefix (`str`, *optional*, defaults to `"##"`): 453 | The prefix for subwords. 454 | """ 455 | 456 | vocab_files_names = VOCAB_FILES_NAMES 457 | pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP 458 | max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES 459 | slow_tokenizer_class = ErnieMTokenizer 460 | 461 | def __init__( 462 | self, 463 | vocab_file=None, 464 | sentencepiece_model_file=None, 465 | tokenizer_file=None, 466 | do_lower_case=True, 467 | unk_token="[UNK]", 468 | sep_token="[SEP]", 469 | pad_token="[PAD]", 470 | cls_token="[CLS]", 471 | mask_token="[MASK]", 472 | tokenize_chinese_chars=True, 473 | strip_accents=None, 474 | **kwargs 475 | ): 476 | super().__init__( 477 | vocab_file, 478 | sentencepiece_model_file, 479 | tokenizer_file=tokenizer_file, 480 | do_lower_case=do_lower_case, 481 | unk_token=unk_token, 482 | sep_token=sep_token, 483 | pad_token=pad_token, 484 | cls_token=cls_token, 485 | mask_token=mask_token, 486 | tokenize_chinese_chars=tokenize_chinese_chars, 487 | strip_accents=strip_accents, 488 | **kwargs, 489 | ) 490 | 491 | normalizer_state = json.loads( 492 | self.backend_tokenizer.normalizer.__getstate__()) 493 | if ( 494 | normalizer_state.get("lowercase", do_lower_case) != do_lower_case 495 | or normalizer_state.get("strip_accents", strip_accents) != strip_accents 496 | or normalizer_state.get("handle_chinese_chars", tokenize_chinese_chars) != tokenize_chinese_chars 497 | ): 498 | normalizer_class = getattr( 499 | normalizers, normalizer_state.pop("type")) 500 | normalizer_state["lowercase"] = do_lower_case 501 | normalizer_state["strip_accents"] = strip_accents 502 | normalizer_state["handle_chinese_chars"] = tokenize_chinese_chars 503 | self.backend_tokenizer.normalizer = normalizer_class( 504 | **normalizer_state) 505 | 506 | self.do_lower_case = do_lower_case 507 | 508 | def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None): 509 | """ 510 | Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and 511 | adding special tokens. A ERNIE-M sequence has the following format: 512 | - single sequence: `[CLS] X [SEP]` 513 | - pair of sequences: `[CLS] A [SEP] B [SEP]` 514 | Args: 515 | token_ids_0 (`List[int]`): 516 | List of IDs to which the special tokens will be added. 517 | token_ids_1 (`List[int]`, *optional*): 518 | Optional second list of IDs for sequence pairs. 519 | Returns: 520 | `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens. 521 | """ 522 | output = [self.cls_token_id] + token_ids_0 + [self.sep_token_id] 523 | 524 | if token_ids_1: 525 | output += [self.sep_token_id] + token_ids_1 + [self.sep_token_id] 526 | 527 | return output 528 | 529 | def create_token_type_ids_from_sequences( 530 | self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None 531 | ) -> List[int]: 532 | """ 533 | Create a mask from the two sequences passed to be used in a sequence-pair classification task. A ERNIE-M sequence 534 | pair mask has the following format: 535 | ``` 536 | 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 537 | | first sequence | second sequence | 538 | ``` 539 | If `token_ids_1` is `None`, this method only returns the first portion of the mask (0s). 540 | Args: 541 | token_ids_0 (`List[int]`): 542 | List of IDs. 543 | token_ids_1 (`List[int]`, *optional*): 544 | Optional second list of IDs for sequence pairs. 545 | Returns: 546 | `List[int]`: List of [token type IDs](../glossary#token-type-ids) according to the given sequence(s). 547 | """ 548 | 549 | if token_ids_1 is None: 550 | return (len(token_ids_0) + 2) * [0] 551 | return [0] * (len(token_ids_0) + 1) + [1] * (len(token_ids_1) + 3) 552 | 553 | def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]: 554 | files = self._tokenizer.model.save( 555 | save_directory, name=filename_prefix) 556 | return tuple(files) 557 | 558 | 559 | class TokenizerProxy: 560 | def __init__(self, tokenizer): 561 | self._tokenizer = tokenizer 562 | self.no_padding = self._tokenizer.disable_padding 563 | self.no_truncation = self._tokenizer.disable_truncation 564 | 565 | def __getattr__(self, __name: str) -> Any: 566 | return getattr(self._tokenizer, __name) 567 | 568 | 569 | class ErnieMConverter(Converter): 570 | def __init__(self, *args): 571 | requires_backends(self, "protobuf") 572 | 573 | super().__init__(*args) 574 | 575 | from transformers.utils import sentencepiece_model_pb2 as model_pb2 576 | 577 | m = model_pb2.ModelProto() 578 | with open(self.original_tokenizer.sentencepiece_model_file, "rb") as f: 579 | m.ParseFromString(f.read()) 580 | self.proto = m 581 | 582 | def vocab(self, proto): 583 | word_score_dict = {} 584 | for piece in proto.pieces: 585 | word_score_dict[piece.piece] = piece.score 586 | vocab_list = [None] * len(self.original_tokenizer.ids_to_tokens) 587 | original_vocab = self.original_tokenizer.vocab 588 | for _token, _id in original_vocab.items(): 589 | if _token in word_score_dict: 590 | vocab_list[_id] = (_token, word_score_dict[_token]) 591 | else: 592 | vocab_list[_id] = (_token, 0.0) 593 | return vocab_list 594 | 595 | def post_processor(self): 596 | ''' 597 | An ERNIE-M sequence has the following format: 598 | - single sequence: ``[CLS] X [SEP]`` 599 | - pair of sequences: ``[CLS] A [SEP] [SEP] B [SEP]`` 600 | ''' 601 | return postprocessors.TemplatePostProcessor( 602 | single="[CLS]:0 $A:0 [SEP]:0", 603 | pair="[CLS]:0 $A:0 [SEP]:0 [SEP]:1 $B:1 [SEP]:1", 604 | special_tokens=[ 605 | ("[CLS]", 606 | self.original_tokenizer.convert_tokens_to_ids("[CLS]")), 607 | ("[SEP]", 608 | self.original_tokenizer.convert_tokens_to_ids("[SEP]")), 609 | ], 610 | ) 611 | 612 | def normalizer(self, proto): 613 | list_normalizers = [] 614 | precompiled_charsmap = proto.normalizer_spec.precompiled_charsmap 615 | list_normalizers.append( 616 | normalizers.PrecompiledNormalizer(precompiled_charsmap)) 617 | return normalizers.SequenceNormalizer(list_normalizers) 618 | 619 | def unk_id(self, proto): 620 | return self.original_tokenizer.convert_tokens_to_ids( 621 | str(self.original_tokenizer.unk_token)) 622 | 623 | def pre_tokenizer(self, replacement, add_prefix_space): 624 | return pretokenizers.SequencePreTokenizer([ 625 | pretokenizers.WhitespacePreTokenizer(), 626 | pretokenizers.MetaSpacePreTokenizer( 627 | replacement=replacement, add_prefix_space=add_prefix_space) 628 | ]) 629 | 630 | def converted(self) -> Tokenizer: 631 | tokenizer = self.tokenizer(self.proto) 632 | 633 | SPLICE_UNDERLINE = SPIECE_UNDERLINE 634 | tokenizer.model.set_filter_token(SPLICE_UNDERLINE) 635 | chinese_chars = r"\x{4e00}-\x{9fff}" 636 | punc_chars = r",;:.?!~,;:。?!《》【】" 637 | digits = r"0-9" 638 | tokenizer.model.set_split_rule( 639 | fr"[{chinese_chars}]|[{punc_chars}]|[{digits}]+|[^{chinese_chars}{punc_chars}{digits}]+" 640 | ) 641 | 642 | # Tokenizer assemble 643 | tokenizer.normalizer = self.normalizer(self.proto) 644 | 645 | replacement = "▁" 646 | add_prefix_space = True 647 | tokenizer.pretokenizer = self.pre_tokenizer( 648 | replacement, add_prefix_space) 649 | 650 | post_processor = self.post_processor() 651 | if post_processor: 652 | tokenizer.postprocessor = post_processor 653 | 654 | tokenizer = TokenizerProxy(tokenizer) 655 | return tokenizer 656 | 657 | def tokenizer(self, proto): 658 | model_type = proto.trainer_spec.model_type 659 | vocab = self.vocab(proto) 660 | unk_id = self.unk_id(proto) 661 | 662 | if model_type == 1: 663 | tokenizer = Tokenizer(Unigram(vocab, unk_id)) 664 | elif model_type == 2: 665 | _, merges = SentencePieceExtractor( 666 | self.original_tokenizer.sentencepiece_model_file).extract() 667 | bpe_vocab = {word: i for i, (word, score) in enumerate(vocab)} 668 | tokenizer = Tokenizer( 669 | BPE( 670 | bpe_vocab, 671 | merges, 672 | unk_token=proto.trainer_spec.unk_piece, 673 | fuse_unk=True, 674 | ) 675 | ) 676 | else: 677 | raise Exception( 678 | "You're trying to run a `Unigram` model but you're file was trained with a different algorithm" 679 | ) 680 | 681 | return tokenizer 682 | 683 | 684 | SLOW_TO_FAST_CONVERTERS["ErnieMTokenizer"] = ErnieMConverter 685 | -------------------------------------------------------------------------------- /uie_predictor.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 Heiheiyoyo. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import re 16 | import numpy as np 17 | import six 18 | 19 | 20 | import math 21 | import argparse 22 | import os.path 23 | 24 | from utils import logger, get_bool_ids_greater_than, get_span, get_id_and_prob, cut_chinese_sent, dbc2sbc 25 | 26 | 27 | class ONNXInferBackend(object): 28 | def __init__(self, 29 | model_path_prefix, 30 | device='cpu', 31 | use_fp16=False): 32 | from onnxruntime import InferenceSession, SessionOptions 33 | logger.info(">>> [ONNXInferBackend] Creating Engine ...") 34 | onnx_model = float_onnx_file = os.path.join( 35 | model_path_prefix, "inference.onnx") 36 | if not os.path.exists(onnx_model): 37 | raise OSError(f'{onnx_model} not exists!') 38 | infer_model_dir = model_path_prefix 39 | 40 | if device == "gpu": 41 | providers = ['CUDAExecutionProvider'] 42 | logger.info(">>> [ONNXInferBackend] Use GPU to inference ...") 43 | if use_fp16: 44 | logger.info(">>> [ONNXInferBackend] Use FP16 to inference ...") 45 | from onnxconverter_common import float16 46 | import onnx 47 | fp16_model_file = os.path.join(infer_model_dir, 48 | "fp16_model.onnx") 49 | onnx_model = onnx.load_model(float_onnx_file) 50 | trans_model = float16.convert_float_to_float16( 51 | onnx_model, keep_io_types=True) 52 | onnx.save_model(trans_model, fp16_model_file) 53 | onnx_model = fp16_model_file 54 | else: 55 | providers = ['CPUExecutionProvider'] 56 | logger.info(">>> [ONNXInferBackend] Use CPU to inference ...") 57 | 58 | sess_options = SessionOptions() 59 | self.predictor = InferenceSession( 60 | onnx_model, sess_options=sess_options, providers=providers) 61 | if device == "gpu": 62 | try: 63 | assert 'CUDAExecutionProvider' in self.predictor.get_providers() 64 | except AssertionError: 65 | raise AssertionError( 66 | f"The environment for GPU inference is not set properly. " 67 | "A possible cause is that you had installed both onnxruntime and onnxruntime-gpu. " 68 | "Please run the following commands to reinstall: \n " 69 | "1) pip uninstall -y onnxruntime onnxruntime-gpu \n 2) pip install onnxruntime-gpu" 70 | ) 71 | logger.info(">>> [InferBackend] Engine Created ...") 72 | 73 | def infer(self, input_dict: dict): 74 | result = self.predictor.run(None, dict(input_dict)) 75 | return result 76 | 77 | 78 | class PyTorchInferBackend: 79 | def __init__(self, 80 | model_path_prefix, 81 | multilingual=False, 82 | device='cpu', 83 | use_fp16=False): 84 | from model import UIE, UIEM 85 | logger.info(">>> [PyTorchInferBackend] Creating Engine ...") 86 | if multilingual: 87 | self.model = UIEM.from_pretrained(model_path_prefix) 88 | else: 89 | self.model = UIE.from_pretrained(model_path_prefix) 90 | self.model.eval() 91 | self.device = device 92 | if self.device == 'gpu': 93 | logger.info(">>> [PyTorchInferBackend] Use GPU to inference ...") 94 | if use_fp16: 95 | logger.info( 96 | ">>> [PyTorchInferBackend] Use FP16 to inference ...") 97 | self.model = self.model.half() 98 | self.model = self.model.cuda() 99 | else: 100 | logger.info(">>> [PyTorchInferBackend] Use CPU to inference ...") 101 | logger.info(">>> [PyTorchInferBackend] Engine Created ...") 102 | 103 | def infer(self, input_dict): 104 | import torch 105 | for input_name, input_value in input_dict.items(): 106 | input_value = torch.LongTensor(input_value) 107 | if self.device == 'gpu': 108 | input_value = input_value.cuda() 109 | input_dict[input_name] = input_value 110 | 111 | outputs = self.model(**input_dict) 112 | start_prob, end_prob = outputs[0], outputs[1] 113 | if self.device == 'gpu': 114 | start_prob, end_prob = start_prob.cpu(), end_prob.cpu() 115 | start_prob = start_prob.detach().numpy() 116 | end_prob = end_prob.detach().numpy() 117 | return start_prob, end_prob 118 | 119 | 120 | class UIEPredictor(object): 121 | 122 | def __init__(self, model, schema, task_path=None, schema_lang="zh", engine='pytorch', device='cpu', position_prob=0.5, max_seq_len=512, batch_size=64, split_sentence=False, use_fp16=False): 123 | 124 | assert isinstance( 125 | device, six.string_types 126 | ), "The type of device must be string." 127 | assert device in [ 128 | 'cpu', 'gpu'], "The device must be cpu or gpu." 129 | if model in ['uie-m-base', 'uie-m-large']: 130 | self._multilingual = True 131 | else: 132 | self._multilingual = False 133 | self._model = model 134 | self._engine = engine 135 | self._task_path = task_path 136 | self._device = device 137 | self._position_prob = position_prob 138 | self._max_seq_len = max_seq_len 139 | self._batch_size = batch_size 140 | self._split_sentence = split_sentence 141 | self._use_fp16 = use_fp16 142 | 143 | self._schema_tree = None 144 | self._is_en = True if model in ['uie-base-en' 145 | ] or schema_lang == 'en' else False 146 | self.set_schema(schema) 147 | self._prepare_predictor() 148 | 149 | def _prepare_predictor(self): 150 | assert self._engine in ['pytorch', 151 | 'onnx'], "engine must be pytorch or onnx!" 152 | 153 | if self._task_path is None: 154 | self._task_path = self._model.replace('-', '_')+'_pytorch' 155 | if not os.path.exists(self._task_path): 156 | from convert import check_model, extract_and_convert 157 | check_model(self._model) 158 | extract_and_convert(self._model, self._task_path) 159 | 160 | if self._multilingual: 161 | from tokenizer import ErnieMTokenizerFast 162 | self._tokenizer = ErnieMTokenizerFast.from_pretrained( 163 | self._task_path) 164 | else: 165 | from transformers import BertTokenizerFast 166 | self._tokenizer = BertTokenizerFast.from_pretrained( 167 | self._task_path) 168 | 169 | if self._engine == 'pytorch': 170 | self.inference_backend = PyTorchInferBackend( 171 | self._task_path, multilingual=self._multilingual, device=self._device, use_fp16=self._use_fp16) 172 | 173 | if self._engine == 'onnx': 174 | if os.path.exists(os.path.join(self._task_path, "pytorch_model.bin")) and not os.path.exists(os.path.join(self._task_path, "inference.onnx")): 175 | from export_model import export_onnx 176 | from model import UIE, UIEM 177 | if self._multilingual: 178 | model = UIEM.from_pretrained(self._task_path) 179 | else: 180 | model = UIE.from_pretrained(self._task_path) 181 | input_names = [ 182 | 'input_ids', 183 | 'token_type_ids', 184 | 'attention_mask', 185 | ] 186 | output_names = [ 187 | 'start_prob', 188 | 'end_prob' 189 | ] 190 | logger.info( 191 | "Converting to the inference model cost a little time.") 192 | save_path = export_onnx( 193 | self._task_path, self._tokenizer, model, 'cpu', input_names, output_names) 194 | logger.info( 195 | "The inference model save in the path:{}".format(save_path)) 196 | del model 197 | self.inference_backend = ONNXInferBackend( 198 | self._task_path, device=self._device, use_fp16=self._use_fp16) 199 | 200 | def set_schema(self, schema): 201 | if isinstance(schema, dict) or isinstance(schema, str): 202 | schema = [schema] 203 | self._schema_tree = self._build_tree(schema) 204 | 205 | def __call__(self, inputs): 206 | texts = inputs 207 | if isinstance(texts, str): 208 | texts = [texts] 209 | results = self._multi_stage_predict(texts) 210 | return results 211 | 212 | def _multi_stage_predict(self, datas): 213 | """ 214 | Traversal the schema tree and do multi-stage prediction. 215 | Args: 216 | datas (list): a list of strings 217 | Returns: 218 | list: a list of predictions, where the list's length 219 | equals to the length of `datas` 220 | """ 221 | results = [{} for _ in range(len(datas))] 222 | # input check to early return 223 | if len(datas) < 1 or self._schema_tree is None: 224 | return results 225 | 226 | # copy to stay `self._schema_tree` unchanged 227 | schema_list = self._schema_tree.children[:] 228 | while len(schema_list) > 0: 229 | node = schema_list.pop(0) 230 | examples = [] 231 | input_map = {} 232 | cnt = 0 233 | idx = 0 234 | if not node.prefix: 235 | for data in datas: 236 | examples.append({ 237 | "text": data, 238 | "prompt": dbc2sbc(node.name) 239 | }) 240 | input_map[cnt] = [idx] 241 | idx += 1 242 | cnt += 1 243 | else: 244 | for pre, data in zip(node.prefix, datas): 245 | if len(pre) == 0: 246 | input_map[cnt] = [] 247 | else: 248 | for p in pre: 249 | if self._is_en: 250 | if re.search(r'\[.*?\]$', node.name): 251 | prompt_prefix = node.name[:node.name.find( 252 | "[", 1)].strip() 253 | cls_options = re.search( 254 | r'\[.*?\]$', node.name).group() 255 | # Sentiment classification of xxx [positive, negative] 256 | prompt = prompt_prefix + p + " " + cls_options 257 | else: 258 | prompt = node.name + p 259 | else: 260 | prompt = p + node.name 261 | examples.append({ 262 | "text": data, 263 | "prompt": dbc2sbc(prompt) 264 | }) 265 | input_map[cnt] = [i + idx for i in range(len(pre))] 266 | idx += len(pre) 267 | cnt += 1 268 | if len(examples) == 0: 269 | result_list = [] 270 | else: 271 | result_list = self._single_stage_predict(examples) 272 | 273 | if not node.parent_relations: 274 | relations = [[] for i in range(len(datas))] 275 | for k, v in input_map.items(): 276 | for idx in v: 277 | if len(result_list[idx]) == 0: 278 | continue 279 | if node.name not in results[k].keys(): 280 | results[k][node.name] = result_list[idx] 281 | else: 282 | results[k][node.name].extend(result_list[idx]) 283 | if node.name in results[k].keys(): 284 | relations[k].extend(results[k][node.name]) 285 | else: 286 | relations = node.parent_relations 287 | for k, v in input_map.items(): 288 | for i in range(len(v)): 289 | if len(result_list[v[i]]) == 0: 290 | continue 291 | if "relations" not in relations[k][i].keys(): 292 | relations[k][i]["relations"] = { 293 | node.name: result_list[v[i]] 294 | } 295 | elif node.name not in relations[k][i]["relations"].keys( 296 | ): 297 | relations[k][i]["relations"][ 298 | node.name] = result_list[v[i]] 299 | else: 300 | relations[k][i]["relations"][node.name].extend( 301 | result_list[v[i]]) 302 | 303 | new_relations = [[] for i in range(len(datas))] 304 | for i in range(len(relations)): 305 | for j in range(len(relations[i])): 306 | if "relations" in relations[i][j].keys( 307 | ) and node.name in relations[i][j]["relations"].keys(): 308 | for k in range( 309 | len(relations[i][j]["relations"][ 310 | node.name])): 311 | new_relations[i].append(relations[i][j][ 312 | "relations"][node.name][k]) 313 | relations = new_relations 314 | 315 | prefix = [[] for _ in range(len(datas))] 316 | for k, v in input_map.items(): 317 | for idx in v: 318 | for i in range(len(result_list[idx])): 319 | if self._is_en: 320 | prefix[k].append(" of " + 321 | result_list[idx][i]["text"]) 322 | else: 323 | prefix[k].append(result_list[idx][i]["text"] + "的") 324 | 325 | for child in node.children: 326 | child.prefix = prefix 327 | child.parent_relations = relations 328 | schema_list.append(child) 329 | return results 330 | 331 | def _convert_ids_to_results(self, examples, sentence_ids, probs): 332 | """ 333 | Convert ids to raw text in a single stage. 334 | """ 335 | results = [] 336 | for example, sentence_id, prob in zip(examples, sentence_ids, probs): 337 | if len(sentence_id) == 0: 338 | results.append([]) 339 | continue 340 | result_list = [] 341 | text = example["text"] 342 | prompt = example["prompt"] 343 | for i in range(len(sentence_id)): 344 | start, end = sentence_id[i] 345 | if start < 0 and end >= 0: 346 | continue 347 | if end < 0: 348 | start += (len(prompt) + 1) 349 | end += (len(prompt) + 1) 350 | result = {"text": prompt[start:end], 351 | "probability": prob[i]} 352 | result_list.append(result) 353 | else: 354 | result = { 355 | "text": text[start:end], 356 | "start": start, 357 | "end": end, 358 | "probability": prob[i] 359 | } 360 | result_list.append(result) 361 | results.append(result_list) 362 | return results 363 | 364 | def _auto_splitter(self, input_texts, max_text_len, split_sentence=False): 365 | ''' 366 | Split the raw texts automatically for model inference. 367 | Args: 368 | input_texts (List[str]): input raw texts. 369 | max_text_len (int): cutting length. 370 | split_sentence (bool): If True, sentence-level split will be performed. 371 | return: 372 | short_input_texts (List[str]): the short input texts for model inference. 373 | input_mapping (dict): mapping between raw text and short input texts. 374 | ''' 375 | input_mapping = {} 376 | short_input_texts = [] 377 | cnt_org = 0 378 | cnt_short = 0 379 | for text in input_texts: 380 | if not split_sentence: 381 | sens = [text] 382 | else: 383 | sens = cut_chinese_sent(text) 384 | for sen in sens: 385 | lens = len(sen) 386 | if lens <= max_text_len: 387 | short_input_texts.append(sen) 388 | if cnt_org not in input_mapping.keys(): 389 | input_mapping[cnt_org] = [cnt_short] 390 | else: 391 | input_mapping[cnt_org].append(cnt_short) 392 | cnt_short += 1 393 | else: 394 | temp_text_list = [ 395 | sen[i:i + max_text_len] 396 | for i in range(0, lens, max_text_len) 397 | ] 398 | short_input_texts.extend(temp_text_list) 399 | short_idx = cnt_short 400 | cnt_short += math.ceil(lens / max_text_len) 401 | temp_text_id = [ 402 | short_idx + i for i in range(cnt_short - short_idx) 403 | ] 404 | if cnt_org not in input_mapping.keys(): 405 | input_mapping[cnt_org] = temp_text_id 406 | else: 407 | input_mapping[cnt_org].extend(temp_text_id) 408 | cnt_org += 1 409 | return short_input_texts, input_mapping 410 | 411 | def _single_stage_predict(self, inputs): 412 | input_texts = [] 413 | prompts = [] 414 | for i in range(len(inputs)): 415 | input_texts.append(inputs[i]["text"]) 416 | prompts.append(inputs[i]["prompt"]) 417 | # max predict length should exclude the length of prompt and summary tokens 418 | max_predict_len = self._max_seq_len - len(max(prompts)) - 3 419 | 420 | short_input_texts, self.input_mapping = self._auto_splitter( 421 | input_texts, max_predict_len, split_sentence=self._split_sentence) 422 | 423 | short_texts_prompts = [] 424 | for k, v in self.input_mapping.items(): 425 | short_texts_prompts.extend([prompts[k] for i in range(len(v))]) 426 | short_inputs = [{ 427 | "text": short_input_texts[i], 428 | "prompt": short_texts_prompts[i] 429 | } for i in range(len(short_input_texts))] 430 | 431 | sentence_ids = [] 432 | probs = [] 433 | 434 | input_ids = [] 435 | token_type_ids = [] 436 | attention_mask = [] 437 | offset_maps = [] 438 | 439 | if self._multilingual: 440 | padding_type = "max_length" 441 | else: 442 | padding_type = "longest" 443 | encoded_inputs = self._tokenizer( 444 | text=short_texts_prompts, 445 | text_pair=short_input_texts, 446 | stride=2, 447 | truncation=True, 448 | max_length=self._max_seq_len, 449 | padding=padding_type, 450 | add_special_tokens=True, 451 | return_offsets_mapping=True, 452 | return_tensors="np") 453 | 454 | start_prob_concat, end_prob_concat = [], [] 455 | for batch_start in range(0, len(short_input_texts), self._batch_size): 456 | input_ids = encoded_inputs["input_ids"][batch_start:batch_start+self._batch_size] 457 | token_type_ids = encoded_inputs["token_type_ids"][batch_start:batch_start+self._batch_size] 458 | attention_mask = encoded_inputs["attention_mask"][batch_start:batch_start+self._batch_size] 459 | offset_maps = encoded_inputs["offset_mapping"][batch_start:batch_start+self._batch_size] 460 | if self._multilingual: 461 | input_ids = np.array( 462 | input_ids, dtype="int64") 463 | attention_mask = np.array( 464 | attention_mask, dtype="int64") 465 | position_ids = (np.cumsum(np.ones_like(input_ids), axis=1) 466 | - np.ones_like(input_ids))*attention_mask 467 | input_dict = { 468 | "input_ids": input_ids, 469 | "attention_mask": attention_mask, 470 | "position_ids": position_ids 471 | } 472 | else: 473 | input_dict = { 474 | "input_ids": np.array( 475 | input_ids, dtype="int64"), 476 | "token_type_ids": np.array( 477 | token_type_ids, dtype="int64"), 478 | "attention_mask": np.array( 479 | attention_mask, dtype="int64") 480 | } 481 | 482 | outputs = self.inference_backend.infer(input_dict) 483 | start_prob, end_prob = outputs[0], outputs[1] 484 | start_prob_concat.append(start_prob) 485 | end_prob_concat.append(end_prob) 486 | start_prob_concat = np.concatenate(start_prob_concat) 487 | end_prob_concat = np.concatenate(end_prob_concat) 488 | 489 | start_ids_list = get_bool_ids_greater_than( 490 | start_prob_concat, limit=self._position_prob, return_prob=True) 491 | end_ids_list = get_bool_ids_greater_than( 492 | end_prob_concat, limit=self._position_prob, return_prob=True) 493 | 494 | input_ids = input_dict['input_ids'] 495 | sentence_ids = [] 496 | probs = [] 497 | for start_ids, end_ids, ids, offset_map in zip(start_ids_list, 498 | end_ids_list, 499 | input_ids.tolist(), 500 | offset_maps): 501 | for i in reversed(range(len(ids))): 502 | if ids[i] != 0: 503 | ids = ids[:i] 504 | break 505 | span_list = get_span(start_ids, end_ids, with_prob=True) 506 | sentence_id, prob = get_id_and_prob(span_list, offset_map.tolist()) 507 | sentence_ids.append(sentence_id) 508 | probs.append(prob) 509 | 510 | results = self._convert_ids_to_results(short_inputs, sentence_ids, 511 | probs) 512 | results = self._auto_joiner(results, short_input_texts, 513 | self.input_mapping) 514 | return results 515 | 516 | def _auto_joiner(self, short_results, short_inputs, input_mapping): 517 | concat_results = [] 518 | is_cls_task = False 519 | for short_result in short_results: 520 | if short_result == []: 521 | continue 522 | elif 'start' not in short_result[0].keys( 523 | ) and 'end' not in short_result[0].keys(): 524 | is_cls_task = True 525 | break 526 | else: 527 | break 528 | for k, vs in input_mapping.items(): 529 | if is_cls_task: 530 | cls_options = {} 531 | single_results = [] 532 | for v in vs: 533 | if len(short_results[v]) == 0: 534 | continue 535 | if short_results[v][0]['text'] not in cls_options.keys(): 536 | cls_options[short_results[v][0][ 537 | 'text']] = [1, short_results[v][0]['probability']] 538 | else: 539 | cls_options[short_results[v][0]['text']][0] += 1 540 | cls_options[short_results[v][0]['text']][ 541 | 1] += short_results[v][0]['probability'] 542 | if len(cls_options) != 0: 543 | cls_res, cls_info = max(cls_options.items(), 544 | key=lambda x: x[1]) 545 | concat_results.append([{ 546 | 'text': cls_res, 547 | 'probability': cls_info[1] / cls_info[0] 548 | }]) 549 | else: 550 | concat_results.append([]) 551 | else: 552 | offset = 0 553 | single_results = [] 554 | for v in vs: 555 | if v == 0: 556 | single_results = short_results[v] 557 | offset += len(short_inputs[v]) 558 | else: 559 | for i in range(len(short_results[v])): 560 | if 'start' not in short_results[v][ 561 | i] or 'end' not in short_results[v][i]: 562 | continue 563 | short_results[v][i]['start'] += offset 564 | short_results[v][i]['end'] += offset 565 | offset += len(short_inputs[v]) 566 | single_results.extend(short_results[v]) 567 | concat_results.append(single_results) 568 | return concat_results 569 | 570 | def predict(self, input_data): 571 | results = self._multi_stage_predict(input_data) 572 | return results 573 | 574 | @classmethod 575 | def _build_tree(cls, schema, name='root'): 576 | """ 577 | Build the schema tree. 578 | """ 579 | schema_tree = SchemaTree(name) 580 | for s in schema: 581 | if isinstance(s, str): 582 | schema_tree.add_child(SchemaTree(s)) 583 | elif isinstance(s, dict): 584 | for k, v in s.items(): 585 | if isinstance(v, str): 586 | child = [v] 587 | elif isinstance(v, list): 588 | child = v 589 | else: 590 | raise TypeError( 591 | "Invalid schema, value for each key:value pairs should be list or string" 592 | "but {} received".format(type(v))) 593 | schema_tree.add_child(cls._build_tree(child, name=k)) 594 | else: 595 | raise TypeError( 596 | "Invalid schema, element should be string or dict, " 597 | "but {} received".format(type(s))) 598 | return schema_tree 599 | 600 | 601 | class SchemaTree(object): 602 | """ 603 | Implementataion of SchemaTree 604 | """ 605 | 606 | def __init__(self, name='root', children=None): 607 | self.name = name 608 | self.children = [] 609 | self.prefix = None 610 | self.parent_relations = None 611 | if children is not None: 612 | for child in children: 613 | self.add_child(child) 614 | 615 | def __repr__(self): 616 | return self.name 617 | 618 | def add_child(self, node): 619 | assert isinstance( 620 | node, SchemaTree 621 | ), "The children of a node should be an instacne of SchemaTree." 622 | self.children.append(node) 623 | 624 | 625 | def parse_args(): 626 | parser = argparse.ArgumentParser() 627 | # Required parameters 628 | parser.add_argument( 629 | "-m", 630 | "--model", 631 | type=str, 632 | default='uie-base', 633 | help="The model to be used.", ) 634 | parser.add_argument( 635 | "-t", 636 | "--task_path", 637 | type=str, 638 | default=None, 639 | help="The path prefix of custom inference model to be used.", ) 640 | parser.add_argument( 641 | "-p", 642 | "--position_prob", 643 | default=0.5, 644 | type=float, 645 | help="Probability threshold for start/end index probabiliry.", ) 646 | parser.add_argument( 647 | "--use_fp16", 648 | action='store_true', 649 | help="Whether to use fp16 inference, only takes effect when deploying on gpu.", 650 | ) 651 | parser.add_argument( 652 | "--max_seq_len", 653 | default=512, 654 | type=int, 655 | help="The maximum input sequence length. Sequences longer than this will be split automatically.", 656 | ) 657 | parser.add_argument( 658 | "-D", 659 | "--device", 660 | choices=['cpu', 'gpu'], 661 | default="gpu", 662 | help="Select which device to run model, defaults to gpu." 663 | ) 664 | parser.add_argument( 665 | "-e", 666 | "--engine", 667 | choices=['pytorch', 'onnx'], 668 | default="pytorch", 669 | help="Select which engine to run model, defaults to pytorch." 670 | ) 671 | args = parser.parse_args() 672 | return args 673 | 674 | 675 | if __name__ == '__main__': 676 | args = parse_args() 677 | args.schema = ['航母'] 678 | args.schema_lang = "en" 679 | uie = UIEPredictor(model=args.model, task_path=args.task_path, schema_lang=args.schema_lang, schema=args.schema, engine=args.engine, device=args.device, 680 | position_prob=args.position_prob, max_seq_len=args.max_seq_len, batch_size=64, split_sentence=False, use_fp16=args.use_fp16) 681 | print(uie("印媒所称的“印度第一艘国产航母”—“维克兰特”号")) 682 | --------------------------------------------------------------------------------