├── .gitignore ├── LICENSE ├── README.md ├── bbcm ├── __init__.py ├── config │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-37.pyc │ │ └── defaults.cpython-37.pyc │ ├── config.py │ └── defaults.py ├── data │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-37.pyc │ │ └── build.cpython-37.pyc │ ├── build.py │ ├── datasets │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-37.pyc │ │ │ └── csc.cpython-37.pyc │ │ └── csc.py │ ├── loaders │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-37.pyc │ │ │ └── csc.cpython-37.pyc │ │ ├── collator.py │ │ └── csc.py │ └── processors │ │ ├── __init__.py │ │ ├── __pycache__ │ │ ├── __init__.cpython-37.pyc │ │ └── csc.cpython-37.pyc │ │ └── csc.py ├── engine │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-37.pyc │ │ ├── bases.cpython-37.pyc │ │ └── csc_trainer.cpython-37.pyc │ ├── bases.py │ └── csc_trainer.py ├── layers │ └── __init__.py ├── modeling │ ├── __init__.py │ ├── __pycache__ │ │ └── __init__.cpython-37.pyc │ └── csc │ │ ├── __init__.py │ │ ├── __pycache__ │ │ ├── __init__.cpython-37.pyc │ │ ├── modeling_bert4csc.cpython-37.pyc │ │ └── modeling_soft_masked_bert.cpython-37.pyc │ │ ├── modeling_abert4csc.py │ │ ├── modeling_bert4csc.py │ │ └── modeling_soft_masked_bert.py ├── solver │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-37.pyc │ │ └── build.cpython-37.pyc │ ├── build.py │ ├── losses.py │ └── lr_scheduler.py └── utils │ ├── __init__.py │ ├── __pycache__ │ ├── __init__.cpython-37.pyc │ ├── evaluations.cpython-37.pyc │ ├── file_io.cpython-37.pyc │ └── logger.cpython-37.pyc │ ├── evaluations.py │ ├── file_io.py │ └── logger.py ├── checkpoints └── .gitignore ├── configs ├── csc │ ├── train_SoftMaskedBert.yml │ ├── train_bert4csc.yml │ └── train_macbert4csc.yml └── dict │ └── white_name_list.json ├── datasets ├── .gitignore └── csc │ └── .gitignore ├── requirements.txt ├── tests └── test_embedding.py └── tools ├── bases.py ├── convert_to_pure_state_dict.py ├── inference.py ├── train_csc.py └── train_csc.sh /.gitignore: -------------------------------------------------------------------------------- 1 | .idea 2 | .DS_Store 3 | __pycache__ 4 | lightning_logs 5 | 6 | # Byte-compiled / optimized / DLL files 7 | __pycache__/ 8 | *.py[cod] 9 | *$py.class 10 | 11 | # C extensions 12 | *.so 13 | 14 | # Distribution / packaging 15 | .Python 16 | env/ 17 | build/ 18 | develop-eggs/ 19 | dist/ 20 | downloads/ 21 | eggs/ 22 | .eggs/ 23 | lib/ 24 | lib64/ 25 | parts/ 26 | sdist/ 27 | var/ 28 | wheels/ 29 | *.egg-info/ 30 | .installed.cfg 31 | *.egg 32 | 33 | # PyInstaller 34 | # Usually these files are written by a python script from a template 35 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 36 | *.manifest 37 | *.spec 38 | 39 | # Installer logs 40 | pip-log.txt 41 | pip-delete-this-directory.txt 42 | 43 | # Unit test / coverage reports 44 | htmlcov/ 45 | .tox/ 46 | .coverage 47 | .coverage.* 48 | .cache 49 | nosetests.xml 50 | coverage.xml 51 | *.cover 52 | .hypothesis/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | 62 | # Flask stuff: 63 | instance/ 64 | .webassets-cache 65 | 66 | # Scrapy stuff: 67 | .scrapy 68 | 69 | # Sphinx documentation 70 | docs/_build/ 71 | 72 | # PyBuilder 73 | target/ 74 | 75 | # Jupyter Notebook 76 | .ipynb_checkpoints 77 | 78 | # pyenv 79 | .python-version 80 | 81 | # celery beat schedule file 82 | celerybeat-schedule 83 | 84 | # SageMath parsed files 85 | *.sage.py 86 | 87 | # dotenv 88 | .env 89 | 90 | # virtualenv 91 | .venv 92 | venv/ 93 | ENV/ 94 | 95 | # Spyder project settings 96 | .spyderproject 97 | .spyproject 98 | 99 | # Rope project settings 100 | .ropeproject 101 | 102 | # mkdocs documentation 103 | /site 104 | 105 | # mypy 106 | .mypy_cache/ -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | 179 | APPENDIX: How to apply the Apache License to your work. 180 | 181 | To apply the Apache License to your work, attach the following 182 | boilerplate notice, with the fields enclosed by brackets "[]" 183 | replaced with your own identifying information. (Don't include 184 | the brackets!) The text should be enclosed in the appropriate 185 | comment syntax for the file format. We also recommend that a 186 | file or class name and description of purpose be included on the 187 | same "printed page" as the copyright notice for easier 188 | identification within third-party archives. 189 | 190 | Copyright [yyyy] [name of copyright owner] 191 | 192 | Licensed under the Apache License, Version 2.0 (the "License"); 193 | you may not use this file except in compliance with the License. 194 | You may obtain a copy of the License at 195 | 196 | http://www.apache.org/licenses/LICENSE-2.0 197 | 198 | Unless required by applicable law or agreed to in writing, software 199 | distributed under the License is distributed on an "AS IS" BASIS, 200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 201 | See the License for the specific language governing permissions and 202 | limitations under the License. 203 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # BertBasedCorrectionModels 2 | 3 | 基于BERT的文本纠错模型,使用PyTorch实现 4 | 5 | ## 数据准备 6 | 1. 从 [http://nlp.ee.ncu.edu.tw/resource/csc.html](http://nlp.ee.ncu.edu.tw/resource/csc.html)下载SIGHAN数据集 7 | 2. 解压上述数据集并将文件夹中所有 ''.sgml'' 文件复制至 datasets/csc/ 目录 8 | 3. 复制 ''SIGHAN15_CSC_TestInput.txt'' 和 ''SIGHAN15_CSC_TestTruth.txt'' 至 datasets/csc/ 目录 9 | 4. 下载 [https://github.com/wdimmy/Automatic-Corpus-Generation/blob/master/corpus/train.sgml](https://github.com/wdimmy/Automatic-Corpus-Generation/blob/master/corpus/train.sgml) 至 datasets/csc 目录 10 | 5. 请确保以下文件在 datasets/csc 中 11 | ``` 12 | train.sgml 13 | B1_training.sgml 14 | C1_training.sgml 15 | SIGHAN15_CSC_A2_Training.sgml 16 | SIGHAN15_CSC_B2_Training.sgml 17 | SIGHAN15_CSC_TestInput.txt 18 | SIGHAN15_CSC_TestTruth.txt 19 | ``` 20 | ps: sighan下载链接已失效,github有其分发版本,见[https://github.com/NYCU-NLP/SIGHAN-CSC](https://github.com/NYCU-NLP/SIGHAN-CSC) 21 | 22 | ## 环境准备 23 | 1. 使用已有编码环境或通过 `conda create -n python=3.7` 创建一个新环境(推荐) 24 | 2. 克隆本项目并进入项目根目录 25 | 3. 安装所需依赖 `pip install -r requirements.txt` 26 | 4. 如果出现报错 GLIBC 版本过低的问题(GLIBC 的版本更迭容易出事故,不推荐更新),openCC 改为安装较低版本(例如 1.1.0) 27 | 5. 在当前终端将此目录加入环境变量 `export PYTHONPATH=.` 28 | 29 | 30 | ## 训练 31 | 32 | 运行以下命令以训练模型,首次运行会自动处理数据。 33 | ```shell 34 | python tools/train_csc.py --config_file csc/train_SoftMaskedBert.yml 35 | ``` 36 | 37 | 可选择不同配置文件以训练不同模型,目前支持以下配置文件: 38 | - train_bert4csc.yml 39 | - train_macbert4csc.yml 40 | - train_SoftMaskedBert.yml 41 | 42 | 如有其他需求,可根据需要自行调整配置文件中的参数。 43 | 44 | ## 实验结果 45 | 46 | ### SoftMaskedBert 47 | |component|sentence level acc|p|r|f| 48 | |:-:|:-:|:-:|:-:|:-:| 49 | |Detection|0.5045|0.8252|0.8416|0.8333| 50 | |Correction|0.8055|0.9395|0.8748|0.9060| 51 | 52 | ### Bert类 53 | #### char level 54 | |MODEL|p|r|f| 55 | |:-:|:-:|:-:|:-:| 56 | |BERT4CSC|0.9269|0.8651|0.8949| 57 | |MACBERT4CSC|0.9380|0.8736|0.9047| 58 | 59 | #### sentence level 60 | |model|acc|p|r|f| 61 | |:-:|:-:|:-:|:-:|:-:| 62 | |BERT4CSC|0.7990|0.8482|0.7214|0.7797| 63 | |MACBERT4CSC|0.8027|0.8525|0.7251|0.7836| 64 | 65 | ## 推理 66 | ### 方法一,使用inference脚本: 67 | ```shell 68 | cd tools 69 | python inference.py --ckpt_fn epoch=0-val_loss=0.03.ckpt --texts "我今天很高心" 70 | 推理输出:['我今天很高兴'] 71 | 72 | # 或给出line by line格式的文本地址 73 | cd tools 74 | python inference.py --ckpt_fn epoch=0-val_loss=0.03.ckpt --text_file ./ml/data/text.txt 75 | 推理输出:['我今天很高兴', '你这个辣鸡模型只能做错别字纠正'] 76 | ``` 77 | 其中/ml/data/text.txt文本如下: 78 | ```text 79 | 我今天很高心 80 | 你这个辣鸡模型只能做错别字纠正 81 | ``` 82 | 83 | ### 方法二,直接调用 84 | ```python 85 | from tools.inference import * 86 | ckpt_fn = 'SoftMaskedBert/epoch=02-val_loss=0.02904.ckpt' # find it in checkpoints/ 87 | config_file = 'csc/train_SoftMaskedBert.yml' # find it in configs/ 88 | model = load_model_directly(ckpt_fn, config_file) 89 | texts = ['今天我很高心', '测试', '继续测试'] 90 | model.predict(texts) 91 | 推理输出: 92 | ['今天我很高兴', '测试', '继续测试'] 93 | 94 | ``` 95 | ### 方法三、导出bert权重,使用transformers或pycorrector调用 96 | 1. 使用convert_to_pure_state_dict.py导出bert权重 97 | 2. 后续步骤参考[https://github.com/shibing624/pycorrector/blob/master/examples/macbert/README.md](https://github.com/shibing624/pycorrector/blob/master/examples/macbert/README.md) 98 | 99 | ## 模型下载 100 | 1. SoftMaskedBert、macbert4csc及bert4csc三个模型文件夹及训练参数可从[模型文件](https://pan.baidu.com/s/1TKFFTLuEFXNh-g7xBY0IOg?pwd=za92)下载后放入: BertBasedCorrectionModels/checkpoints/路径下推理使用。 101 | 102 | 103 | 104 | ## 引用 105 | 如果你在研究中使用了本项目,请按如下格式引用: 106 | 107 | ``` 108 | @article{cai2020pre, 109 | title={BERT Based Correction Models}, 110 | author={Cai, Heng and Chen, Dian}, 111 | journal={GitHub. Note: https://github.com/gitabtion/BertBasedCorrectionModels}, 112 | year={2020} 113 | } 114 | ``` 115 | 116 | ## License 117 | 本源代码的授权协议为 Apache License 2.0,可免费用做商业用途。请在产品说明中附加本项目的链接和授权协议。本项目受版权法保护,侵权必究。 118 | 119 | 120 | ## 更新记录 121 | 122 | ### 20240104 123 | 1. update the web link in readme file as preivous one is not vaild. 124 | 125 | ### 20220517 126 | 1. trained and add the trained SoftMaskedBert、macbert4csc及bert4csc checkpoint model file download info in README.md 127 | 128 | ### 20220513 129 | 1. new add the model inference postprocess module in tools/inference.py 130 | 131 | ### 20220511 132 | 1. 更新README.md 推理部分方法1+2 code 133 | 134 | ### 20210618 135 | 1. 修复数据处理的编码报错问题 136 | 137 | ### 20210518 138 | 1. 将BERT4CSC检错任务改为使用FocalLoss 139 | 2. 更新修改后的模型实验结果 140 | 3. 降低数据处理时保留原文的概率 141 | 142 | ### 20210517 143 | 1. 对BERT4CSC模型新增检错任务 144 | 2. 新增基于LineByLine文件的inference 145 | 146 | ## References 147 | 1. [Spelling Error Correction with Soft-Masked BERT](https://arxiv.org/abs/2005.07421) 148 | 2. [http://ir.itc.ntnu.edu.tw/lre/sighan8csc.html](http://ir.itc.ntnu.edu.tw/lre/sighan8csc.html) 149 | 3. [https://github.com/wdimmy/Automatic-Corpus-Generation](https://github.com/wdimmy/Automatic-Corpus-Generation) 150 | 4. [transformers](https://huggingface.co/) 151 | 5. [https://github.com/sunnyqiny/Confusionset-guided-Pointer-Networks-for-Chinese-Spelling-Check](https://github.com/sunnyqiny/Confusionset-guided-Pointer-Networks-for-Chinese-Spelling-Check) 152 | 6. [SoftMaskedBert-PyTorch](https://github.com/gitabtion/SoftMaskedBert-PyTorch) 153 | 7. [Deep-Learning-Project-Template](https://github.com/L1aoXingyu/Deep-Learning-Project-Template) 154 | 8. [https://github.com/lonePatient/TorchBlocks](https://github.com/lonePatient/TorchBlocks) 155 | 9. [https://github.com/shibing624/pycorrector](https://github.com/shibing624/pycorrector) 156 | -------------------------------------------------------------------------------- /bbcm/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | @Time : 2021-01-21 10:35:24 3 | @File : __init__.py.py 4 | @Author : Abtion 5 | @Email : abtion{at}outlook.com 6 | """ 7 | -------------------------------------------------------------------------------- /bbcm/config/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | @Time : 2021-01-21 10:37:06 3 | @File : __init__.py.py 4 | @Author : Abtion 5 | @Email : abtion{at}outlook.com 6 | """ 7 | from .defaults import _C as cfg 8 | -------------------------------------------------------------------------------- /bbcm/config/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gitabtion/BertBasedCorrectionModels/f5e1bacf4f4f9b4a9ca434e9d40249285172a06a/bbcm/config/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /bbcm/config/__pycache__/defaults.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gitabtion/BertBasedCorrectionModels/f5e1bacf4f4f9b4a9ca434e9d40249285172a06a/bbcm/config/__pycache__/defaults.cpython-37.pyc -------------------------------------------------------------------------------- /bbcm/config/config.py: -------------------------------------------------------------------------------- 1 | """ 2 | @Time : 2021-01-21 10:37:24 3 | @File : config.py 4 | @Author : Abtion 5 | @Email : abtion{at}outlook.com 6 | """ 7 | import logging 8 | import os 9 | from typing import Any 10 | 11 | import yaml 12 | from yacs.config import CfgNode as _CfgNode 13 | 14 | from ..utils.file_io import PathManager 15 | 16 | BASE_KEY = "_BASE_" 17 | 18 | 19 | class CfgNode(_CfgNode): 20 | """ 21 | Our own extended version of :class:`yacs.config.CfgNode`. 22 | It contains the following extra features: 23 | 1. The :meth:`merge_from_file` method supports the "_BASE_" key, 24 | which allows the new CfgNode to inherit all the attributes from the 25 | base configuration file. 26 | 2. Keys that start with "COMPUTED_" are treated as insertion-only 27 | "computed" attributes. They can be inserted regardless of whether 28 | the CfgNode is frozen or not. 29 | 3. With "allow_unsafe=True", it supports pyyaml tags that evaluate 30 | expressions in config. See examples in 31 | https://pyyaml.org/wiki/PyYAMLDocumentation#yaml-tags-and-python-types 32 | Note that this may lead to arbitrary code execution: you must not 33 | load a config file from untrusted sources before manually inspecting 34 | the content of the file. 35 | """ 36 | 37 | @staticmethod 38 | def load_yaml_with_base(filename: str, allow_unsafe: bool = False): 39 | """ 40 | Just like `yaml.load(open(filename))`, but inherit attributes from its 41 | `_BASE_`. 42 | Args: 43 | filename (str): the file name of the current config. Will be used to 44 | find the base config file. 45 | allow_unsafe (bool): whether to allow loading the config file with 46 | `yaml.unsafe_load`. 47 | Returns: 48 | (dict): the loaded yaml 49 | """ 50 | with PathManager.open(filename, "r") as f: 51 | try: 52 | cfg = yaml.safe_load(f) 53 | except yaml.constructor.ConstructorError: 54 | if not allow_unsafe: 55 | raise 56 | logger = logging.getLogger(__name__) 57 | logger.warning( 58 | "Loading config {} with yaml.unsafe_load. Your machine may " 59 | "be at risk if the file contains malicious content.".format( 60 | filename 61 | ) 62 | ) 63 | f.close() 64 | with open(filename, "r") as f: 65 | cfg = yaml.unsafe_load(f) 66 | 67 | def merge_a_into_b(a, b): 68 | # merge dict a into dict b. values in a will overwrite b. 69 | for k, v in a.items(): 70 | if isinstance(v, dict) and k in b: 71 | assert isinstance( 72 | b[k], dict 73 | ), "Cannot inherit key '{}' from base!".format(k) 74 | merge_a_into_b(v, b[k]) 75 | else: 76 | b[k] = v 77 | 78 | if BASE_KEY in cfg: 79 | base_cfg_file = cfg[BASE_KEY] 80 | if base_cfg_file.startswith("~"): 81 | base_cfg_file = os.path.expanduser(base_cfg_file) 82 | if not any( 83 | map(base_cfg_file.startswith, ["/", "https://", "http://"]) 84 | ): 85 | # the path to base cfg is relative to the config file itself. 86 | base_cfg_file = os.path.join( 87 | os.path.dirname(filename), base_cfg_file 88 | ) 89 | base_cfg = CfgNode.load_yaml_with_base( 90 | base_cfg_file, allow_unsafe=allow_unsafe 91 | ) 92 | del cfg[BASE_KEY] 93 | 94 | merge_a_into_b(cfg, base_cfg) 95 | return base_cfg 96 | return cfg 97 | 98 | def merge_from_file(self, cfg_filename: str, allow_unsafe: bool = False): 99 | """ 100 | Merge configs from a given yaml file. 101 | Args: 102 | cfg_filename: the file name of the yaml config. 103 | allow_unsafe: whether to allow loading the config file with 104 | `yaml.unsafe_load`. 105 | """ 106 | loaded_cfg = CfgNode.load_yaml_with_base( 107 | cfg_filename, allow_unsafe=allow_unsafe 108 | ) 109 | loaded_cfg = type(self)(loaded_cfg) 110 | self.merge_from_other_cfg(loaded_cfg) 111 | 112 | # Forward the following calls to base, but with a check on the BASE_KEY. 113 | def merge_from_other_cfg(self, cfg_other): 114 | """ 115 | Args: 116 | cfg_other (CfgNode): configs to merge from. 117 | """ 118 | assert ( 119 | BASE_KEY not in cfg_other 120 | ), "The reserved key '{}' can only be used in files!".format(BASE_KEY) 121 | return super().merge_from_other_cfg(cfg_other) 122 | 123 | def merge_from_list(self, cfg_list: list): 124 | """ 125 | Args: 126 | cfg_list (list): list of configs to merge from. 127 | """ 128 | keys = set(cfg_list[0::2]) 129 | assert ( 130 | BASE_KEY not in keys 131 | ), "The reserved key '{}' can only be used in files!".format(BASE_KEY) 132 | return super().merge_from_list(cfg_list) 133 | 134 | def __setattr__(self, name: str, val: Any): 135 | if name.startswith("COMPUTED_"): 136 | if name in self: 137 | old_val = self[name] 138 | if old_val == val: 139 | return 140 | raise KeyError( 141 | "Computed attributed '{}' already exists " 142 | "with a different value! old={}, new={}.".format( 143 | name, old_val, val 144 | ) 145 | ) 146 | self[name] = val 147 | else: 148 | super().__setattr__(name, val) 149 | 150 | 151 | def get_cfg() -> CfgNode: 152 | """ 153 | Get a copy of the default config. 154 | Returns: 155 | a fastreid CfgNode instance. 156 | """ 157 | from .defaults import _C 158 | 159 | return _C.clone() 160 | -------------------------------------------------------------------------------- /bbcm/config/defaults.py: -------------------------------------------------------------------------------- 1 | """ 2 | @Time : 2021-01-21 10:37:36 3 | @File : defaults.py 4 | @Author : Abtion 5 | @Email : abtion{at}outlook.com 6 | """ 7 | from yacs.config import CfgNode as CN 8 | 9 | # ----------------------------------------------------------------------------- 10 | # Convention about Training / Test specific parameters 11 | # ----------------------------------------------------------------------------- 12 | # Whenever an argument can be either used for training or for testing, the 13 | # corresponding name will be post-fixed by a _TRAIN for a training parameter, 14 | # or _TEST for a test-specific parameter. 15 | # For example, the number of images during training will be 16 | # IMAGES_PER_BATCH_TRAIN, while the number of images for testing will be 17 | # IMAGES_PER_BATCH_TEST 18 | 19 | # ----------------------------------------------------------------------------- 20 | # Config definition 21 | # ----------------------------------------------------------------------------- 22 | 23 | _C = CN() 24 | 25 | _C.MODEL = CN() 26 | _C.MODEL.DEVICE = "cpu" 27 | _C.MODEL.GPU_IDS = [0] 28 | _C.MODEL.NUM_CLASSES = 10 29 | _C.MODEL.BERT_CKPT = 'bert-base-chinese' 30 | _C.MODEL.NAME = '' 31 | _C.MODEL.WEIGHTS = '' 32 | _C.MODEL.HYPER_PARAMS = [] 33 | 34 | # ----------------------------------------------------------------------------- 35 | # INPUT 36 | # ----------------------------------------------------------------------------- 37 | _C.INPUT = CN() 38 | # Max length of input text. 39 | _C.INPUT.MAX_LEN = 512 40 | 41 | 42 | # ----------------------------------------------------------------------------- 43 | # Dataset 44 | # ----------------------------------------------------------------------------- 45 | _C.DATASETS = CN() 46 | # List of the dataset names for training, as present in paths_catalog.py 47 | _C.DATASETS.TRAIN = "" 48 | # List of the dataset names for validation, as present in paths_catalog.py 49 | _C.DATASETS.VALID = "" 50 | # List of the dataset names for testing, as present in paths_catalog.py 51 | _C.DATASETS.TEST = "" 52 | 53 | # ----------------------------------------------------------------------------- 54 | # DataLoader 55 | # ----------------------------------------------------------------------------- 56 | _C.DATALOADER = CN() 57 | # Number of data loading threads 58 | _C.DATALOADER.NUM_WORKERS = 4 59 | 60 | # ---------------------------------------------------------------------------- # 61 | # Solver 62 | # ---------------------------------------------------------------------------- # 63 | _C.SOLVER = CN() 64 | _C.SOLVER.OPTIMIZER_NAME = "AdamW" 65 | 66 | _C.SOLVER.MAX_EPOCHS = 50 67 | 68 | _C.SOLVER.BASE_LR = 0.001 69 | _C.SOLVER.BIAS_LR_FACTOR = 2 70 | 71 | _C.SOLVER.MOMENTUM = 0.9 72 | 73 | _C.SOLVER.WEIGHT_DECAY = 0.0005 74 | _C.SOLVER.WEIGHT_DECAY_BIAS = 0 75 | 76 | _C.SOLVER.GAMMA = 0.9999 77 | _C.SOLVER.STEPS = (10,) 78 | _C.SOLVER.SCHED = "WarmupExponentialLR" 79 | _C.SOLVER.WARMUP_FACTOR = 0.01 80 | _C.SOLVER.WARMUP_ITERS = 2 81 | _C.SOLVER.WARMUP_EPOCHS = 1024 82 | _C.SOLVER.WARMUP_METHOD = "linear" 83 | _C.SOLVER.DELAY_ITERS = 0 84 | _C.SOLVER.ETA_MIN_LR = 3e-7 85 | _C.SOLVER.MAX_ITER = 10 86 | _C.SOLVER.INTERVAL = 'step' 87 | 88 | _C.SOLVER.CHECKPOINT_PERIOD = 10 89 | _C.SOLVER.LOG_PERIOD = 100 90 | _C.SOLVER.ACCUMULATE_GRAD_BATCHES = 1 91 | # Number of images per batch 92 | # This is global, so if we have 8 GPUs and IMS_PER_BATCH = 16, each GPU will 93 | # see 2 images per batch 94 | _C.SOLVER.BATCH_SIZE = 16 95 | 96 | 97 | _C.TEST = CN() 98 | _C.TEST.BATCH_SIZE = 8 99 | _C.TEST.CKPT_FN = "" 100 | 101 | # ---------------------------------------------------------------------------- # 102 | # Task specific 103 | # ---------------------------------------------------------------------------- # 104 | _C.TASK = CN() 105 | _C.TASK.NAME = "CSC" 106 | 107 | 108 | # ---------------------------------------------------------------------------- # 109 | # Misc options 110 | # ---------------------------------------------------------------------------- # 111 | _C.OUTPUT_DIR = "" 112 | _C.MODE = ['train', 'test'] 113 | 114 | 115 | -------------------------------------------------------------------------------- /bbcm/data/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | @Time : 2021-01-21 11:21:51 3 | @File : __init__.py.py 4 | @Author : Abtion 5 | @Email : abtion{at}outlook.com 6 | """ 7 | -------------------------------------------------------------------------------- /bbcm/data/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gitabtion/BertBasedCorrectionModels/f5e1bacf4f4f9b4a9ca434e9d40249285172a06a/bbcm/data/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /bbcm/data/__pycache__/build.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gitabtion/BertBasedCorrectionModels/f5e1bacf4f4f9b4a9ca434e9d40249285172a06a/bbcm/data/__pycache__/build.cpython-37.pyc -------------------------------------------------------------------------------- /bbcm/data/build.py: -------------------------------------------------------------------------------- 1 | """ 2 | @Time : 2021-01-21 14:20:50 3 | @File : build.py 4 | @Author : Abtion 5 | @Email : abtion{at}outlook.com 6 | """ 7 | from bbcm.utils import get_abs_path 8 | 9 | 10 | def make_loaders(cfg, get_loader_fn, **kwargs): 11 | if cfg.DATASETS.TRAIN == '': 12 | train_loader = None 13 | else: 14 | train_loader = get_loader_fn(get_abs_path(cfg.DATASETS.TRAIN), 15 | batch_size=cfg.SOLVER.BATCH_SIZE, 16 | shuffle=True, 17 | num_workers=cfg.DATALOADER.NUM_WORKERS, **kwargs) 18 | if cfg.DATASETS.VALID == '': 19 | valid_loader = None 20 | else: 21 | valid_loader = get_loader_fn(get_abs_path(cfg.DATASETS.VALID), 22 | batch_size=cfg.TEST.BATCH_SIZE, 23 | shuffle=False, 24 | num_workers=cfg.DATALOADER.NUM_WORKERS, **kwargs) 25 | if cfg.DATASETS.TEST == '': 26 | test_loader = None 27 | else: 28 | test_loader = get_loader_fn(get_abs_path(cfg.DATASETS.TEST), 29 | batch_size=cfg.TEST.BATCH_SIZE, 30 | shuffle=False, 31 | num_workers=cfg.DATALOADER.NUM_WORKERS, **kwargs) 32 | return train_loader, valid_loader, test_loader 33 | -------------------------------------------------------------------------------- /bbcm/data/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | @Time : 2021-01-21 11:23:35 3 | @File : __init__.py.py 4 | @Author : Abtion 5 | @Email : abtion{at}outlook.com 6 | """ 7 | -------------------------------------------------------------------------------- /bbcm/data/datasets/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gitabtion/BertBasedCorrectionModels/f5e1bacf4f4f9b4a9ca434e9d40249285172a06a/bbcm/data/datasets/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /bbcm/data/datasets/__pycache__/csc.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gitabtion/BertBasedCorrectionModels/f5e1bacf4f4f9b4a9ca434e9d40249285172a06a/bbcm/data/datasets/__pycache__/csc.cpython-37.pyc -------------------------------------------------------------------------------- /bbcm/data/datasets/csc.py: -------------------------------------------------------------------------------- 1 | """ 2 | @Time : 2021-01-21 11:24:00 3 | @File : csc.py 4 | @Author : Abtion 5 | @Email : abtion{at}outlook.com 6 | """ 7 | from torch.utils.data import Dataset 8 | 9 | from bbcm.utils import load_json 10 | 11 | 12 | class CscDataset(Dataset): 13 | def __init__(self, fp): 14 | self.data = load_json(fp) 15 | 16 | def __len__(self): 17 | return len(self.data) 18 | 19 | def __getitem__(self, index): 20 | return self.data[index]['original_text'], self.data[index]['correct_text'], self.data[index]['wrong_ids'] 21 | -------------------------------------------------------------------------------- /bbcm/data/loaders/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | @Time : 2021-01-21 14:58:14 3 | @File : __init__.py.py 4 | @Author : Abtion 5 | @Email : abtion{at}outlook.com 6 | """ 7 | from .csc import * -------------------------------------------------------------------------------- /bbcm/data/loaders/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gitabtion/BertBasedCorrectionModels/f5e1bacf4f4f9b4a9ca434e9d40249285172a06a/bbcm/data/loaders/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /bbcm/data/loaders/__pycache__/csc.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gitabtion/BertBasedCorrectionModels/f5e1bacf4f4f9b4a9ca434e9d40249285172a06a/bbcm/data/loaders/__pycache__/csc.cpython-37.pyc -------------------------------------------------------------------------------- /bbcm/data/loaders/collator.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class DataCollatorForCsc: 5 | def __init__(self, tokenizer): 6 | self.tokenizer = tokenizer 7 | 8 | def __call__(self, data): 9 | ori_texts, cor_texts, wrong_idss = zip(*data) 10 | encoded_texts = [self.tokenizer.tokenize(t) for t in ori_texts] 11 | max_len = max([len(t) for t in encoded_texts]) + 2 12 | det_labels = torch.zeros(len(ori_texts), max_len).long() 13 | for i, (encoded_text, wrong_ids) in enumerate(zip(encoded_texts, wrong_idss)): 14 | for idx in wrong_ids: 15 | margins = [] 16 | for word in encoded_text[:idx]: 17 | if word == '[UNK]': 18 | break 19 | if word.startswith('##'): 20 | margins.append(len(word) - 3) 21 | else: 22 | margins.append(len(word) - 1) 23 | margin = sum(margins) 24 | move = 0 25 | while (abs(move) < margin) or (idx + move >= len(encoded_text)) or encoded_text[idx + move].startswith( 26 | '##'): 27 | move -= 1 28 | det_labels[i, idx + move + 1] = 1 29 | return ori_texts, cor_texts, det_labels 30 | -------------------------------------------------------------------------------- /bbcm/data/loaders/csc.py: -------------------------------------------------------------------------------- 1 | """ 2 | @Time : 2021-01-21 14:58:30 3 | @File : csc.py 4 | @Author : Abtion 5 | @Email : abtion{at}outlook.com 6 | """ 7 | from torch.utils.data import DataLoader 8 | 9 | from bbcm.data.datasets.csc import CscDataset 10 | 11 | 12 | def get_csc_loader(fp, _collate_fn, **kwargs): 13 | dataset = CscDataset(fp) 14 | loader = DataLoader(dataset, collate_fn=_collate_fn, **kwargs) 15 | return loader 16 | -------------------------------------------------------------------------------- /bbcm/data/processors/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | @Time : 2021-01-29 18:27:07 3 | @File : __init__.py.py 4 | @Author : Abtion 5 | @Email : abtion{at}outlook.com 6 | """ 7 | -------------------------------------------------------------------------------- /bbcm/data/processors/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gitabtion/BertBasedCorrectionModels/f5e1bacf4f4f9b4a9ca434e9d40249285172a06a/bbcm/data/processors/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /bbcm/data/processors/__pycache__/csc.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gitabtion/BertBasedCorrectionModels/f5e1bacf4f4f9b4a9ca434e9d40249285172a06a/bbcm/data/processors/__pycache__/csc.cpython-37.pyc -------------------------------------------------------------------------------- /bbcm/data/processors/csc.py: -------------------------------------------------------------------------------- 1 | """ 2 | @Time : 2021-01-29 18:27:21 3 | @File : csc.py 4 | @Author : Abtion 5 | @Email : abtion{at}outlook.com 6 | """ 7 | import gc 8 | import os 9 | import random 10 | 11 | import opencc 12 | from lxml import etree 13 | from tqdm import tqdm 14 | 15 | from bbcm.utils import dump_json, get_abs_path 16 | 17 | 18 | def proc_item(item, convertor): 19 | root = etree.XML(item) 20 | passages = dict() 21 | mistakes = [] 22 | for passage in root.xpath('/ESSAY/TEXT/PASSAGE'): 23 | passages[passage.get('id')] = convertor.convert(passage.text) 24 | for mistake in root.xpath('/ESSAY/MISTAKE'): 25 | mistakes.append({'id': mistake.get('id'), 26 | 'location': int(mistake.get('location')) - 1, 27 | 'wrong': convertor.convert(mistake.xpath('./WRONG/text()')[0].strip()), 28 | 'correction': convertor.convert(mistake.xpath('./CORRECTION/text()')[0].strip())}) 29 | 30 | rst_items = dict() 31 | 32 | def get_passages_by_id(pgs, _id): 33 | p = pgs.get(_id) 34 | if p: 35 | return p 36 | _id = _id[:-1] + str(int(_id[-1]) + 1) 37 | p = pgs.get(_id) 38 | if p: 39 | return p 40 | raise ValueError(f'passage not found by {_id}') 41 | 42 | for mistake in mistakes: 43 | if mistake['id'] not in rst_items.keys(): 44 | rst_items[mistake['id']] = {'original_text': get_passages_by_id(passages, mistake['id']), 45 | 'wrong_ids': [], 46 | 'correct_text': get_passages_by_id(passages, mistake['id'])} 47 | 48 | # todo 繁体转简体字符数量或位置发生改变校验 49 | 50 | ori_text = rst_items[mistake['id']]['original_text'] 51 | cor_text = rst_items[mistake['id']]['correct_text'] 52 | if len(ori_text) == len(cor_text): 53 | if ori_text[mistake['location']] in mistake['wrong']: 54 | rst_items[mistake['id']]['wrong_ids'].append(mistake['location']) 55 | wrong_char_idx = mistake['wrong'].index(ori_text[mistake['location']]) 56 | start = mistake['location'] - wrong_char_idx 57 | end = start + len(mistake['wrong']) 58 | rst_items[mistake['id']][ 59 | 'correct_text'] = f'{cor_text[:start]}{mistake["correction"]}{cor_text[end:]}' 60 | else: 61 | print(f'{mistake["id"]}\n{ori_text}\n{cor_text}') 62 | rst = [] 63 | for k in rst_items.keys(): 64 | if len(rst_items[k]['correct_text']) == len(rst_items[k]['original_text']): 65 | rst.append({'id': k, **rst_items[k]}) 66 | else: 67 | text = rst_items[k]['correct_text'] 68 | rst.append({'id': k, 'correct_text': text, 'original_text': text, 'wrong_ids': []}) 69 | return rst 70 | 71 | 72 | def proc_test_set(fp, convertor): 73 | """ 74 | 生成sighan15的测试集 75 | Args: 76 | fp: 77 | convertor: 78 | Returns: 79 | """ 80 | inputs = dict() 81 | with open(os.path.join(fp, 'SIGHAN15_CSC_TestInput.txt'), 'r', encoding='utf8') as f: 82 | for line in f: 83 | pid = line[5:14] 84 | text = line[16:].strip() 85 | inputs[pid] = text 86 | 87 | rst = [] 88 | with open(os.path.join(fp, 'SIGHAN15_CSC_TestTruth.txt'), 'r', encoding='utf8') as f: 89 | for line in f: 90 | pid = line[0:9] 91 | mistakes = line[11:].strip().split(', ') 92 | if len(mistakes) <= 1: 93 | text = convertor.convert(inputs[pid]) 94 | rst.append({'id': pid, 95 | 'original_text': text, 96 | 'wrong_ids': [], 97 | 'correct_text': text}) 98 | else: 99 | wrong_ids = [] 100 | original_text = inputs[pid] 101 | cor_text = inputs[pid] 102 | for i in range(len(mistakes) // 2): 103 | idx = int(mistakes[2 * i]) - 1 104 | cor_char = mistakes[2 * i + 1] 105 | wrong_ids.append(idx) 106 | cor_text = f'{cor_text[:idx]}{cor_char}{cor_text[idx + 1:]}' 107 | original_text = convertor.convert(original_text) 108 | cor_text = convertor.convert(cor_text) 109 | if len(original_text) != len(cor_text): 110 | print(pid) 111 | print(original_text) 112 | print(cor_text) 113 | continue 114 | rst.append({'id': pid, 115 | 'original_text': original_text, 116 | 'wrong_ids': wrong_ids, 117 | 'correct_text': cor_text}) 118 | 119 | return rst 120 | 121 | 122 | def read_data(fp): 123 | for fn in os.listdir(fp): 124 | if fn.endswith('ing.sgml'): 125 | with open(os.path.join(fp, fn), 'r', encoding='utf-8', errors='ignore') as f: 126 | item = [] 127 | for line in f: 128 | if line.strip().startswith(' 0: 129 | yield ''.join(item) 130 | item = [line.strip()] 131 | elif line.strip().startswith('<'): 132 | item.append(line.strip()) 133 | 134 | 135 | def read_confusion_data(fp): 136 | fn = os.path.join(fp, 'train.sgml') 137 | with open(fn, 'r', encoding='utf8') as f: 138 | item = [] 139 | for line in tqdm(f): 140 | if line.strip().startswith(' 0: 141 | yield ''.join(item) 142 | item = [line.strip()] 143 | elif line.strip().startswith('<'): 144 | item.append(line.strip()) 145 | 146 | 147 | def proc_confusion_item(item): 148 | """ 149 | 处理confusionset数据集 150 | Args: 151 | item: 152 | Returns: 153 | """ 154 | root = etree.XML(item) 155 | text = root.xpath('/SENTENCE/TEXT/text()')[0] 156 | mistakes = [] 157 | for mistake in root.xpath('/SENTENCE/MISTAKE'): 158 | mistakes.append({'location': int(mistake.xpath('./LOCATION/text()')[0]) - 1, 159 | 'wrong': mistake.xpath('./WRONG/text()')[0].strip(), 160 | 'correction': mistake.xpath('./CORRECTION/text()')[0].strip()}) 161 | 162 | cor_text = text 163 | wrong_ids = [] 164 | 165 | for mis in mistakes: 166 | cor_text = f'{cor_text[:mis["location"]]}{mis["correction"]}{cor_text[mis["location"] + 1:]}' 167 | wrong_ids.append(mis['location']) 168 | 169 | rst = [{ 170 | 'id': '-', 171 | 'original_text': text, 172 | 'wrong_ids': wrong_ids, 173 | 'correct_text': cor_text 174 | }] 175 | if len(text) != len(cor_text): 176 | return [{'id': '--', 177 | 'original_text': cor_text, 178 | 'wrong_ids': [], 179 | 'correct_text': cor_text}] 180 | # 取一定概率保留原文本 181 | if random.random() < 0.01: 182 | rst.append({'id': '--', 183 | 'original_text': cor_text, 184 | 'wrong_ids': [], 185 | 'correct_text': cor_text}) 186 | return rst 187 | 188 | 189 | def preproc(): 190 | rst_items = [] 191 | convertor = opencc.OpenCC('tw2sp.json') 192 | test_items = proc_test_set(get_abs_path('datasets', 'csc'), convertor) 193 | for item in read_data(get_abs_path('datasets', 'csc')): 194 | rst_items += proc_item(item, convertor) 195 | for item in read_confusion_data(get_abs_path('datasets', 'csc')): 196 | rst_items += proc_confusion_item(item) 197 | 198 | # 拆分训练与测试 199 | dev_set_len = len(rst_items) // 10 200 | print(len(rst_items)) 201 | random.seed(666) 202 | random.shuffle(rst_items) 203 | dump_json(rst_items[:dev_set_len], get_abs_path('datasets', 'csc', 'dev.json')) 204 | dump_json(rst_items[dev_set_len:], get_abs_path('datasets', 'csc', 'train.json')) 205 | dump_json(test_items, get_abs_path('datasets', 'csc', 'test.json')) 206 | gc.collect() 207 | -------------------------------------------------------------------------------- /bbcm/engine/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | @Time : 2021-01-21 10:55:11 3 | @File : __init__.py.py 4 | @Author : Abtion 5 | @Email : abtion{at}outlook.com 6 | """ 7 | -------------------------------------------------------------------------------- /bbcm/engine/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gitabtion/BertBasedCorrectionModels/f5e1bacf4f4f9b4a9ca434e9d40249285172a06a/bbcm/engine/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /bbcm/engine/__pycache__/bases.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gitabtion/BertBasedCorrectionModels/f5e1bacf4f4f9b4a9ca434e9d40249285172a06a/bbcm/engine/__pycache__/bases.cpython-37.pyc -------------------------------------------------------------------------------- /bbcm/engine/__pycache__/csc_trainer.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gitabtion/BertBasedCorrectionModels/f5e1bacf4f4f9b4a9ca434e9d40249285172a06a/bbcm/engine/__pycache__/csc_trainer.cpython-37.pyc -------------------------------------------------------------------------------- /bbcm/engine/bases.py: -------------------------------------------------------------------------------- 1 | """ 2 | @Time : 2021-01-21 11:10:52 3 | @File : bases.py 4 | @Author : Abtion 5 | @Email : abtion{at}outlook.com 6 | """ 7 | import logging 8 | import pytorch_lightning as pl 9 | 10 | from bbcm.solver.build import make_optimizer, build_lr_scheduler 11 | 12 | 13 | class BaseTrainingEngine(pl.LightningModule): 14 | def __init__(self, cfg, *args, **kwargs): 15 | super().__init__(*args, **kwargs) 16 | self.cfg = cfg 17 | self._logger = logging.getLogger(cfg.MODEL.NAME) 18 | 19 | def configure_optimizers(self): 20 | optimizer = make_optimizer(self.cfg, self) 21 | scheduler = build_lr_scheduler(self.cfg, optimizer) 22 | 23 | return [optimizer], [scheduler] 24 | 25 | def on_validation_epoch_start(self) -> None: 26 | self._logger.info('Valid.') 27 | 28 | def on_test_epoch_start(self) -> None: 29 | self._logger.info('Testing...') 30 | -------------------------------------------------------------------------------- /bbcm/engine/csc_trainer.py: -------------------------------------------------------------------------------- 1 | """ 2 | @Time : 2021-01-21 10:57:33 3 | @File : csc_trainer.py 4 | @Author : Abtion 5 | @Email : abtion{at}outlook.com 6 | """ 7 | import operator 8 | import torch 9 | import numpy as np 10 | from bbcm.utils.evaluations import compute_corrector_prf, compute_sentence_level_prf 11 | from .bases import BaseTrainingEngine 12 | 13 | 14 | class CscTrainingModel(BaseTrainingEngine): 15 | """ 16 | 用于CSC的BaseModel, 定义了训练及预测步骤 17 | """ 18 | 19 | def __init__(self, cfg, *args, **kwargs): 20 | super().__init__(cfg, *args, **kwargs) 21 | # loss weight 22 | self.w = cfg.MODEL.HYPER_PARAMS[0] 23 | 24 | def training_step(self, batch, batch_idx): 25 | ori_text, cor_text, det_labels = batch 26 | outputs = self.forward(ori_text, cor_text, det_labels) 27 | loss = self.w * outputs[1] + (1 - self.w) * outputs[0] 28 | return loss 29 | 30 | def validation_step(self, batch, batch_idx): 31 | ori_text, cor_text, det_labels = batch 32 | outputs = self.forward(ori_text, cor_text, det_labels) 33 | loss = self.w * outputs[1] + (1 - self.w) * outputs[0] 34 | det_y_hat = (outputs[2] > 0.5).long() 35 | cor_y_hat = torch.argmax((outputs[3]), dim=-1) 36 | encoded_x = self.tokenizer(cor_text, padding=True, return_tensors='pt') 37 | encoded_x.to(self._device) 38 | cor_y = encoded_x['input_ids'] 39 | cor_y_hat *= encoded_x['attention_mask'] 40 | 41 | results = [] 42 | det_acc_labels = [] 43 | cor_acc_labels = [] 44 | for src, tgt, predict, det_predict, det_label in zip(ori_text, cor_y, cor_y_hat, det_y_hat, det_labels): 45 | _src = self.tokenizer(src, add_special_tokens=False)['input_ids'] 46 | _tgt = tgt[1:len(_src) + 1].cpu().numpy().tolist() 47 | _predict = predict[1:len(_src) + 1].cpu().numpy().tolist() 48 | cor_acc_labels.append(1 if operator.eq(_tgt, _predict) else 0) 49 | det_acc_labels.append(det_predict[1:len(_src) + 1].equal(det_label[1:len(_src) + 1])) 50 | results.append((_src, _tgt, _predict,)) 51 | 52 | return loss.cpu().item(), det_acc_labels, cor_acc_labels, results 53 | 54 | def validation_epoch_end(self, outputs) -> None: 55 | det_acc_labels = [] 56 | cor_acc_labels = [] 57 | results = [] 58 | for out in outputs: 59 | det_acc_labels += out[1] 60 | cor_acc_labels += out[2] 61 | results += out[3] 62 | loss = np.mean([out[0] for out in outputs]) 63 | self.log('val_loss', loss) 64 | self._logger.info(f'loss: {loss}') 65 | self._logger.info(f'Detection:\n' 66 | f'acc: {np.mean(det_acc_labels):.4f}') 67 | self._logger.info(f'Correction:\n' 68 | f'acc: {np.mean(cor_acc_labels):.4f}') 69 | compute_corrector_prf(results, self._logger) 70 | compute_sentence_level_prf(results, self._logger) 71 | 72 | def test_step(self, batch, batch_idx): 73 | return self.validation_step(batch, batch_idx) 74 | 75 | def test_epoch_end(self, outputs) -> None: 76 | self._logger.info('Test.') 77 | self.validation_epoch_end(outputs) 78 | 79 | def predict(self, texts): 80 | inputs = self.tokenizer(texts, padding=True, return_tensors='pt') 81 | inputs.to(self.cfg.MODEL.DEVICE) 82 | with torch.no_grad(): 83 | outputs = self.forward(texts) 84 | y_hat = torch.argmax(outputs[1], dim=-1) 85 | expand_text_lens = torch.sum(inputs['attention_mask'], dim=-1) - 1 86 | rst = [] 87 | for t_len, _y_hat in zip(expand_text_lens, y_hat): 88 | rst.append(self.tokenizer.decode(_y_hat[1:t_len]).replace(' ', '')) 89 | return rst 90 | -------------------------------------------------------------------------------- /bbcm/layers/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | @Time : 2021-01-21 10:54:50 3 | @File : __init__.py.py 4 | @Author : Abtion 5 | @Email : abtion{at}outlook.com 6 | """ 7 | -------------------------------------------------------------------------------- /bbcm/modeling/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | @Time : 2021-01-21 10:54:38 3 | @File : __init__.py.py 4 | @Author : Abtion 5 | @Email : abtion{at}outlook.com 6 | """ 7 | -------------------------------------------------------------------------------- /bbcm/modeling/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gitabtion/BertBasedCorrectionModels/f5e1bacf4f4f9b4a9ca434e9d40249285172a06a/bbcm/modeling/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /bbcm/modeling/csc/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | @Time : 2021-01-21 17:56:11 3 | @File : __init__.py.py 4 | @Author : Abtion 5 | @Email : abtion{at}outlook.com 6 | """ 7 | from .modeling_soft_masked_bert import SoftMaskedBertModel 8 | from .modeling_bert4csc import BertForCsc 9 | -------------------------------------------------------------------------------- /bbcm/modeling/csc/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gitabtion/BertBasedCorrectionModels/f5e1bacf4f4f9b4a9ca434e9d40249285172a06a/bbcm/modeling/csc/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /bbcm/modeling/csc/__pycache__/modeling_bert4csc.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gitabtion/BertBasedCorrectionModels/f5e1bacf4f4f9b4a9ca434e9d40249285172a06a/bbcm/modeling/csc/__pycache__/modeling_bert4csc.cpython-37.pyc -------------------------------------------------------------------------------- /bbcm/modeling/csc/__pycache__/modeling_soft_masked_bert.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gitabtion/BertBasedCorrectionModels/f5e1bacf4f4f9b4a9ca434e9d40249285172a06a/bbcm/modeling/csc/__pycache__/modeling_soft_masked_bert.cpython-37.pyc -------------------------------------------------------------------------------- /bbcm/modeling/csc/modeling_abert4csc.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | import transformers as tfs 3 | import torch 4 | import torch.nn as nn 5 | from transformers.modeling_utils import ModuleUtilsMixin 6 | from transformers.models.bert.modeling_bert import BertEmbeddings, BertEncoder, BertOnlyMLMHead 7 | 8 | from bbcm.engine.csc_trainer import CscTrainingModel 9 | 10 | 11 | class BertCorrectionModel(torch.nn.Module, ModuleUtilsMixin): 12 | def __init__(self, config, tokenizer, device): 13 | super().__init__() 14 | self.config = config 15 | self.tokenizer = tokenizer 16 | self.embeddings = BertEmbeddings(self.config) 17 | self.corrector = BertEncoder(self.config) 18 | self.cls = BertOnlyMLMHead(self.config) 19 | self._device = device 20 | 21 | def forward(self, texts, cor_labels=None, residual_connection=False): 22 | if cor_labels is not None: 23 | text_labels = self.tokenizer(cor_labels, padding=True, return_tensors='pt')['input_ids'] 24 | text_labels = text_labels.to(self._device) 25 | # torch的cross entropy loss 会忽略-100的label 26 | text_labels[text_labels == 0] = -100 27 | else: 28 | text_labels = None 29 | encoded_texts = self.tokenizer(texts, padding=True, return_tensors='pt') 30 | encoded_texts.to(self._device) 31 | embed = self.embeddings(input_ids=encoded_texts['input_ids'], 32 | token_type_ids=encoded_texts['token_type_ids'],) 33 | 34 | input_shape = encoded_texts['input_ids'].size() 35 | device = encoded_texts['input_ids'].device 36 | 37 | extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(encoded_texts['attention_mask'], 38 | input_shape, device) 39 | head_mask = self.get_head_mask(None, self.config.num_hidden_layers) 40 | encoder_outputs = self.corrector( 41 | embed, 42 | attention_mask=extended_attention_mask, 43 | head_mask=head_mask, 44 | encoder_hidden_states=None, 45 | encoder_attention_mask=None, 46 | return_dict=False, 47 | ) 48 | sequence_output = encoder_outputs[0] 49 | 50 | sequence_output = sequence_output + embed if residual_connection else sequence_output 51 | prediction_scores = self.cls(sequence_output) 52 | out = (prediction_scores, sequence_output) 53 | 54 | # Masked language modeling softmax layer 55 | if text_labels is not None: 56 | loss_fct = nn.CrossEntropyLoss() # -100 index = padding token 57 | cor_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), text_labels.view(-1)) 58 | out = (cor_loss,) + out 59 | return out 60 | 61 | def load_from_transformers_state_dict(self, gen_fp): 62 | state_dict = OrderedDict() 63 | gen_state_dict = tfs.AutoModelForMaskedLM.from_pretrained(gen_fp).state_dict() 64 | for k, v in gen_state_dict.items(): 65 | name = k 66 | if name.startswith('bert'): 67 | name = name[5:] 68 | if name.startswith('encoder'): 69 | name = f'corrector.{name[8:]}' 70 | if 'gamma' in name: 71 | name = name.replace('gamma', 'weight') 72 | if 'beta' in name: 73 | name = name.replace('beta', 'bias') 74 | state_dict[name] = v 75 | self.load_state_dict(state_dict, strict=False) 76 | -------------------------------------------------------------------------------- /bbcm/modeling/csc/modeling_bert4csc.py: -------------------------------------------------------------------------------- 1 | """ 2 | @Time : 2021-01-22 11:42:52 3 | @File : bert4csc.py 4 | @Author : Abtion 5 | @Email : abtion{at}outlook.com 6 | """ 7 | 8 | import torch.nn as nn 9 | from transformers import BertForMaskedLM 10 | 11 | from bbcm.engine.csc_trainer import CscTrainingModel 12 | from bbcm.solver.losses import FocalLoss 13 | 14 | 15 | class BertForCsc(CscTrainingModel): 16 | def __init__(self, cfg, tokenizer): 17 | super().__init__(cfg) 18 | self.cfg = cfg 19 | self.bert = BertForMaskedLM.from_pretrained(cfg.MODEL.BERT_CKPT) 20 | self.detection = nn.Linear(self.bert.config.hidden_size, 1) 21 | self.sigmoid = nn.Sigmoid() 22 | self.tokenizer = tokenizer 23 | 24 | def forward(self, texts, cor_labels=None, det_labels=None): 25 | if cor_labels is not None: 26 | text_labels = self.tokenizer(cor_labels, padding=True, return_tensors='pt', truncation=True)['input_ids'] 27 | text_labels = text_labels.to(self.device) 28 | text_labels[text_labels == 0] = -100 29 | else: 30 | text_labels = None 31 | encoded_text = self.tokenizer(texts, padding=True, return_tensors='pt', truncation=True) 32 | encoded_text.to(self.device) 33 | bert_outputs = self.bert(**encoded_text, labels=text_labels, return_dict=True, output_hidden_states=True) 34 | # 检错概率 35 | prob = self.detection(bert_outputs.hidden_states[-1]) 36 | 37 | if text_labels is None: 38 | # 检错输出,纠错输出 39 | outputs = (prob, bert_outputs.logits) 40 | else: 41 | det_loss_fct = FocalLoss(num_labels=None, activation_type='sigmoid') 42 | # pad部分不计算损失 43 | active_loss = encoded_text['attention_mask'].view(-1, prob.shape[1]) == 1 44 | active_probs = prob.view(-1, prob.shape[1])[active_loss] 45 | active_labels = det_labels[active_loss] 46 | det_loss = det_loss_fct(active_probs, active_labels.float()) 47 | # 检错loss,纠错loss,检错输出,纠错输出 48 | outputs = (det_loss, 49 | bert_outputs.loss, 50 | self.sigmoid(prob).squeeze(-1), 51 | bert_outputs.logits) 52 | return outputs 53 | -------------------------------------------------------------------------------- /bbcm/modeling/csc/modeling_soft_masked_bert.py: -------------------------------------------------------------------------------- 1 | """ 2 | @Time : 2021-01-21 12:00:59 3 | @File : modeling_soft_masked_bert.py 4 | @Author : Abtion 5 | @Email : abtion{at}outlook.com 6 | """ 7 | import operator 8 | import os 9 | from collections import OrderedDict 10 | import transformers as tfs 11 | import torch 12 | from torch import nn 13 | import pytorch_lightning as pl 14 | from torch.optim.lr_scheduler import LambdaLR 15 | from transformers import BertConfig 16 | from transformers.models.bert.modeling_bert import BertEmbeddings, BertEncoder, BertPooler, BertOnlyMLMHead 17 | from transformers.modeling_utils import ModuleUtilsMixin 18 | from bbcm.engine.csc_trainer import CscTrainingModel 19 | import numpy as np 20 | 21 | 22 | class DetectionNetwork(nn.Module): 23 | def __init__(self, config): 24 | super().__init__() 25 | self.config = config 26 | self.gru = nn.GRU( 27 | self.config.hidden_size, 28 | self.config.hidden_size // 2, 29 | num_layers=2, 30 | batch_first=True, 31 | dropout=self.config.hidden_dropout_prob, 32 | bidirectional=True, 33 | ) 34 | self.sigmoid = nn.Sigmoid() 35 | self.linear = nn.Linear(self.config.hidden_size, 1) 36 | 37 | def forward(self, hidden_states): 38 | out, _ = self.gru(hidden_states) 39 | prob = self.linear(out) 40 | prob = self.sigmoid(prob) 41 | return prob 42 | 43 | 44 | class BertCorrectionModel(torch.nn.Module, ModuleUtilsMixin): 45 | def __init__(self, config, tokenizer, device): 46 | super().__init__() 47 | self.config = config 48 | self.tokenizer = tokenizer 49 | self.embeddings = BertEmbeddings(self.config) 50 | self.corrector = BertEncoder(self.config) 51 | self.mask_token_id = self.tokenizer.mask_token_id 52 | self.cls = BertOnlyMLMHead(self.config) 53 | self._device = device 54 | 55 | def forward(self, texts, prob, embed=None, cor_labels=None, residual_connection=False): 56 | if cor_labels is not None: 57 | text_labels = self.tokenizer(cor_labels, padding=True, return_tensors='pt', truncation=True)['input_ids'] 58 | text_labels = text_labels.to(self._device) 59 | # torch的cross entropy loss 会忽略-100的label 60 | text_labels[text_labels == 0] = -100 61 | else: 62 | text_labels = None 63 | encoded_texts = self.tokenizer(texts, padding=True, return_tensors='pt', truncation=True) 64 | encoded_texts.to(self._device) 65 | if embed is None: 66 | embed = self.embeddings(input_ids=encoded_texts['input_ids'], 67 | token_type_ids=encoded_texts['token_type_ids']) 68 | # 此处较原文有一定改动,做此改动意在完整保留type_ids及position_ids的embedding。 69 | mask_embed = self.embeddings(torch.ones_like(prob.squeeze(-1)).long() * self.mask_token_id).detach() 70 | # 此处为原文实现 71 | # mask_embed = self.embeddings(torch.tensor([[self.mask_token_id]], device=self._device)).detach() 72 | cor_embed = prob * mask_embed + (1 - prob) * embed 73 | 74 | input_shape = encoded_texts['input_ids'].size() 75 | device = encoded_texts['input_ids'].device 76 | 77 | extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(encoded_texts['attention_mask'], 78 | input_shape, device) 79 | head_mask = self.get_head_mask(None, self.config.num_hidden_layers) 80 | encoder_outputs = self.corrector( 81 | cor_embed, 82 | attention_mask=extended_attention_mask, 83 | head_mask=head_mask, 84 | encoder_hidden_states=None, 85 | encoder_attention_mask=None, 86 | return_dict=False, 87 | ) 88 | sequence_output = encoder_outputs[0] 89 | 90 | sequence_output = sequence_output + embed if residual_connection else sequence_output 91 | prediction_scores = self.cls(sequence_output) 92 | out = (prediction_scores, sequence_output) 93 | 94 | # Masked language modeling softmax layer 95 | if text_labels is not None: 96 | loss_fct = nn.CrossEntropyLoss() # -100 index = padding token 97 | cor_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), text_labels.view(-1)) 98 | out = (cor_loss,) + out 99 | return out 100 | 101 | def load_from_transformers_state_dict(self, gen_fp): 102 | state_dict = OrderedDict() 103 | gen_state_dict = tfs.AutoModelForMaskedLM.from_pretrained(gen_fp).state_dict() 104 | for k, v in gen_state_dict.items(): 105 | name = k 106 | if name.startswith('bert'): 107 | name = name[5:] 108 | if name.startswith('encoder'): 109 | name = f'corrector.{name[8:]}' 110 | if 'gamma' in name: 111 | name = name.replace('gamma', 'weight') 112 | if 'beta' in name: 113 | name = name.replace('beta', 'bias') 114 | state_dict[name] = v 115 | self.load_state_dict(state_dict, strict=False) 116 | 117 | 118 | class SoftMaskedBertModel(CscTrainingModel): 119 | def __init__(self, cfg, tokenizer): 120 | super().__init__(cfg) 121 | self.cfg = cfg 122 | self.config = tfs.AutoConfig.from_pretrained(cfg.MODEL.BERT_CKPT) 123 | self.detector = DetectionNetwork(self.config) 124 | self.tokenizer = tokenizer 125 | self.corrector = BertCorrectionModel(self.config, tokenizer, cfg.MODEL.DEVICE) 126 | self.corrector.load_from_transformers_state_dict(self.cfg.MODEL.BERT_CKPT) 127 | self._device = cfg.MODEL.DEVICE 128 | 129 | def forward(self, texts, cor_labels=None, det_labels=None): 130 | encoded_texts = self.tokenizer(texts, padding=True, return_tensors='pt', truncation=True) 131 | encoded_texts.to(self._device) 132 | embed = self.corrector.embeddings(input_ids=encoded_texts['input_ids'], 133 | token_type_ids=encoded_texts['token_type_ids']) 134 | prob = self.detector(embed) 135 | cor_out = self.corrector(texts, prob, embed, cor_labels, residual_connection=True) 136 | 137 | if det_labels is not None: 138 | det_loss_fct = nn.BCELoss() 139 | # pad部分不计算损失 140 | active_loss = encoded_texts['attention_mask'].view(-1, prob.shape[1]) == 1 141 | active_probs = prob.view(-1, prob.shape[1])[active_loss] 142 | active_labels = det_labels[active_loss] 143 | det_loss = det_loss_fct(active_probs, active_labels.float()) 144 | outputs = (det_loss, cor_out[0], prob.squeeze(-1)) + cor_out[1:] 145 | else: 146 | outputs = (prob.squeeze(-1),) + cor_out 147 | 148 | return outputs 149 | 150 | def load_from_transformers_state_dict(self, gen_fp): 151 | """ 152 | 从transformers加载预训练权重 153 | :param gen_fp: 154 | :return: 155 | """ 156 | self.corrector.load_from_transformers_state_dict(gen_fp) 157 | -------------------------------------------------------------------------------- /bbcm/solver/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | @Time : 2021-01-21 10:51:07 3 | @File : __init__.py.py 4 | @Author : Abtion 5 | @Email : abtion{at}outlook.com 6 | """ 7 | -------------------------------------------------------------------------------- /bbcm/solver/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gitabtion/BertBasedCorrectionModels/f5e1bacf4f4f9b4a9ca434e9d40249285172a06a/bbcm/solver/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /bbcm/solver/__pycache__/build.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gitabtion/BertBasedCorrectionModels/f5e1bacf4f4f9b4a9ca434e9d40249285172a06a/bbcm/solver/__pycache__/build.cpython-37.pyc -------------------------------------------------------------------------------- /bbcm/solver/build.py: -------------------------------------------------------------------------------- 1 | """ 2 | @Time : 2021-01-21 10:51:21 3 | @File : build.py 4 | @Author : Abtion 5 | @Email : abtion{at}outlook.com 6 | """ 7 | import torch 8 | 9 | from bbcm.solver import lr_scheduler 10 | 11 | 12 | def make_optimizer(cfg, model): 13 | params = [] 14 | for key, value in model.named_parameters(): 15 | if not value.requires_grad: 16 | continue 17 | lr = cfg.SOLVER.BASE_LR 18 | weight_decay = cfg.SOLVER.WEIGHT_DECAY 19 | if "bias" in key: 20 | lr = cfg.SOLVER.BASE_LR * cfg.SOLVER.BIAS_LR_FACTOR 21 | weight_decay = cfg.SOLVER.WEIGHT_DECAY_BIAS 22 | params += [{"params": [value], "lr": lr, "weight_decay": weight_decay}] 23 | if cfg.SOLVER.OPTIMIZER_NAME == 'SGD': 24 | optimizer = getattr(torch.optim, cfg.SOLVER.OPTIMIZER_NAME)(params, momentum=cfg.SOLVER.MOMENTUM) 25 | else: 26 | optimizer = getattr(torch.optim, cfg.SOLVER.OPTIMIZER_NAME)(params) 27 | return optimizer 28 | 29 | 30 | def build_lr_scheduler(cfg, optimizer): 31 | scheduler_args = { 32 | "optimizer": optimizer, 33 | 34 | # warmup options 35 | "warmup_factor": cfg.SOLVER.WARMUP_FACTOR, 36 | "warmup_epochs": cfg.SOLVER.WARMUP_EPOCHS, 37 | "warmup_method": cfg.SOLVER.WARMUP_METHOD, 38 | 39 | # multi-step lr scheduler options 40 | "milestones": cfg.SOLVER.STEPS, 41 | "gamma": cfg.SOLVER.GAMMA, 42 | 43 | # cosine annealing lr scheduler options 44 | "max_iters": cfg.SOLVER.MAX_ITER, 45 | "delay_iters": cfg.SOLVER.DELAY_ITERS, 46 | "eta_min_lr": cfg.SOLVER.ETA_MIN_LR, 47 | 48 | } 49 | scheduler = getattr(lr_scheduler, cfg.SOLVER.SCHED)(**scheduler_args) 50 | return {'scheduler': scheduler, 'interval': cfg.SOLVER.INTERVAL} 51 | -------------------------------------------------------------------------------- /bbcm/solver/losses.py: -------------------------------------------------------------------------------- 1 | """ 2 | @Time : 2021-05-18 15:19:28 3 | @File : losses.py 4 | @Author : Abtion 5 | @Email : abtion{at}outlook.com 6 | """ 7 | import torch 8 | import torch.nn as nn 9 | 10 | 11 | class FocalLoss(nn.Module): 12 | """ 13 | Softmax and sigmoid focal loss. 14 | copy from https://github.com/lonePatient/TorchBlocks 15 | """ 16 | 17 | def __init__(self, num_labels, activation_type='softmax', gamma=2.0, alpha=0.25, epsilon=1.e-9): 18 | 19 | super(FocalLoss, self).__init__() 20 | self.num_labels = num_labels 21 | self.gamma = gamma 22 | self.alpha = alpha 23 | self.epsilon = epsilon 24 | self.activation_type = activation_type 25 | 26 | def forward(self, input, target): 27 | """ 28 | Args: 29 | logits: model's output, shape of [batch_size, num_cls] 30 | target: ground truth labels, shape of [batch_size] 31 | Returns: 32 | shape of [batch_size] 33 | """ 34 | if self.activation_type == 'softmax': 35 | idx = target.view(-1, 1).long() 36 | one_hot_key = torch.zeros(idx.size(0), self.num_labels, dtype=torch.float32, device=idx.device) 37 | one_hot_key = one_hot_key.scatter_(1, idx, 1) 38 | logits = torch.softmax(input, dim=-1) 39 | loss = -self.alpha * one_hot_key * torch.pow((1 - logits), self.gamma) * (logits + self.epsilon).log() 40 | loss = loss.sum(1) 41 | elif self.activation_type == 'sigmoid': 42 | multi_hot_key = target 43 | logits = torch.sigmoid(input) 44 | zero_hot_key = 1 - multi_hot_key 45 | loss = -self.alpha * multi_hot_key * torch.pow((1 - logits), self.gamma) * (logits + self.epsilon).log() 46 | loss += -(1 - self.alpha) * zero_hot_key * torch.pow(logits, self.gamma) * (1 - logits + self.epsilon).log() 47 | return loss.mean() 48 | -------------------------------------------------------------------------------- /bbcm/solver/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | """ 2 | @Time : 2021-01-21 10:52:47 3 | @File : lr_scheduler.py 4 | @Author : Abtion 5 | @Email : abtion{at}outlook.com 6 | """ 7 | import math 8 | import warnings 9 | from bisect import bisect_right 10 | from typing import List 11 | 12 | import torch 13 | from torch.optim.lr_scheduler import _LRScheduler 14 | 15 | __all__ = ["WarmupMultiStepLR", "WarmupCosineAnnealingLR"] 16 | 17 | 18 | class WarmupMultiStepLR(_LRScheduler): 19 | def __init__( 20 | self, 21 | optimizer: torch.optim.Optimizer, 22 | milestones: List[int], 23 | gamma: float = 0.1, 24 | warmup_factor: float = 0.001, 25 | warmup_epochs: int = 2, 26 | warmup_method: str = "linear", 27 | last_epoch: int = -1, 28 | **kwargs, 29 | ): 30 | if not list(milestones) == sorted(milestones): 31 | raise ValueError( 32 | "Milestones should be a list of" " increasing integers. Got {}", milestones 33 | ) 34 | self.milestones = milestones 35 | self.gamma = gamma 36 | self.warmup_factor = warmup_factor 37 | self.warmup_epochs = warmup_epochs 38 | self.warmup_method = warmup_method 39 | super().__init__(optimizer, last_epoch) 40 | 41 | def get_lr(self) -> List[float]: 42 | warmup_factor = _get_warmup_factor_at_iter( 43 | self.warmup_method, self.last_epoch, self.warmup_epochs, self.warmup_factor 44 | ) 45 | return [ 46 | base_lr * warmup_factor * self.gamma ** bisect_right(self.milestones, self.last_epoch) 47 | for base_lr in self.base_lrs 48 | ] 49 | 50 | def _compute_values(self) -> List[float]: 51 | # The new interface 52 | return self.get_lr() 53 | 54 | 55 | class WarmupExponentialLR(_LRScheduler): 56 | """Decays the learning rate of each parameter group by gamma every epoch. 57 | When last_epoch=-1, sets initial lr as lr. 58 | 59 | Args: 60 | optimizer (Optimizer): Wrapped optimizer. 61 | gamma (float): Multiplicative factor of learning rate decay. 62 | last_epoch (int): The index of last epoch. Default: -1. 63 | verbose (bool): If ``True``, prints a message to stdout for 64 | each update. Default: ``False``. 65 | """ 66 | 67 | def __init__(self, optimizer, gamma, last_epoch=-1, warmup_epochs=2, warmup_factor=1.0 / 3, verbose=False, 68 | **kwargs): 69 | self.gamma = gamma 70 | self.warmup_method = 'linear' 71 | self.warmup_epochs = warmup_epochs 72 | self.warmup_factor = warmup_factor 73 | super().__init__(optimizer, last_epoch, verbose) 74 | 75 | def get_lr(self): 76 | if not self._get_lr_called_within_step: 77 | warnings.warn("To get the last learning rate computed by the scheduler, " 78 | "please use `get_last_lr()`.", UserWarning) 79 | warmup_factor = _get_warmup_factor_at_iter( 80 | self.warmup_method, self.last_epoch, self.warmup_epochs, self.warmup_factor 81 | ) 82 | 83 | if self.last_epoch <= self.warmup_epochs: 84 | return [base_lr * warmup_factor 85 | for base_lr in self.base_lrs] 86 | return [group['lr'] * self.gamma 87 | for group in self.optimizer.param_groups] 88 | 89 | def _get_closed_form_lr(self): 90 | return [base_lr * self.gamma ** self.last_epoch 91 | for base_lr in self.base_lrs] 92 | 93 | 94 | class WarmupCosineAnnealingLR(_LRScheduler): 95 | r"""Set the learning rate of each parameter group using a cosine annealing 96 | schedule, where :math:`\eta_{max}` is set to the initial lr and 97 | :math:`T_{cur}` is the number of epochs since the last restart in SGDR: 98 | .. math:: 99 | \eta_t = \eta_{min} + \frac{1}{2}(\eta_{max} - \eta_{min})(1 + 100 | \cos(\frac{T_{cur}}{T_{max}}\pi)) 101 | When last_epoch=-1, sets initial lr as lr. 102 | It has been proposed in 103 | `SGDR: Stochastic Gradient Descent with Warm Restarts`_. Note that this only 104 | implements the cosine annealing part of SGDR, and not the restarts. 105 | Args: 106 | optimizer (Optimizer): Wrapped optimizer. 107 | T_max (int): Maximum number of iterations. 108 | eta_min (float): Minimum learning rate. Default: 0. 109 | last_epoch (int): The index of last epoch. Default: -1. 110 | .. _SGDR\: Stochastic Gradient Descent with Warm Restarts: 111 | https://arxiv.org/abs/1608.03983 112 | """ 113 | 114 | def __init__( 115 | self, 116 | optimizer: torch.optim.Optimizer, 117 | max_iters: int, 118 | delay_iters: int = 0, 119 | eta_min_lr: int = 0, 120 | warmup_factor: float = 0.001, 121 | warmup_epochs: int = 2, 122 | warmup_method: str = "linear", 123 | last_epoch=-1, 124 | **kwargs 125 | ): 126 | self.max_iters = max_iters 127 | self.delay_iters = delay_iters 128 | self.eta_min_lr = eta_min_lr 129 | self.warmup_factor = warmup_factor 130 | self.warmup_epochs = warmup_epochs 131 | self.warmup_method = warmup_method 132 | assert self.delay_iters >= self.warmup_epochs, "Scheduler delay iters must be larger than warmup iters" 133 | super(WarmupCosineAnnealingLR, self).__init__(optimizer, last_epoch) 134 | 135 | def get_lr(self) -> List[float]: 136 | if self.last_epoch <= self.warmup_epochs: 137 | warmup_factor = _get_warmup_factor_at_iter( 138 | self.warmup_method, self.last_epoch, self.warmup_epochs, self.warmup_factor, 139 | ) 140 | return [ 141 | base_lr * warmup_factor for base_lr in self.base_lrs 142 | ] 143 | elif self.last_epoch <= self.delay_iters: 144 | return self.base_lrs 145 | 146 | else: 147 | return [ 148 | self.eta_min_lr + (base_lr - self.eta_min_lr) * 149 | (1 + math.cos( 150 | math.pi * (self.last_epoch - self.delay_iters) / (self.max_iters - self.delay_iters))) / 2 151 | for base_lr in self.base_lrs] 152 | 153 | 154 | def _get_warmup_factor_at_iter( 155 | method: str, iter: int, warmup_iters: int, warmup_factor: float 156 | ) -> float: 157 | """ 158 | Return the learning rate warmup factor at a specific iteration. 159 | See https://arxiv.org/abs/1706.02677 for more details. 160 | Args: 161 | method (str): warmup method; either "constant" or "linear". 162 | iter (int): iteration at which to calculate the warmup factor. 163 | warmup_iters (int): the number of warmup iterations. 164 | warmup_factor (float): the base warmup factor (the meaning changes according 165 | to the method used). 166 | Returns: 167 | float: the effective warmup factor at the given iteration. 168 | """ 169 | if iter >= warmup_iters: 170 | return 1.0 171 | 172 | if method == "constant": 173 | return warmup_factor 174 | elif method == "linear": 175 | alpha = iter / warmup_iters 176 | return warmup_factor * (1 - alpha) + alpha 177 | else: 178 | raise ValueError("Unknown warmup method: {}".format(method)) 179 | -------------------------------------------------------------------------------- /bbcm/utils/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | @Time : 2021-01-21 10:38:37 3 | @File : __init__.py.py 4 | @Author : Abtion 5 | @Email : abtion{at}outlook.com 6 | """ 7 | from .file_io import * 8 | -------------------------------------------------------------------------------- /bbcm/utils/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gitabtion/BertBasedCorrectionModels/f5e1bacf4f4f9b4a9ca434e9d40249285172a06a/bbcm/utils/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /bbcm/utils/__pycache__/evaluations.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gitabtion/BertBasedCorrectionModels/f5e1bacf4f4f9b4a9ca434e9d40249285172a06a/bbcm/utils/__pycache__/evaluations.cpython-37.pyc -------------------------------------------------------------------------------- /bbcm/utils/__pycache__/file_io.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gitabtion/BertBasedCorrectionModels/f5e1bacf4f4f9b4a9ca434e9d40249285172a06a/bbcm/utils/__pycache__/file_io.cpython-37.pyc -------------------------------------------------------------------------------- /bbcm/utils/__pycache__/logger.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gitabtion/BertBasedCorrectionModels/f5e1bacf4f4f9b4a9ca434e9d40249285172a06a/bbcm/utils/__pycache__/logger.cpython-37.pyc -------------------------------------------------------------------------------- /bbcm/utils/evaluations.py: -------------------------------------------------------------------------------- 1 | """ 2 | @Time : 2021-01-21 12:01:32 3 | @File : evaluations.py 4 | @Author : Abtion 5 | @Email : abtion{at}outlook.com 6 | """ 7 | 8 | 9 | def compute_corrector_prf(results, logger): 10 | """ 11 | copy from https://github.com/sunnyqiny/Confusionset-guided-Pointer-Networks-for-Chinese-Spelling-Check/blob/master/utils/evaluation_metrics.py 12 | """ 13 | TP = 0 14 | FP = 0 15 | FN = 0 16 | all_predict_true_index = [] 17 | all_gold_index = [] 18 | for item in results: 19 | src, tgt, predict = item 20 | gold_index = [] 21 | each_true_index = [] 22 | for i in range(len(list(src))): 23 | if src[i] == tgt[i]: 24 | continue 25 | else: 26 | gold_index.append(i) 27 | all_gold_index.append(gold_index) 28 | predict_index = [] 29 | for i in range(len(list(src))): 30 | if src[i] == predict[i]: 31 | continue 32 | else: 33 | predict_index.append(i) 34 | 35 | for i in predict_index: 36 | if i in gold_index: 37 | TP += 1 38 | each_true_index.append(i) 39 | else: 40 | FP += 1 41 | for i in gold_index: 42 | if i in predict_index: 43 | continue 44 | else: 45 | FN += 1 46 | all_predict_true_index.append(each_true_index) 47 | 48 | # For the detection Precision, Recall and F1 49 | detection_precision = TP / (TP + FP) if (TP + FP) > 0 else 0 50 | detection_recall = TP / (TP + FN) if (TP + FN) > 0 else 0 51 | if detection_precision + detection_recall == 0: 52 | detection_f1 = 0 53 | else: 54 | detection_f1 = 2 * (detection_precision * detection_recall) / (detection_precision + detection_recall) 55 | logger.info( 56 | "The detection result is precision={}, recall={} and F1={}".format(detection_precision, detection_recall, 57 | detection_f1)) 58 | 59 | TP = 0 60 | FP = 0 61 | FN = 0 62 | 63 | for i in range(len(all_predict_true_index)): 64 | # we only detect those correctly detected location, which is a different from the common metrics since 65 | # we wanna to see the precision improve by using the confusionset 66 | if len(all_predict_true_index[i]) > 0: 67 | predict_words = [] 68 | for j in all_predict_true_index[i]: 69 | predict_words.append(results[i][2][j]) 70 | if results[i][1][j] == results[i][2][j]: 71 | TP += 1 72 | else: 73 | FP += 1 74 | for j in all_gold_index[i]: 75 | if results[i][1][j] in predict_words: 76 | continue 77 | else: 78 | FN += 1 79 | 80 | # For the correction Precision, Recall and F1 81 | correction_precision = TP / (TP + FP) if (TP + FP) > 0 else 0 82 | correction_recall = TP / (TP + FN) if (TP + FN) > 0 else 0 83 | if correction_precision + correction_recall == 0: 84 | correction_f1 = 0 85 | else: 86 | correction_f1 = 2 * (correction_precision * correction_recall) / (correction_precision + correction_recall) 87 | logger.info("The correction result is precision={}, recall={} and F1={}".format(correction_precision, 88 | correction_recall, 89 | correction_f1)) 90 | 91 | return detection_f1, correction_f1 92 | 93 | 94 | def compute_sentence_level_prf(results, logger): 95 | """ 96 | 自定义的句级prf,设定需要纠错为正样本,无需纠错为负样本 97 | :param results: 98 | :return: 99 | """ 100 | 101 | TP = 0.0 102 | FP = 0.0 103 | FN = 0.0 104 | TN = 0.0 105 | total_num = len(results) 106 | 107 | for item in results: 108 | src, tgt, predict = item 109 | 110 | # 负样本 111 | if src == tgt: 112 | # 预测也为负 113 | if tgt == predict: 114 | TN += 1 115 | # 预测为正 116 | else: 117 | FP += 1 118 | # 正样本 119 | else: 120 | # 预测也为正 121 | if tgt == predict: 122 | TP += 1 123 | # 预测为负 124 | else: 125 | FN += 1 126 | 127 | acc = (TP + TN) / total_num 128 | precision = TP / (TP + FP) if TP > 0 else 0.0 129 | recall = TP / (TP + FN) if TP > 0 else 0.0 130 | f1 = 2 * precision * recall / (precision + recall) if precision + recall != 0 else 0 131 | 132 | logger.info(f'Sentence Level: acc:{acc:.6f}, precision:{precision:.6f}, recall:{recall:.6f}, f1:{f1:.6f}') 133 | return acc, precision, recall, f1 134 | 135 | 136 | def report_prf(tp, fp, fn, phase, logger=None, return_dict=False): 137 | # For the detection Precision, Recall and F1 138 | precision = tp / (tp + fp) if (tp + fp) > 0 else 0 139 | recall = tp / (tp + fn) if (tp + fn) > 0 else 0 140 | if precision + recall == 0: 141 | f1_score = 0 142 | else: 143 | f1_score = 2 * (precision * recall) / (precision + recall) 144 | 145 | if phase and logger: 146 | logger.info(f"The {phase} result is: " 147 | f"{precision:.4f}/{recall:.4f}/{f1_score:.4f} -->\n" 148 | # f"precision={precision:.6f}, recall={recall:.6f} and F1={f1_score:.6f}\n" 149 | f"support: TP={tp}, FP={fp}, FN={fn}") 150 | if return_dict: 151 | ret_dict = { 152 | f'{phase}_p': precision, 153 | f'{phase}_r': recall, 154 | f'{phase}_f1': f1_score} 155 | return ret_dict 156 | return precision, recall, f1_score 157 | 158 | 159 | def compute_corrector_prf_faspell(results, logger=None, strict=True): 160 | """ 161 | All-in-one measure function. 162 | based on FASpell's measure script. 163 | :param results: a list of (wrong, correct, predict, ...) 164 | both token_ids or characters are fine for the script. 165 | :param logger: take which logger to print logs. 166 | :param strict: a more strict evaluation mode (all-char-detected/corrected) 167 | References: 168 | sentence-level PRF: https://github.com/iqiyi/ 169 | FASPell/blob/master/faspell.py 170 | """ 171 | 172 | corrected_char, wrong_char = 0, 0 173 | corrected_sent, wrong_sent = 0, 0 174 | true_corrected_char = 0 175 | true_corrected_sent = 0 176 | true_detected_char = 0 177 | true_detected_sent = 0 178 | accurate_detected_sent = 0 179 | accurate_corrected_sent = 0 180 | all_sent = 0 181 | 182 | for item in results: 183 | # wrong, correct, predict, d_tgt, d_predict = item 184 | wrong, correct, predict = item[:3] 185 | 186 | all_sent += 1 187 | wrong_num = 0 188 | corrected_num = 0 189 | original_wrong_num = 0 190 | true_detected_char_in_sentence = 0 191 | 192 | for c, w, p in zip(correct, wrong, predict): 193 | if c != p: 194 | wrong_num += 1 195 | if w != p: 196 | corrected_num += 1 197 | if c == p: 198 | true_corrected_char += 1 199 | if w != c: 200 | true_detected_char += 1 201 | true_detected_char_in_sentence += 1 202 | if c != w: 203 | original_wrong_num += 1 204 | 205 | corrected_char += corrected_num 206 | wrong_char += original_wrong_num 207 | if original_wrong_num != 0: 208 | wrong_sent += 1 209 | if corrected_num != 0 and wrong_num == 0: 210 | true_corrected_sent += 1 211 | 212 | if corrected_num != 0: 213 | corrected_sent += 1 214 | 215 | if strict: # find out all faulty wordings' potisions 216 | true_detected_flag = (true_detected_char_in_sentence == original_wrong_num \ 217 | and original_wrong_num != 0 \ 218 | and corrected_num == true_detected_char_in_sentence) 219 | else: # think it has faulty wordings 220 | true_detected_flag = (corrected_num != 0 and original_wrong_num != 0) 221 | 222 | # if corrected_num != 0 and original_wrong_num != 0: 223 | if true_detected_flag: 224 | true_detected_sent += 1 225 | if correct == predict: 226 | accurate_corrected_sent += 1 227 | if correct == predict or true_detected_flag: 228 | accurate_detected_sent += 1 229 | 230 | counts = { # TP, FP, TN for each level 231 | 'det_char_counts': [true_detected_char, 232 | corrected_char-true_detected_char, 233 | wrong_char-true_detected_char], 234 | 'cor_char_counts': [true_corrected_char, 235 | corrected_char-true_corrected_char, 236 | wrong_char-true_corrected_char], 237 | 'det_sent_counts': [true_detected_sent, 238 | corrected_sent-true_detected_sent, 239 | wrong_sent-true_detected_sent], 240 | 'cor_sent_counts': [true_corrected_sent, 241 | corrected_sent-true_corrected_sent, 242 | wrong_sent-true_corrected_sent], 243 | 'det_sent_acc': accurate_detected_sent / all_sent, 244 | 'cor_sent_acc': accurate_corrected_sent / all_sent, 245 | 'all_sent_count': all_sent, 246 | } 247 | 248 | details = {} 249 | for phase in ['det_char', 'cor_char', 'det_sent', 'cor_sent']: 250 | dic = report_prf( 251 | *counts[f'{phase}_counts'], 252 | phase=phase, logger=logger, 253 | return_dict=True) 254 | details.update(dic) 255 | details.update(counts) 256 | return details 257 | -------------------------------------------------------------------------------- /bbcm/utils/file_io.py: -------------------------------------------------------------------------------- 1 | """ 2 | @Time : 2021-01-21 10:39:23 3 | @File : file_io.py 4 | @Author : Abtion 5 | @Email : abtion{at}outlook.com 6 | """ 7 | import errno 8 | import json 9 | import logging 10 | import os 11 | import shutil 12 | import sys 13 | from collections import OrderedDict 14 | from typing import ( 15 | IO, 16 | Any, 17 | Callable, 18 | Dict, 19 | List, 20 | MutableMapping, 21 | Optional, 22 | Union, 23 | ) 24 | 25 | __all__ = ["PathManager", 26 | "get_cache_dir", 27 | "get_abs_path", 28 | "dump_json", 29 | "load_json"] 30 | 31 | 32 | def get_cache_dir(cache_dir: Optional[str] = None) -> str: 33 | """ 34 | Returns a default directory to cache static files 35 | (usually downloaded from Internet), if None is provided. 36 | Args: 37 | cache_dir (None or str): if not None, will be returned as is. 38 | If None, returns the default cache directory as: 39 | 1) $DLCORE_CACHE, if set 40 | 2) otherwise ~/.torch/dlcore_cache 41 | """ 42 | if cache_dir is None: 43 | cache_dir = os.path.expanduser( 44 | os.getenv("DLCORE_CACHE", "~/.torch/dlcore_cache") 45 | ) 46 | return cache_dir 47 | 48 | 49 | class PathHandler: 50 | """ 51 | PathHandler is a base class that defines common I/O functionality for a URI 52 | protocol. It routes I/O for a generic URI which may look like "protocol://*" 53 | or a canonical filepath "/foo/bar/baz". 54 | """ 55 | 56 | _strict_kwargs_check = True 57 | 58 | def _check_kwargs(self, kwargs: Dict[str, Any]) -> None: 59 | """ 60 | Checks if the given arguments are empty. Throws a ValueError if strict 61 | kwargs checking is enabled and args are non-empty. If strict kwargs 62 | checking is disabled, only a warning is logged. 63 | Args: 64 | kwargs (Dict[str, Any]) 65 | """ 66 | if self._strict_kwargs_check: 67 | if len(kwargs) > 0: 68 | raise ValueError("Unused arguments: {}".format(kwargs)) 69 | else: 70 | logger = logging.getLogger(__name__) 71 | for k, v in kwargs.items(): 72 | logger.warning( 73 | "[PathManager] {}={} argument ignored".format(k, v) 74 | ) 75 | 76 | def _get_supported_prefixes(self) -> List[str]: 77 | """ 78 | Returns: 79 | List[str]: the list of URI prefixes this PathHandler can support 80 | """ 81 | raise NotImplementedError() 82 | 83 | def _get_local_path(self, path: str, **kwargs: Any) -> str: 84 | """ 85 | Get a filepath which is compatible with native Python I/O such as `open` 86 | and `os.path`. 87 | If URI points to a remote resource, this function may download and cache 88 | the resource to local disk. In this case, this function is meant to be 89 | used with read-only resources. 90 | Args: 91 | path (str): A URI supported by this PathHandler 92 | Returns: 93 | local_path (str): a file path which exists on the local file system 94 | """ 95 | raise NotImplementedError() 96 | 97 | def _open( 98 | self, path: str, mode: str = "r", buffering: int = -1, **kwargs: Any 99 | ) -> Union[IO[str], IO[bytes]]: 100 | """ 101 | Open a stream to a URI, similar to the built-in `open`. 102 | Args: 103 | path (str): A URI supported by this PathHandler 104 | mode (str): Specifies the mode in which the file is opened. It defaults 105 | to 'r'. 106 | buffering (int): An optional integer used to set the buffering policy. 107 | Pass 0 to switch buffering off and an integer >= 1 to indicate the 108 | size in bytes of a fixed-size chunk buffer. When no buffering 109 | argument is given, the default buffering policy depends on the 110 | underlying I/O implementation. 111 | Returns: 112 | file: a file-like object. 113 | """ 114 | raise NotImplementedError() 115 | 116 | def _copy( 117 | self, 118 | src_path: str, 119 | dst_path: str, 120 | overwrite: bool = False, 121 | **kwargs: Any, 122 | ) -> bool: 123 | """ 124 | Copies a source path to a destination path. 125 | Args: 126 | src_path (str): A URI supported by this PathHandler 127 | dst_path (str): A URI supported by this PathHandler 128 | overwrite (bool): Bool flag for forcing overwrite of existing file 129 | Returns: 130 | status (bool): True on success 131 | """ 132 | raise NotImplementedError() 133 | 134 | def _exists(self, path: str, **kwargs: Any) -> bool: 135 | """ 136 | Checks if there is a resource at the given URI. 137 | Args: 138 | path (str): A URI supported by this PathHandler 139 | Returns: 140 | bool: true if the path exists 141 | """ 142 | raise NotImplementedError() 143 | 144 | def _isfile(self, path: str, **kwargs: Any) -> bool: 145 | """ 146 | Checks if the resource at the given URI is a file. 147 | Args: 148 | path (str): A URI supported by this PathHandler 149 | Returns: 150 | bool: true if the path is a file 151 | """ 152 | raise NotImplementedError() 153 | 154 | def _isdir(self, path: str, **kwargs: Any) -> bool: 155 | """ 156 | Checks if the resource at the given URI is a directory. 157 | Args: 158 | path (str): A URI supported by this PathHandler 159 | Returns: 160 | bool: true if the path is a directory 161 | """ 162 | raise NotImplementedError() 163 | 164 | def _ls(self, path: str, **kwargs: Any) -> List[str]: 165 | """ 166 | List the contents of the directory at the provided URI. 167 | Args: 168 | path (str): A URI supported by this PathHandler 169 | Returns: 170 | List[str]: list of contents in given path 171 | """ 172 | raise NotImplementedError() 173 | 174 | def _mkdirs(self, path: str, **kwargs: Any) -> None: 175 | """ 176 | Recursive directory creation function. Like mkdir(), but makes all 177 | intermediate-level directories needed to contain the leaf directory. 178 | Similar to the native `os.makedirs`. 179 | Args: 180 | path (str): A URI supported by this PathHandler 181 | """ 182 | raise NotImplementedError() 183 | 184 | def _rm(self, path: str, **kwargs: Any) -> None: 185 | """ 186 | Remove the file (not directory) at the provided URI. 187 | Args: 188 | path (str): A URI supported by this PathHandler 189 | """ 190 | raise NotImplementedError() 191 | 192 | 193 | class NativePathHandler(PathHandler): 194 | """ 195 | Handles paths that can be accessed using Python native system calls. This 196 | handler uses `open()` and `os.*` calls on the given path. 197 | """ 198 | 199 | def _get_local_path(self, path: str, **kwargs: Any) -> str: 200 | self._check_kwargs(kwargs) 201 | return path 202 | 203 | def _open( 204 | self, 205 | path: str, 206 | mode: str = "r", 207 | buffering: int = -1, 208 | encoding: Optional[str] = None, 209 | errors: Optional[str] = None, 210 | newline: Optional[str] = None, 211 | closefd: bool = True, 212 | opener: Optional[Callable] = None, 213 | **kwargs: Any, 214 | ) -> Union[IO[str], IO[bytes]]: 215 | """ 216 | Open a path. 217 | Args: 218 | path (str): A URI supported by this PathHandler 219 | mode (str): Specifies the mode in which the file is opened. It defaults 220 | to 'r'. 221 | buffering (int): An optional integer used to set the buffering policy. 222 | Pass 0 to switch buffering off and an integer >= 1 to indicate the 223 | size in bytes of a fixed-size chunk buffer. When no buffering 224 | argument is given, the default buffering policy works as follows: 225 | * Binary files are buffered in fixed-size chunks; the size of 226 | the buffer is chosen using a heuristic trying to determine the 227 | underlying device’s “block size” and falling back on 228 | io.DEFAULT_BUFFER_SIZE. On many systems, the buffer will 229 | typically be 4096 or 8192 bytes long. 230 | encoding (Optional[str]): the name of the encoding used to decode or 231 | encode the file. This should only be used in text mode. 232 | errors (Optional[str]): an optional string that specifies how encoding 233 | and decoding errors are to be handled. This cannot be used in binary 234 | mode. 235 | newline (Optional[str]): controls how universal newlines mode works 236 | (it only applies to text mode). It can be None, '', '\n', '\r', 237 | and '\r\n'. 238 | closefd (bool): If closefd is False and a file descriptor rather than 239 | a filename was given, the underlying file descriptor will be kept 240 | open when the file is closed. If a filename is given closefd must 241 | be True (the default) otherwise an error will be raised. 242 | opener (Optional[Callable]): A custom opener can be used by passing 243 | a callable as opener. The underlying file descriptor for the file 244 | object is then obtained by calling opener with (file, flags). 245 | opener must return an open file descriptor (passing os.open as opener 246 | results in functionality similar to passing None). 247 | See https://docs.python.org/3/library/functions.html#open for details. 248 | Returns: 249 | file: a file-like object. 250 | """ 251 | self._check_kwargs(kwargs) 252 | return open( # type: ignore 253 | path, 254 | mode, 255 | buffering=buffering, 256 | encoding=encoding, 257 | errors=errors, 258 | newline=newline, 259 | closefd=closefd, 260 | opener=opener, 261 | ) 262 | 263 | def _copy( 264 | self, 265 | src_path: str, 266 | dst_path: str, 267 | overwrite: bool = False, 268 | **kwargs: Any, 269 | ) -> bool: 270 | """ 271 | Copies a source path to a destination path. 272 | Args: 273 | src_path (str): A URI supported by this PathHandler 274 | dst_path (str): A URI supported by this PathHandler 275 | overwrite (bool): Bool flag for forcing overwrite of existing file 276 | Returns: 277 | status (bool): True on success 278 | """ 279 | self._check_kwargs(kwargs) 280 | 281 | if os.path.exists(dst_path) and not overwrite: 282 | logger = logging.getLogger(__name__) 283 | logger.error("Destination file {} already exists.".format(dst_path)) 284 | return False 285 | 286 | try: 287 | shutil.copyfile(src_path, dst_path) 288 | return True 289 | except Exception as e: 290 | logger = logging.getLogger(__name__) 291 | logger.error("Error in file copy - {}".format(str(e))) 292 | return False 293 | 294 | def _exists(self, path: str, **kwargs: Any) -> bool: 295 | self._check_kwargs(kwargs) 296 | return os.path.exists(path) 297 | 298 | def _isfile(self, path: str, **kwargs: Any) -> bool: 299 | self._check_kwargs(kwargs) 300 | return os.path.isfile(path) 301 | 302 | def _isdir(self, path: str, **kwargs: Any) -> bool: 303 | self._check_kwargs(kwargs) 304 | return os.path.isdir(path) 305 | 306 | def _ls(self, path: str, **kwargs: Any) -> List[str]: 307 | self._check_kwargs(kwargs) 308 | return os.listdir(path) 309 | 310 | def _mkdirs(self, path: str, **kwargs: Any) -> None: 311 | self._check_kwargs(kwargs) 312 | try: 313 | os.makedirs(path, exist_ok=True) 314 | except OSError as e: 315 | # EEXIST it can still happen if multiple processes are creating the dir 316 | if e.errno != errno.EEXIST: 317 | raise 318 | 319 | def _rm(self, path: str, **kwargs: Any) -> None: 320 | self._check_kwargs(kwargs) 321 | os.remove(path) 322 | 323 | 324 | class PathManager: 325 | """ 326 | A class for users to open generic paths or translate generic paths to file names. 327 | """ 328 | 329 | _PATH_HANDLERS: MutableMapping[str, PathHandler] = OrderedDict() 330 | _NATIVE_PATH_HANDLER = NativePathHandler() 331 | 332 | @staticmethod 333 | def __get_path_handler(path: str) -> PathHandler: 334 | """ 335 | Finds a PathHandler that supports the given path. Falls back to the native 336 | PathHandler if no other handler is found. 337 | Args: 338 | path (str): URI path to resource 339 | Returns: 340 | handler (PathHandler) 341 | """ 342 | for p in PathManager._PATH_HANDLERS.keys(): 343 | if path.startswith(p): 344 | return PathManager._PATH_HANDLERS[p] 345 | return PathManager._NATIVE_PATH_HANDLER 346 | 347 | @staticmethod 348 | def open( 349 | path: str, mode: str = "r", buffering: int = -1, **kwargs: Any 350 | ) -> Union[IO[str], IO[bytes]]: 351 | """ 352 | Open a stream to a URI, similar to the built-in `open`. 353 | Args: 354 | path (str): A URI supported by this PathHandler 355 | mode (str): Specifies the mode in which the file is opened. It defaults 356 | to 'r'. 357 | buffering (int): An optional integer used to set the buffering policy. 358 | Pass 0 to switch buffering off and an integer >= 1 to indicate the 359 | size in bytes of a fixed-size chunk buffer. When no buffering 360 | argument is given, the default buffering policy depends on the 361 | underlying I/O implementation. 362 | Returns: 363 | file: a file-like object. 364 | """ 365 | return PathManager.__get_path_handler(path)._open( # type: ignore 366 | path, mode, buffering=buffering, **kwargs 367 | ) 368 | 369 | @staticmethod 370 | def copy( 371 | src_path: str, dst_path: str, overwrite: bool = False, **kwargs: Any 372 | ) -> bool: 373 | """ 374 | Copies a source path to a destination path. 375 | Args: 376 | src_path (str): A URI supported by this PathHandler 377 | dst_path (str): A URI supported by this PathHandler 378 | overwrite (bool): Bool flag for forcing overwrite of existing file 379 | Returns: 380 | status (bool): True on success 381 | """ 382 | 383 | # Copying across handlers is not supported. 384 | assert PathManager.__get_path_handler( # type: ignore 385 | src_path 386 | ) == PathManager.__get_path_handler(dst_path) 387 | return PathManager.__get_path_handler(src_path)._copy( 388 | src_path, dst_path, overwrite, **kwargs 389 | ) 390 | 391 | @staticmethod 392 | def get_local_path(path: str, **kwargs: Any) -> str: 393 | """ 394 | Get a filepath which is compatible with native Python I/O such as `open` 395 | and `os.path`. 396 | If URI points to a remote resource, this function may download and cache 397 | the resource to local disk. 398 | Args: 399 | path (str): A URI supported by this PathHandler 400 | Returns: 401 | local_path (str): a file path which exists on the local file system 402 | """ 403 | return PathManager.__get_path_handler( # type: ignore 404 | path 405 | )._get_local_path(path, **kwargs) 406 | 407 | @staticmethod 408 | def exists(path: str, **kwargs: Any) -> bool: 409 | """ 410 | Checks if there is a resource at the given URI. 411 | Args: 412 | path (str): A URI supported by this PathHandler 413 | Returns: 414 | bool: true if the path exists 415 | """ 416 | return PathManager.__get_path_handler(path)._exists( # type: ignore 417 | path, **kwargs 418 | ) 419 | 420 | @staticmethod 421 | def isfile(path: str, **kwargs: Any) -> bool: 422 | """ 423 | Checks if there the resource at the given URI is a file. 424 | Args: 425 | path (str): A URI supported by this PathHandler 426 | Returns: 427 | bool: true if the path is a file 428 | """ 429 | return PathManager.__get_path_handler(path)._isfile( # type: ignore 430 | path, **kwargs 431 | ) 432 | 433 | @staticmethod 434 | def isdir(path: str, **kwargs: Any) -> bool: 435 | """ 436 | Checks if the resource at the given URI is a directory. 437 | Args: 438 | path (str): A URI supported by this PathHandler 439 | Returns: 440 | bool: true if the path is a directory 441 | """ 442 | return PathManager.__get_path_handler(path)._isdir( # type: ignore 443 | path, **kwargs 444 | ) 445 | 446 | @staticmethod 447 | def ls(path: str, **kwargs: Any) -> List[str]: 448 | """ 449 | List the contents of the directory at the provided URI. 450 | Args: 451 | path (str): A URI supported by this PathHandler 452 | Returns: 453 | List[str]: list of contents in given path 454 | """ 455 | return PathManager.__get_path_handler(path)._ls( # type: ignore 456 | path, **kwargs 457 | ) 458 | 459 | @staticmethod 460 | def mkdirs(path: str, **kwargs: Any) -> None: 461 | """ 462 | Recursive directory creation function. Like mkdir(), but makes all 463 | intermediate-level directories needed to contain the leaf directory. 464 | Similar to the native `os.makedirs`. 465 | Args: 466 | path (str): A URI supported by this PathHandler 467 | """ 468 | return PathManager.__get_path_handler(path)._mkdirs( # type: ignore 469 | path, **kwargs 470 | ) 471 | 472 | @staticmethod 473 | def rm(path: str, **kwargs: Any) -> None: 474 | """ 475 | Remove the file (not directory) at the provided URI. 476 | Args: 477 | path (str): A URI supported by this PathHandler 478 | """ 479 | return PathManager.__get_path_handler(path)._rm( # type: ignore 480 | path, **kwargs 481 | ) 482 | 483 | @staticmethod 484 | def register_handler(handler: PathHandler) -> None: 485 | """ 486 | Register a path handler associated with `handler._get_supported_prefixes` 487 | URI prefixes. 488 | Args: 489 | handler (PathHandler) 490 | """ 491 | assert isinstance(handler, PathHandler), handler 492 | for prefix in handler._get_supported_prefixes(): 493 | assert prefix not in PathManager._PATH_HANDLERS 494 | PathManager._PATH_HANDLERS[prefix] = handler 495 | 496 | # Sort path handlers in reverse order so longer prefixes take priority, 497 | # eg: http://foo/bar before http://foo 498 | PathManager._PATH_HANDLERS = OrderedDict( 499 | sorted( 500 | PathManager._PATH_HANDLERS.items(), 501 | key=lambda t: t[0], 502 | reverse=True, 503 | ) 504 | ) 505 | 506 | @staticmethod 507 | def set_strict_kwargs_checking(enable: bool) -> None: 508 | """ 509 | Toggles strict kwargs checking. If enabled, a ValueError is thrown if any 510 | unused parameters are passed to a PathHandler function. If disabled, only 511 | a warning is given. 512 | With a centralized file API, there's a tradeoff of convenience and 513 | correctness delegating arguments to the proper I/O layers. An underlying 514 | `PathHandler` may support custom arguments which should not be statically 515 | exposed on the `PathManager` function. For example, a custom `HTTPURLHandler` 516 | may want to expose a `cache_timeout` argument for `open()` which specifies 517 | how old a locally cached resource can be before it's refetched from the 518 | remote server. This argument would not make sense for a `NativePathHandler`. 519 | If strict kwargs checking is disabled, `cache_timeout` can be passed to 520 | `PathManager.open` which will forward the arguments to the underlying 521 | handler. By default, checking is enabled since it is innately unsafe: 522 | multiple `PathHandler`s could reuse arguments with different semantic 523 | meanings or types. 524 | Args: 525 | enable (bool) 526 | """ 527 | PathManager._NATIVE_PATH_HANDLER._strict_kwargs_check = enable 528 | for handler in PathManager._PATH_HANDLERS.values(): 529 | handler._strict_kwargs_check = enable 530 | 531 | 532 | def load_json(fp): 533 | if not os.path.exists(fp): 534 | return dict() 535 | 536 | with open(fp, 'r', encoding='utf8') as f: 537 | return json.load(f) 538 | 539 | 540 | def dump_json(obj, fp): 541 | try: 542 | fp = os.path.abspath(fp) 543 | if not os.path.exists(os.path.dirname(fp)): 544 | os.makedirs(os.path.dirname(fp)) 545 | with open(fp, 'w', encoding='utf8') as f: 546 | json.dump(obj, f, ensure_ascii=False, indent=4, separators=(',', ':')) 547 | print(f'json文件保存成功,{fp}') 548 | return True 549 | except Exception as e: 550 | print(f'json文件{obj}保存失败, {e}') 551 | return False 552 | 553 | 554 | def get_main_dir(): 555 | # 如果是使用pyinstaller打包后的执行文件,则定位到执行文件所在目录 556 | if hasattr(sys, 'frozen'): 557 | return os.path.join(os.path.dirname(sys.executable)) 558 | # 其他情况则定位至项目根目录 559 | return os.path.join(os.path.dirname(__file__), '..', '..') 560 | 561 | 562 | def get_abs_path(*name): 563 | fn = os.path.join(*name) 564 | if os.path.isabs(fn): 565 | return fn 566 | return os.path.abspath(os.path.join(get_main_dir(), fn)) 567 | -------------------------------------------------------------------------------- /bbcm/utils/logger.py: -------------------------------------------------------------------------------- 1 | """ 2 | @Time : 2021-01-21 11:50:55 3 | @File : logger.py 4 | @Author : Abtion 5 | @Email : abtion{at}outlook.com 6 | """ 7 | import logging 8 | import os 9 | import sys 10 | 11 | 12 | def setup_logger(name, save_dir, distributed_rank): 13 | logger = logging.getLogger(name) 14 | logger.setLevel(logging.DEBUG) 15 | # don't log results for the non-master process 16 | if distributed_rank > 0: 17 | return logger 18 | ch = logging.StreamHandler(stream=sys.stdout) 19 | ch.setLevel(logging.DEBUG) 20 | formatter = logging.Formatter("%(asctime)s %(name)s %(levelname)s: %(message)s") 21 | ch.setFormatter(formatter) 22 | logger.addHandler(ch) 23 | 24 | if save_dir: 25 | if not os.path.exists(save_dir): 26 | os.makedirs(save_dir) 27 | fh = logging.FileHandler(os.path.join(save_dir, "log.txt"), encoding='utf8') 28 | fh.setLevel(logging.DEBUG) 29 | fh.setFormatter(formatter) 30 | logger.addHandler(fh) 31 | 32 | return logger 33 | -------------------------------------------------------------------------------- /checkpoints/.gitignore: -------------------------------------------------------------------------------- 1 | * -------------------------------------------------------------------------------- /configs/csc/train_SoftMaskedBert.yml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | BERT_CKPT: "bert-base-chinese" 3 | DEVICE: "cuda" 4 | NAME: "SoftMaskedBertModel" 5 | GPU_IDS: [1] 6 | # [loss_coefficient] 7 | HYPER_PARAMS: [0.8] 8 | 9 | DATASETS: 10 | TRAIN: "datasets/csc/train.json" 11 | VALID: "datasets/csc/dev.json" 12 | TEST: "datasets/csc/test.json" 13 | 14 | SOLVER: 15 | BASE_LR: 0.0001 16 | WEIGHT_DECAY: 5e-8 17 | BATCH_SIZE: 32 18 | MAX_EPOCHS: 10 19 | ACCUMULATE_GRAD_BATCHES: 4 20 | 21 | 22 | TEST: 23 | BATCH_SIZE: 16 24 | 25 | TASK: 26 | NAME: "csc" 27 | 28 | OUTPUT_DIR: "checkpoints/SoftMaskedBert" 29 | -------------------------------------------------------------------------------- /configs/csc/train_bert4csc.yml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | BERT_CKPT: "bert-base-chinese" 3 | DEVICE: "cuda" 4 | NAME: "bert4csc" 5 | # [loss_coefficient] 6 | HYPER_PARAMS: [ 0.5 ] 7 | GPU_IDS: [0] 8 | 9 | DATASETS: 10 | TRAIN: "datasets/csc/train.json" 11 | VALID: "datasets/csc/dev.json" 12 | TEST: "datasets/csc/test.json" 13 | 14 | SOLVER: 15 | BASE_LR: 1e-4 16 | WEIGHT_DECAY: 5e-8 17 | BATCH_SIZE: 32 18 | MAX_EPOCHS: 10 19 | ACCUMULATE_GRAD_BATCHES: 4 20 | 21 | DATALOADER: 22 | NUM_WORKERS: 4 23 | 24 | TEST: 25 | BATCH_SIZE: 16 26 | 27 | TASK: 28 | NAME: "csc" 29 | 30 | OUTPUT_DIR: "checkpoints/bert4csc" 31 | -------------------------------------------------------------------------------- /configs/csc/train_macbert4csc.yml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | BERT_CKPT: "hfl/chinese-macbert-base" 3 | DEVICE: "cuda:0" 4 | NAME: "macbert4csc" 5 | # [loss_coefficient] 6 | HYPER_PARAMS: [ 0.3 ] 7 | GPU_IDS: [ 0 ] 8 | # WEIGHTS: "epoch=6-val_loss=0.07.ckpt" 9 | 10 | DATASETS: 11 | TRAIN: "datasets/csc/train.json" 12 | VALID: "datasets/csc/dev.json" 13 | TEST: "datasets/csc/test.json" 14 | 15 | SOLVER: 16 | BASE_LR: 5e-5 17 | WEIGHT_DECAY: 0.01 18 | BATCH_SIZE: 32 19 | MAX_EPOCHS: 10 20 | ACCUMULATE_GRAD_BATCHES: 4 21 | 22 | 23 | TEST: 24 | BATCH_SIZE: 8 25 | 26 | TASK: 27 | NAME: "csc" 28 | 29 | OUTPUT_DIR: "checkpoints/macbert4csc" 30 | MODE: [ 'train', "test" ] 31 | -------------------------------------------------------------------------------- /configs/dict/white_name_list.json: -------------------------------------------------------------------------------- 1 | { 2 | "权": [ 3 | "债权", 4 | "量为权数" 5 | ], 6 | "出": [ 7 | "一次出", 8 | "存出保", 9 | "方面出" 10 | ], 11 | "止": [ 12 | "截止" 13 | ], 14 | "围": [ 15 | "围、禁止" 16 | ], 17 | "见": [ 18 | "见1", 19 | "见2", 20 | "万元见" 21 | ], 22 | "他": [ 23 | "其他基", 24 | "其他行", 25 | "其他法", 26 | "其他活", 27 | "其他业", 28 | "其他信", 29 | "其他文", 30 | "其他账", 31 | "其他受", 32 | "为他所", 33 | "其他方", 34 | "其他相", 35 | "其他事", 36 | "和其他", 37 | "的其他", 38 | "及其他" 39 | ], 40 | "指": [ 41 | "人:指", 42 | "日:指", 43 | "值:指", 44 | "金:指", 45 | "方:指", 46 | "费:指" 47 | ], 48 | "余": [ 49 | "结余", 50 | "按摊余成", 51 | "余额", 52 | "对其余", 53 | "剩余", 54 | "州市余杭", 55 | "余宝", 56 | "余洁", 57 | "余磊" 58 | ], 59 | "日": [ 60 | "交易日", 61 | "估值日", 62 | "开放日", 63 | "工作日" 64 | ], 65 | "主": [ 66 | "称为主袋", 67 | "主袋" 68 | ], 69 | "凡": [ 70 | "理,凡和" 71 | ], 72 | "甘": [ 73 | "上海甘证" 74 | ], 75 | "加": [ 76 | "任中加基", 77 | "厚利加混", 78 | "告增加次", 79 | "刘加海" 80 | ], 81 | "力": [ 82 | "王乃力", 83 | "等权力", 84 | "的权力", 85 | "使权力" 86 | ], 87 | "覆": [ 88 | "覆盖" 89 | ], 90 | "民": [ 91 | "李民吉", 92 | "缪建民" 93 | ], 94 | "音": [ 95 | "音像", 96 | "备录音" 97 | ], 98 | "及": [ 99 | "后应及", 100 | "方式及有", 101 | "自己及任", 102 | "理人及基", 103 | "披露及报", 104 | "国家及监" 105 | ], 106 | "需": [ 107 | "要求需向", 108 | "需按", 109 | "需依", 110 | "需遵", 111 | "需承", 112 | "需签", 113 | "机构需在", 114 | "理人需取", 115 | "需符" 116 | ], 117 | "通": [ 118 | "海市通力", 119 | "但对通", 120 | "银行通过" 121 | ], 122 | "俞": [ 123 | "俞卫", 124 | "俞洋" 125 | ], 126 | "锋": [ 127 | "俞卫锋", 128 | "谭广锋", 129 | "董一锋", 130 | "于海锋", 131 | "李剑锋" 132 | ], 133 | "睿": [ 134 | "孙睿", 135 | "长信睿进", 136 | "汇金睿选" 137 | ], 138 | "红": [ 139 | "吕红" 140 | ], 141 | "换": [ 142 | "转换" 143 | ], 144 | "接": [ 145 | "下,接" 146 | ], 147 | "属": [ 148 | "各类属信" 149 | ], 150 | "收": [ 151 | "收益", 152 | "接收", 153 | "收到", 154 | "收付", 155 | "公司收益" 156 | ], 157 | "金": [ 158 | "金额", 159 | "基金", 160 | "型基金中", 161 | "黄金琳", 162 | "马金" 163 | ], 164 | "伟": [ 165 | "杜伟", 166 | "张跃伟", 167 | "章伟东", 168 | "陈达伟" 169 | ], 170 | "作": [ 171 | "不作为", 172 | "部分作自", 173 | "也不作为", 174 | "程序作出" 175 | ], 176 | "终": [ 177 | "终止" 178 | ], 179 | "袋": [ 180 | "侧袋账", 181 | "按主袋账", 182 | "以主袋账" 183 | ], 184 | "交": [ 185 | "交纳" 186 | ], 187 | "寄": [ 188 | "寄交" 189 | ], 190 | "拒": [ 191 | "拒派" 192 | ], 193 | "与": [ 194 | "生效与公" 195 | ], 196 | "复": [ 197 | "银复", 198 | "答复" 199 | ], 200 | "转": [ 201 | "转人工" 202 | ], 203 | "阳": [ 204 | "海长阳路", 205 | "区星阳街" 206 | ], 207 | "观": [ 208 | "浙江观合" 209 | ], 210 | "基": [ 211 | "证监基字", 212 | "大成基金", 213 | "照《基", 214 | "后,基金", 215 | "公司基金", 216 | "通知基" 217 | ], 218 | "份": [ 219 | "每份基", 220 | "算的份额", 221 | "数成份股", 222 | "于每份基" 223 | ], 224 | "题": [ 225 | "造主题", 226 | "车主题", 227 | "创主题" 228 | ], 229 | "扼": [ 230 | "扼制" 231 | ], 232 | "部": [ 233 | "研究部", 234 | "业务部", 235 | "纯债部", 236 | "投资部", 237 | "合规部", 238 | "控制部" 239 | ], 240 | "体": [ 241 | "化、体" 242 | ], 243 | "末": [ 244 | "期末" 245 | ], 246 | "值": [ 247 | "定估值" 248 | ], 249 | "至": [ 250 | "截至日", 251 | "截至基", 252 | "通知至" 253 | ], 254 | "信": [ 255 | "的资信控", 256 | "信用", 257 | "长信稳", 258 | "择资信状" 259 | ], 260 | "嘉": [ 261 | "指百嘉基" 262 | ], 263 | "业": [ 264 | "业务", 265 | "国农业银" 266 | ], 267 | "务": [ 268 | "业务部" 269 | ], 270 | "总": [ 271 | "总监" 272 | ], 273 | "监": [ 274 | "部总监" 275 | ], 276 | "大": [ 277 | "大成", 278 | "林保大基" 279 | ], 280 | "成": [ 281 | "大成基", 282 | "汇成基" 283 | ], 284 | "管": [ 285 | "基金管理" 286 | ], 287 | "理": [ 288 | "金管理有" 289 | ], 290 | "有": [ 291 | "管理有限", 292 | "基金有面", 293 | "产或有其" 294 | ], 295 | "公": [ 296 | "有限公司" 297 | ], 298 | "首": [ 299 | "公司首席" 300 | ], 301 | "席": [ 302 | "首席" 303 | ], 304 | "莹": [ 305 | "郑南莹女" 306 | ], 307 | "史": [ 308 | "会计史专" 309 | ], 310 | "支": [ 311 | "30支", 312 | "8支银行理财" 313 | ], 314 | "琳": [ 315 | "古琳花" 316 | ], 317 | "寅": [ 318 | "高寅初" 319 | ], 320 | "柏": [ 321 | "邓柏涛" 322 | ], 323 | "兵": [ 324 | "欧阳兵", 325 | "王兵", 326 | "戎兵" 327 | ], 328 | "全": [ 329 | "全价", 330 | "全债" 331 | ], 332 | "侧": [ 333 | "侧袋" 334 | ], 335 | "电": [ 336 | "基金电" 337 | ], 338 | "恬": [ 339 | "上海恬淡" 340 | ], 341 | "淡": [ 342 | "海恬淡资" 343 | ], 344 | "淳": [ 345 | "曾任淳大", 346 | "淳厚" 347 | ], 348 | "上": [ 349 | "上会", 350 | "媒介上公" 351 | ], 352 | "炜": [ 353 | "江嘉炜" 354 | ], 355 | "天": [ 356 | "7天" 357 | ], 358 | "行": [ 359 | "行权", 360 | "银行", 361 | "价值行" 362 | ], 363 | "长": [ 364 | "长待", 365 | "长信" 366 | ], 367 | "槛": [ 368 | "资门槛" 369 | ], 370 | "时": [ 371 | "博时" 372 | ], 373 | "江": [ 374 | "江向", 375 | "浙江", 376 | "江恩" 377 | ], 378 | "后": [ 379 | "投后管" 380 | ], 381 | "鑫": [ 382 | "利华鑫基", 383 | "韩鑫普", 384 | "博远鑫享", 385 | "信利鑫债", 386 | "在华鑫证" 387 | ], 388 | "季": [ 389 | "按季向", 390 | "季平" 391 | ], 392 | "置": [ 393 | "同时置备", 394 | "产配置混" 395 | ], 396 | "因": [ 397 | "因向", 398 | "认为因" 399 | ], 400 | "息": [ 401 | "息税" 402 | ], 403 | "久": [ 404 | "久期" 405 | ], 406 | "轧": [ 407 | "轧差", 408 | "摩根轧机" 409 | ], 410 | "在": [ 411 | "应对在投", 412 | "时,在确", 413 | "有人在按", 414 | "人将在启", 415 | "可以在报", 416 | "值,在", 417 | "理人在履", 418 | "标,在加", 419 | "理人在代" 420 | ], 421 | "征": [ 422 | "交易征费" 423 | ], 424 | "鉴": [ 425 | "留印鉴由" 426 | ], 427 | "录": [ 428 | "登录" 429 | ], 430 | "头": [ 431 | "空头", 432 | "头部" 433 | ], 434 | "曜": [ 435 | "类承曜先" 436 | ], 437 | "恋": [ 438 | "刘恋" 439 | ], 440 | "楠": [ 441 | "石楠", 442 | "张楠", 443 | "秦一楠", 444 | "张冠楠" 445 | ], 446 | "返": [ 447 | "买入返" 448 | ], 449 | "冲": [ 450 | "可冲抵" 451 | ], 452 | "含": [ 453 | "含权", 454 | "如下含义" 455 | ], 456 | "待": [ 457 | "待偿" 458 | ], 459 | "记": [ 460 | "同一记" 461 | ], 462 | "样": [ 463 | "选样" 464 | ], 465 | "应": [ 466 | "反应" 467 | ], 468 | "视": [ 469 | "货盯视结" 470 | ], 471 | "一": [ 472 | "当前一估", 473 | "道文一西", 474 | "天津一德", 475 | "一路", 476 | "方纯一大", 477 | "一经" 478 | ], 479 | "服": [ 480 | "客服部" 481 | ], 482 | "振": [ 483 | "威华振会", 484 | "金合振投" 485 | ], 486 | "级": [ 487 | "AA级", 488 | "评级", 489 | "A+级" 490 | ], 491 | "的": [ 492 | "A级的信", 493 | "资产的2", 494 | "+级的信", 495 | "资产的5", 496 | "总监的安", 497 | "显著的表" 498 | ], 499 | "用": [ 500 | "信用" 501 | ], 502 | "债": [ 503 | "信用债比", 504 | "信用债资" 505 | ], 506 | "比": [ 507 | "用债比例" 508 | ], 509 | "例": [ 510 | "比例" 511 | ], 512 | "合": [ 513 | "合计", 514 | "混合型", 515 | "合计", 516 | "基金合" 517 | ], 518 | "计": [ 519 | "合计" 520 | ], 521 | "不": [ 522 | "合计不超" 523 | ], 524 | "超": [ 525 | "不超过" 526 | ], 527 | "过": [ 528 | "不超过", 529 | "通过了" 530 | ], 531 | "资": [ 532 | "资产", 533 | "投资" 534 | ], 535 | "产": [ 536 | "资产" 537 | ], 538 | "投": [ 539 | "投资", 540 | "生时投" 541 | ], 542 | "于": [ 543 | "投资于", 544 | "公告于", 545 | "不能于" 546 | ], 547 | "评": [ 548 | "信用评级" 549 | ], 550 | "为": [ 551 | "评级为A" 552 | ], 553 | "做": [ 554 | "做出", 555 | "做套" 556 | ], 557 | "研": [ 558 | "身投研优" 559 | ], 560 | "就": [ 561 | "发生就基" 562 | ], 563 | "共": [ 564 | "方,共" 565 | ], 566 | "静": [ 567 | "陈静满", 568 | "胡静华" 569 | ], 570 | "志": [ 571 | "王连志", 572 | "麻众志", 573 | "王连志", 574 | "杨志涌" 575 | ], 576 | "领": [ 577 | "领先" 578 | ], 579 | "乾": [ 580 | "安信乾盛", 581 | "安信乾宏", 582 | "金在乾道" 583 | ], 584 | "直": [ 585 | "直连" 586 | ], 587 | "鞍": [ 588 | "毛鞍宁" 589 | ], 590 | "当": [ 591 | "当日", 592 | "不得当用" 593 | ], 594 | "汇": [ 595 | "经汇" 596 | ], 597 | "持": [ 598 | "佑瑞持债" 599 | ], 600 | "适": [ 601 | "适用" 602 | ], 603 | "集": [ 604 | "集合" 605 | ], 606 | "目": [ 607 | "付,目", 608 | "投资项目", 609 | "目1个" 610 | ], 611 | "人": [ 612 | "管理人不" 613 | ], 614 | "进": [ 615 | "行为进", 616 | "信稳进资" 617 | ], 618 | "届": [ 619 | "据其届" 620 | ], 621 | "事": [ 622 | "通知事", 623 | "新增事" 624 | ], 625 | "燕": [ 626 | "王海燕", 627 | "李海燕" 628 | ], 629 | "壁": [ 630 | "花照壁西" 631 | ], 632 | "得": [ 633 | "号万得大", 634 | "海万得基" 635 | ], 636 | "尾": [ 637 | "能有尾差" 638 | ], 639 | "辐": [ 640 | "生大辐波" 641 | ], 642 | "和": [ 643 | "部分和", 644 | "海国和现", 645 | "海泰和经", 646 | "人:和志" 647 | ], 648 | "请": [ 649 | "提请投" 650 | ], 651 | "弛": [ 652 | "张弛先" 653 | ], 654 | "衍": [ 655 | "深圳衍界" 656 | ], 657 | "尽": [ 658 | "涉税尽职" 659 | ], 660 | "即": [ 661 | "事人即权" 662 | ], 663 | "圈": [ 664 | "经济圈成" 665 | ], 666 | "严": [ 667 | "严峰", 668 | "严娅", 669 | "严亦" 670 | ], 671 | "晚": [ 672 | "在不晚" 673 | ], 674 | "且": [ 675 | "交易且港" 676 | ], 677 | "中": [ 678 | "交易中", 679 | "基金中", 680 | "中国" 681 | ], 682 | "申": [ 683 | "交易申请" 684 | ], 685 | "兴": [ 686 | "信稳兴", 687 | "李兴春", 688 | "金合兴投" 689 | ], 690 | "胜": [ 691 | "海彤胜投" 692 | ], 693 | "男": [ 694 | "事长男", 695 | "监事男", 696 | "经理男" 697 | ], 698 | "晓": [ 699 | "公任晓威", 700 | "李晓皙" 701 | ], 702 | "师": [ 703 | "管理师" 704 | ], 705 | "营": [ 706 | "立董营财", 707 | "高振营", 708 | "主营产" 709 | ], 710 | "实": [ 711 | "地上实产", 712 | "乐实" 713 | ], 714 | "女": [ 715 | "监事女三" 716 | ], 717 | "战": [ 718 | "公司战" 719 | ], 720 | "家": [ 721 | "上海家宝" 722 | ], 723 | "稳": [ 724 | "长信稳进" 725 | ], 726 | "配": [ 727 | "资产配置" 728 | ], 729 | "混": [ 730 | "配置混合" 731 | ], 732 | "型": [ 733 | "混合型" 734 | ], 735 | "新": [ 736 | "长信新利", 737 | "冯恩新" 738 | ], 739 | "彬": [ 740 | "尹彬彬", 741 | "陈祎彬" 742 | ], 743 | "宫": [ 744 | "王宫" 745 | ], 746 | "意": [ 747 | "沈茹意联", 748 | "沈茹意" 749 | ], 750 | "受": [ 751 | "认接受的", 752 | "时不受" 753 | ], 754 | "小": [ 755 | "减小基", 756 | "葛小波", 757 | "吴小静" 758 | ], 759 | "手": [ 760 | "更改手" 761 | ], 762 | "情": [ 763 | "人的情" 764 | ], 765 | "坐": [ 766 | "人工坐席" 767 | ], 768 | "年": [ 769 | "年中" 770 | ], 771 | "国": [ 772 | "中国", 773 | "美国" 774 | ], 775 | "农": [ 776 | "农业" 777 | ], 778 | "银": [ 779 | "银行" 780 | ], 781 | "了": [ 782 | "通过了" 783 | ], 784 | "美": [ 785 | "美国" 786 | ], 787 | "珺": [ 788 | "王珺" 789 | ], 790 | "名": [ 791 | "东大名路", 792 | "前十名" 793 | ], 794 | "序": [ 795 | "当程序后" 796 | ], 797 | "以": [ 798 | "以双", 799 | "的,以基", 800 | "司(以" 801 | ], 802 | "米": [ 803 | "增盈米基", 804 | "加盈米财" 805 | ], 806 | "远": [ 807 | "道卓远混" 808 | ], 809 | "舍": [ 810 | "道如舍投" 811 | ], 812 | "其": [ 813 | "其实" 814 | ], 815 | "洋": [ 816 | "屠彦洋" 817 | ], 818 | "云": [ 819 | "冉云" 820 | ], 821 | "浙": [ 822 | "监会浙江" 823 | ], 824 | "同": [ 825 | "牌的同" 826 | ], 827 | "现": [ 828 | "现价", 829 | "现基" 830 | ], 831 | "慧": [ 832 | "丁慧", 833 | "张佳慧", 834 | "郑慧", 835 | "许慧琳", 836 | "骆文慧" 837 | ], 838 | "源": [ 839 | "高源", 840 | "源数", 841 | "马源" 842 | ], 843 | "兰": [ 844 | "兰显" 845 | ], 846 | "佳": [ 847 | "张佳慧", 848 | "时佳达" 849 | ], 850 | "凸": [ 851 | "凸度" 852 | ], 853 | "卖": [ 854 | "的超卖或" 855 | ], 856 | "周": [ 857 | "每周", 858 | "按周" 859 | ], 860 | "生": [ 861 | "合同生", 862 | "刘玉生任", 863 | "杨明生" 864 | ], 865 | "渤": [ 866 | "加入渤" 867 | ], 868 | "晗": [ 869 | "史亦晗女" 870 | ], 871 | "巾": [ 872 | "常卜巾" 873 | ], 874 | "阈": [ 875 | "风控阈值" 876 | ], 877 | "昱": [ 878 | "吴昱", 879 | "张昱" 880 | ], 881 | "溢": [ 882 | "立溢股" 883 | ], 884 | "欣": [ 885 | "黄欣", 886 | "李欣", 887 | "陆欣" 888 | ], 889 | "丽": [ 890 | "巩巧丽", 891 | "熊丽" 892 | ], 893 | "昊": [ 894 | "钱昊旻", 895 | "路昊" 896 | ], 897 | "健": [ 898 | "申健", 899 | "王健", 900 | "叶健", 901 | "梁健佩" 902 | ], 903 | "军": [ 904 | "王军辉" 905 | ], 906 | "粟": [ 907 | "粟旭" 908 | ], 909 | "舰": [ 910 | "王舰正" 911 | ], 912 | "明": [ 913 | "%)明" 914 | ], 915 | "对": [ 916 | "港币对人" 917 | ], 918 | "座": [ 919 | "人工座席" 920 | ], 921 | "享": [ 922 | "海陆享基" 923 | ], 924 | "诚": [ 925 | "停泰诚财" 926 | ], 927 | "相": [ 928 | "于天相投" 929 | ], 930 | "杨": [ 931 | "郑杨", 932 | "赵杨", 933 | "周杨" 934 | ], 935 | "技": [ 936 | "胡技勋" 937 | ], 938 | "政": [ 939 | "辛国政" 940 | ], 941 | "耀": [ 942 | "何耀" 943 | ], 944 | "微": [ 945 | "梁微" 946 | ], 947 | "地": [ 948 | "新天地大" 949 | ], 950 | "虹": [ 951 | "王虹", 952 | "龙江虹" 953 | ], 954 | "王": [ 955 | "王达", 956 | "王连" 957 | ], 958 | "峰": [ 959 | "薛峰", 960 | "张峰" 961 | ], 962 | "卷": [ 963 | "京蛋卷基" 964 | ], 965 | "度": [ 966 | "度小" 967 | ], 968 | "正": [ 969 | "海中正达", 970 | "诺亚正行" 971 | ], 972 | "植": [ 973 | "北京植信" 974 | ], 975 | "赢": [ 976 | "海攀赢基" 977 | ], 978 | "艳": [ 979 | "周艳琼" 980 | ], 981 | "叁": [ 982 | "博道叁佰" 983 | ], 984 | "甬": [ 985 | "定(甬银" 986 | ], 987 | "备": [ 988 | "监会备" 989 | ], 990 | "章": [ 991 | "章知", 992 | "章宏", 993 | "章伟" 994 | ], 995 | "轩": [ 996 | "廖庆轩" 997 | ], 998 | "承": [ 999 | "吴承根" 1000 | ], 1001 | "韬": [ 1002 | "章宏韬" 1003 | ], 1004 | "证": [ 1005 | "弄证大" 1006 | ], 1007 | "冬": [ 1008 | "安冬" 1009 | ], 1010 | "起": [ 1011 | "起至" 1012 | ], 1013 | "陈": [ 1014 | "由黄陈代" 1015 | ], 1016 | "风": [ 1017 | "加天风证" 1018 | ], 1019 | "量": [ 1020 | "加长量基" 1021 | ], 1022 | "奕": [ 1023 | "金在奕丰" 1024 | ], 1025 | "浦": [ 1026 | "金在浦领" 1027 | ], 1028 | "卫": [ 1029 | "吕春卫", 1030 | "吴卫卫" 1031 | ], 1032 | "张": [ 1033 | "张海" 1034 | ], 1035 | "淮": [ 1036 | "毛淮平" 1037 | ], 1038 | "芯": [ 1039 | "赵芯蕊" 1040 | ], 1041 | "贲": [ 1042 | "贲惠" 1043 | ], 1044 | "期": [ 1045 | "额当期分" 1046 | ], 1047 | "盯": [ 1048 | "期货盯视" 1049 | ], 1050 | "此": [ 1051 | "以此类" 1052 | ], 1053 | "办": [ 1054 | "延期办" 1055 | ], 1056 | "性": [ 1057 | "提示性" 1058 | ], 1059 | "冰": [ 1060 | "周慕冰" 1061 | ], 1062 | "放": [ 1063 | "京京放投" 1064 | ], 1065 | "岷": [ 1066 | "黄越岷先" 1067 | ], 1068 | "休": [ 1069 | "政府休改" 1070 | ], 1071 | "智": [ 1072 | "产业智选" 1073 | ], 1074 | "察": [ 1075 | "对监察稽" 1076 | ], 1077 | "类": [ 1078 | "C类" 1079 | ], 1080 | "采": [ 1081 | "析。采" 1082 | ], 1083 | "唐": [ 1084 | "路海唐商" 1085 | ], 1086 | "镇": [ 1087 | "李镇西" 1088 | ], 1089 | "轶": [ 1090 | "冯轶明" 1091 | ], 1092 | "可": [ 1093 | "城区可园" 1094 | ], 1095 | "岗": [ 1096 | "岗厦" 1097 | ], 1098 | "剑": [ 1099 | "齐剑辉", 1100 | "李剑锋", 1101 | "周剑秋" 1102 | ], 1103 | "辉": [ 1104 | "齐剑辉" 1105 | ], 1106 | "才": [ 1107 | "人:才殿" 1108 | ], 1109 | "瑀": [ 1110 | "平国瑀资" 1111 | ], 1112 | "星": [ 1113 | "王星月" 1114 | ], 1115 | "波": [ 1116 | "梁云波", 1117 | "孙燕波", 1118 | "赵洪波", 1119 | "市场波" 1120 | ], 1121 | "峥": [ 1122 | "徐海峥" 1123 | ], 1124 | "元": [ 1125 | "蚂蚁元空" 1126 | ], 1127 | "林": [ 1128 | "张林" 1129 | ], 1130 | "宣": [ 1131 | "董宣" 1132 | ], 1133 | "利": [ 1134 | "王利刚" 1135 | ], 1136 | "雯": [ 1137 | "周雯", 1138 | "王雯" 1139 | ], 1140 | "盛": [ 1141 | "姚盛盛" 1142 | ], 1143 | "皙": [ 1144 | "李晓皙" 1145 | ], 1146 | "洁": [ 1147 | "马梦洁", 1148 | "虞洁" 1149 | ], 1150 | "滨": [ 1151 | "哈尔滨市" 1152 | ], 1153 | "境": [ 1154 | "胡芷境" 1155 | ], 1156 | "园": [ 1157 | "许梦园" 1158 | ], 1159 | "路": [ 1160 | "纪路", 1161 | "汇路存" 1162 | ], 1163 | "翊": [ 1164 | "海瀛翊投" 1165 | ], 1166 | "博": [ 1167 | "孙博文" 1168 | ], 1169 | "晖": [ 1170 | "刘晖" 1171 | ], 1172 | "竞": [ 1173 | "张竞妍" 1174 | ], 1175 | "妍": [ 1176 | "张竞妍" 1177 | ], 1178 | "崟": [ 1179 | "吕祥崟" 1180 | ], 1181 | "市": [ 1182 | "活跃市" 1183 | ], 1184 | "向": [ 1185 | "能力向好" 1186 | ], 1187 | "统": [ 1188 | "华统股" 1189 | ], 1190 | "泺": [ 1191 | "下区泺源" 1192 | ], 1193 | "菲": [ 1194 | "潘菲" 1195 | ], 1196 | "越": [ 1197 | "越权" 1198 | ], 1199 | "水": [ 1200 | "笔高水位" 1201 | ], 1202 | "须": [ 1203 | "还须支" 1204 | ], 1205 | "洼": [ 1206 | "洼地" 1207 | ] 1208 | } -------------------------------------------------------------------------------- /datasets/.gitignore: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gitabtion/BertBasedCorrectionModels/f5e1bacf4f4f9b4a9ca434e9d40249285172a06a/datasets/.gitignore -------------------------------------------------------------------------------- /datasets/csc/.gitignore: -------------------------------------------------------------------------------- 1 | * -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | OpenCC>=1.1.0,<=1.1.1 2 | pytorch-lightning==1.1.2 3 | six==1.14.0 4 | tensorboard==2.4.0 5 | tensorboard-plugin-wit==1.7.0 6 | threadpoolctl==2.1.0 7 | tokenizers==0.9.4 8 | torch==1.7.0 9 | transformers==4.1.1 10 | yacs 11 | lxml 12 | -------------------------------------------------------------------------------- /tests/test_embedding.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import transformers as tfs 4 | 5 | 6 | class TestEmbedding(unittest.TestCase): 7 | def test_embedding(self): 8 | model = tfs.AutoModel.from_pretrained('bert-base-chinese') 9 | tokenizer = tfs.AutoTokenizer.from_pretrained('bert-base-chinese') 10 | -------------------------------------------------------------------------------- /tools/bases.py: -------------------------------------------------------------------------------- 1 | """ 2 | @Time : 2021-01-21 11:17:25 3 | @File : bases.py 4 | @Author : Abtion 5 | @Email : abtion{at}outlook.com 6 | """ 7 | import argparse 8 | import os 9 | 10 | import logging 11 | 12 | import torch 13 | from pytorch_lightning.callbacks import ModelCheckpoint 14 | 15 | from bbcm.utils import get_abs_path 16 | from bbcm.utils.logger import setup_logger 17 | from bbcm.config import cfg 18 | import pytorch_lightning as pl 19 | import os 20 | 21 | 22 | def args_parse(config_file=''): 23 | parser = argparse.ArgumentParser(description="bbcm") 24 | parser.add_argument( 25 | "--config_file", default="", help="path to config file", type=str 26 | ) 27 | parser.add_argument("--opts", help="Modify config options using the command-line key value", default=[], 28 | nargs=argparse.REMAINDER) 29 | 30 | args = parser.parse_args() 31 | 32 | config_file = args.config_file or config_file 33 | 34 | if config_file != "": 35 | cfg.merge_from_file(get_abs_path('configs', config_file)) 36 | cfg.merge_from_list(args.opts) 37 | cfg.freeze() 38 | 39 | name = cfg.MODEL.NAME 40 | 41 | output_dir = cfg.OUTPUT_DIR 42 | 43 | logger = setup_logger(name, get_abs_path(output_dir), 0) 44 | logger.info(args) 45 | 46 | if config_file != '': 47 | logger.info("Loaded configuration file {}".format(config_file)) 48 | with open(get_abs_path('configs', config_file), 'r') as cf: 49 | config_str = "\n" + cf.read() 50 | logger.info(config_str) 51 | 52 | logger.info("Running with config:\n{}".format(cfg)) 53 | return cfg 54 | 55 | 56 | def train(config, model, loaders, ckpt_callback=None): 57 | """ 58 | 训练 59 | Args: 60 | config: 配置 61 | model: 模型 62 | loaders: 各个数据的loader,包含train,valid,test 63 | ckpt_callback: 按需保存模型的callback,如为空则默认每个epoch保存一次模型。 64 | Returns: 65 | None 66 | """ 67 | train_loader, valid_loader, test_loader = loaders 68 | trainer = pl.Trainer(max_epochs=config.SOLVER.MAX_EPOCHS, 69 | gpus=None if config.MODEL.DEVICE == 'cpu' else config.MODEL.GPU_IDS, 70 | accumulate_grad_batches=config.SOLVER.ACCUMULATE_GRAD_BATCHES, 71 | callbacks=[ckpt_callback]) 72 | # 满足以下条件才进行训练 73 | # 1. 配置文件中要求进行训练 74 | # 2. train_loader不为空 75 | # 3. train_loader中有数据 76 | if 'train' in config.MODE and train_loader and len(train_loader) > 0: 77 | if valid_loader and len(valid_loader) > 0: 78 | trainer.fit(model, train_loader, valid_loader) 79 | else: 80 | trainer.fit(model, train_loader) 81 | # 是否进行测试的逻辑同训练 82 | if 'test' in config.MODE and test_loader and len(test_loader) > 0: 83 | if ckpt_callback and len(ckpt_callback.best_model_path) > 0: 84 | ckpt_path = ckpt_callback.best_model_path 85 | elif len(config.MODEL.WEIGHTS) > 0: 86 | ckpt_path = get_abs_path(config.OUTPUT_DIR, config.MODEL.WEIGHTS) 87 | else: 88 | ckpt_path = None 89 | print(ckpt_path) 90 | if (ckpt_path is not None) and os.path.exists(ckpt_path): 91 | model.load_state_dict(torch.load(ckpt_path)['state_dict']) 92 | trainer.test(model, test_loader) 93 | -------------------------------------------------------------------------------- /tools/convert_to_pure_state_dict.py: -------------------------------------------------------------------------------- 1 | """ 2 | @Time : 2021-02-03 20:58:46 3 | @File : convert_to_pure_state_dict.py 4 | @Author : Abtion 5 | @Email : abtion{at}outlook.com 6 | """ 7 | import sys 8 | import argparse 9 | import os 10 | import torch 11 | from collections import OrderedDict 12 | 13 | sys.path.append('..') 14 | from bbcm.utils import get_abs_path 15 | 16 | 17 | def convert(fn, model_name): 18 | """ 19 | 从保存的ckpt文件中取出模型的state_dict用于迁移。 20 | Args: 21 | fn: ckpt文件的文件名 22 | model_name: 模型名,应与yml中的一致。 23 | 24 | Returns: 25 | 26 | """ 27 | file_dir = get_abs_path("checkpoints", model_name) 28 | state_dict = torch.load((os.path.join(file_dir, fn)))['state_dict'] 29 | new_state_dict = OrderedDict() 30 | if model_name in ['bert4csc', 'macbert4csc']: 31 | for k, v in state_dict.items(): 32 | new_state_dict[k[5:]] = v 33 | else: 34 | new_state_dict = state_dict 35 | torch.save(new_state_dict, os.path.join(file_dir, 'pytorch_model.bin')) 36 | 37 | 38 | def parse_args(): 39 | parser = argparse.ArgumentParser(description="fast-bbdl") 40 | parser.add_argument( 41 | "--ckpt_fn", default="", help="checkpoint file name", type=str 42 | ) 43 | parser.add_argument( 44 | "--model_name", default="bert4csc", help="model name, candidates: bert4csc, macbert4csc, SoftMaskedBert", type=str 45 | ) 46 | 47 | args = parser.parse_args() 48 | return args 49 | 50 | 51 | if __name__ == '__main__': 52 | arguments = parse_args() 53 | convert(arguments.ckpt_fn, arguments.model_name) 54 | -------------------------------------------------------------------------------- /tools/inference.py: -------------------------------------------------------------------------------- 1 | """ 2 | @Time : 2021-02-05 15:33:55 3 | @File : inference.py 4 | @Author : Abtion 5 | @Email : abtion{at}outlook.com 6 | """ 7 | import sys 8 | import argparse 9 | import os 10 | import torch 11 | from transformers import BertTokenizer 12 | from tools.bases import args_parse 13 | sys.path.append('..') 14 | from bbcm.modeling.csc import BertForCsc, SoftMaskedBertModel 15 | from bbcm.utils import get_abs_path 16 | import json 17 | import codecs 18 | import re 19 | 20 | def parse_args(): 21 | parser = argparse.ArgumentParser(description="bbcm") 22 | parser.add_argument( 23 | "--config_file", default="csc/train_bert4csc.yml", help="config file", type=str 24 | ) 25 | parser.add_argument( 26 | "--ckpt_fn", default="epoch=2-val_loss=0.02.ckpt", help="checkpoint file name", type=str 27 | ) 28 | parser.add_argument("--texts", default=["马上要过年了,提前祝大家心年快乐!"], nargs=argparse.REMAINDER) 29 | parser.add_argument("--text_file", default='') 30 | 31 | args = parser.parse_args() 32 | return args 33 | 34 | 35 | def load_model_directly(ckpt_file, config_file): 36 | # Example: 37 | # ckpt_fn = 'SoftMaskedBert/epoch=02-val_loss=0.02904.ckpt' (find in checkpoints) 38 | # config_file = 'csc/train_SoftMaskedBert.yml' (find in configs) 39 | 40 | from bbcm.config import cfg 41 | cp = get_abs_path('checkpoints', ckpt_file) 42 | cfg.merge_from_file(get_abs_path('configs', config_file)) 43 | tokenizer = BertTokenizer.from_pretrained(cfg.MODEL.BERT_CKPT) 44 | 45 | if cfg.MODEL.NAME in ['bert4csc', 'macbert4csc']: 46 | model = BertForCsc.load_from_checkpoint(cp, 47 | cfg=cfg, 48 | tokenizer=tokenizer) 49 | else: 50 | model = SoftMaskedBertModel.load_from_checkpoint(cp, 51 | cfg=cfg, 52 | tokenizer=tokenizer) 53 | model.eval() 54 | model.to(cfg.MODEL.DEVICE) 55 | return model 56 | 57 | 58 | def load_model(args): 59 | from bbcm.config import cfg 60 | cfg.merge_from_file(get_abs_path('configs', args.config_file)) 61 | tokenizer = BertTokenizer.from_pretrained(cfg.MODEL.BERT_CKPT) 62 | file_dir = get_abs_path("checkpoints", cfg.MODEL.NAME) 63 | if cfg.MODEL.NAME in ['bert4csc', 'macbert4csc']: 64 | model = BertForCsc.load_from_checkpoint(os.path.join(file_dir, args.ckpt_fn), 65 | cfg=cfg, 66 | tokenizer=tokenizer) 67 | else: 68 | model = SoftMaskedBertModel.load_from_checkpoint(os.path.join(file_dir, args.ckpt_fn), 69 | cfg=cfg, 70 | tokenizer=tokenizer) 71 | model.eval() 72 | model.to(cfg.MODEL.DEVICE) 73 | 74 | return model 75 | 76 | 77 | def inference(args): 78 | model = load_model(args) 79 | texts = [] 80 | if os.path.exists(args.text_file): 81 | with open(args.text_file, 'r', encoding='utf-8') as f: 82 | for line in f: 83 | texts.append(line.strip()) 84 | else: 85 | texts = args.texts 86 | print("传入 的原始文本:{}".format(texts)) 87 | corrected_texts = model.predict(texts) # input is list and output is list 88 | print("模型纠错输出文本:{}".format(corrected_texts)) 89 | # 输出结果后处理模块 90 | corrected_info = output_result(corrected_texts, sources=texts) 91 | print("模型纠错字段信息:{}".format(corrected_info)) 92 | return corrected_texts 93 | 94 | def parse_args_test(): 95 | parser = argparse.ArgumentParser(description="bbcm") 96 | parser.add_argument( 97 | "--config_file", default="csc/train_SoftMaskedBert.yml", help="config file", type=str 98 | ) 99 | parser.add_argument( 100 | "--ckpt_fn", default="epoch=09-val_loss=0.03032.ckpt", help="checkpoint file name", type=str 101 | ) 102 | args = parser.parse_args() 103 | return args 104 | 105 | 106 | def inference_test(texts): 107 | """input is texts list""" 108 | # 加载推理模型 109 | args = parse_args_test() 110 | # 加载模型参数 111 | model = load_model(args) 112 | #print("传入 的原始文本:{}".format(texts)) 113 | corrected_texts = model.predict(texts) # input is list and output is list 114 | #print("模型纠错输出文本:{}".format(corrected_texts)) 115 | # 输出结果后处理模块 116 | corrected_info = output_result(corrected_texts, sources=texts) 117 | #print("模型纠错字段信息:{}".format(corrected_info)) 118 | return corrected_texts, corrected_info 119 | 120 | 121 | def load_json(filename, encoding="utf-8"): 122 | """Load json file""" 123 | if not os.path.exists(filename): 124 | return None 125 | with codecs.open(filename, mode='r', encoding=encoding) as fr: 126 | return json.load(fr) 127 | 128 | 129 | # 预先加载 - 白名单 - 可根据实际应用场景定向更新后放入此推理代码中备用 130 | white_dict = load_json("../configs/dict/white_name_list.json") # 注意这里的路径-否则white_dict is None 131 | # 编译中文字符 132 | re_han = re.compile("[\u4E00-\u9Fa5]+") 133 | 134 | 135 | def load_white_dict(): 136 | default_lens = 4 # 根据配置的过纠字对应的语义片段长度来设定。默认值,可修改 137 | lens_list = list() 138 | for src in white_dict.keys(): 139 | for name in white_dict[src]: 140 | lens_list.append(len(name)) 141 | max_lens = max(lens_list) if lens_list else default_lens 142 | return white_dict, max_lens 143 | 144 | 145 | def output_result(results, sources): 146 | """ 147 | :param results: 模型纠错结果list 148 | :param sources: 输入list 149 | :return: 150 | """ 151 | """封装输出格式""" 152 | default_data = [ 153 | { 154 | "src_sentence": "", 155 | "tgt_sentence": "", 156 | "fragments": [] 157 | } 158 | ] 159 | if not results: 160 | return default_data 161 | data = [] 162 | # 一个result 生成一个字典dict() 163 | for idx, result in enumerate(results): 164 | # 源文本 165 | source = sources[idx] 166 | # 找到diff_info不同的地方 167 | fragments_lst = generate_diff_info(source, result) 168 | dict_res = { 169 | "src_sentence": source, 170 | "tgt_sentence": result, 171 | "fragments": fragments_lst 172 | } 173 | data.append(dict_res) 174 | return data 175 | 176 | 177 | def generate_diff_info(source, result): 178 | """ 179 | :param source: 原始输入文本 string 180 | :param result: 纠错模型输出文本 string 181 | :return: fragments, 输出[dict_1, dict_2, ....], dict_i 是每个字的纠错输出信息 182 | """ 183 | """基于原始输入文本和纠错后的文本输出differ_info""" 184 | # 定义默认输出 185 | fragments = list() 186 | # 仅支持输出和输出相同的情况下,如果不同则fragments输出为空 187 | # 后处理逻辑1 188 | if len(source) != len(result): 189 | return fragments 190 | # 后处理逻辑2 - 如果输入的source中没有或仅有一个中文字符则也不处理 191 | res_hans = re_han.findall(source) 192 | if not res_hans: 193 | return fragments 194 | if res_hans and len(res_hans[0]) < 2: 195 | return fragments 196 | # 后处理逻辑3 - 逐个字段比对,输出不同的字的位置 197 | for idx in range(len(source)): 198 | # 原始字 199 | src = source[idx] 200 | # 模型输出的字 201 | tgt = result[idx] 202 | # 如果字没发生变化则按照没有错误处理 203 | if src == tgt: 204 | continue 205 | # 过滤掉非汉字 206 | if not re_han.findall(src): 207 | continue 208 | # 通过白名单过滤掉overcorrection-误杀的情况 209 | if model_white_list_filter(source, src, idx): 210 | continue 211 | 212 | # 找到不同的字所在index 213 | fragment = { 214 | "error_init_id": idx, # 出错字开始位置索引 215 | "error_end_id": idx + 1, # 结束索引 216 | "src_fragment": src, # 原字 217 | "tgt_fragment": tgt # 纠正后的字 218 | } 219 | fragments.append(fragment) 220 | return fragments 221 | 222 | 223 | def model_white_list_filter(source, src, src_idx): 224 | """"source: 原来的句子; texts: 白名单; rules: 白名单规则""" 225 | """模型输出结果白名单过滤""" 226 | is_correct = False 227 | # 加载白名单 228 | wh_texts, span_w = load_white_dict() 229 | source_lens = len(source) 230 | if src in wh_texts.keys(): 231 | for src_span in wh_texts[src]: 232 | # 如果配置的语义片段src_span在 传入的文本text 片段source[span_start:span_end]中,则认为过纠is_correct is True。 233 | span_start = src_idx-span_w 234 | span_end = src_idx+span_w 235 | span_start = 0 if span_start < 0 else span_start 236 | span_end = span_end if span_end < source_lens else source_lens 237 | if src_span in source[span_start:span_end]: 238 | is_correct = True 239 | return is_correct 240 | return is_correct 241 | 242 | 243 | if __name__ == '__main__': 244 | # 原来推理代码 245 | # arguments = parse_args() 246 | # inference(arguments) 247 | # 添加代码后的测试代码如下: 248 | texts = [ 249 | '真麻烦你了。希望你们好好的跳无', 250 | '少先队员因该为老人让坐', 251 | '机七学习是人工智能领遇最能体现智能的一个分知', 252 | '今天心情很好', 253 | '汽车新式在这条路上', 254 | '中国人工只能布局很不错' 255 | ] 256 | corrected_texts, corrected_info = inference_test(texts) 257 | for info in corrected_info: 258 | print("----------------------") 259 | print("info:{}".format(info)) 260 | -------------------------------------------------------------------------------- /tools/train_csc.py: -------------------------------------------------------------------------------- 1 | """ 2 | @Time : 2021-01-21 11:47:09 3 | @File : train_csc.py 4 | @Author : Abtion 5 | @Email : abtion{at}outlook.com 6 | """ 7 | import sys 8 | sys.path.append('..') 9 | 10 | from bbcm.data.loaders.collator import DataCollatorForCsc 11 | from pytorch_lightning.callbacks import ModelCheckpoint 12 | from bbcm.data.build import make_loaders 13 | from bbcm.data.loaders import get_csc_loader 14 | from bbcm.modeling.csc import SoftMaskedBertModel 15 | from bbcm.modeling.csc.modeling_bert4csc import BertForCsc 16 | from transformers import BertTokenizer 17 | from bases import args_parse, train 18 | from bbcm.utils import get_abs_path 19 | from bbcm.data.processors.csc import preproc 20 | import os 21 | 22 | 23 | def main(): 24 | cfg = args_parse("csc/train_bert4csc.yml") 25 | 26 | # 如果不存在训练文件则先处理数据 27 | if not os.path.exists(get_abs_path(cfg.DATASETS.TRAIN)): 28 | preproc() 29 | tokenizer = BertTokenizer.from_pretrained(cfg.MODEL.BERT_CKPT) 30 | collator = DataCollatorForCsc(tokenizer=tokenizer) 31 | if cfg.MODEL.NAME in ["bert4csc", "macbert4csc"]: 32 | model = BertForCsc(cfg, tokenizer) 33 | else: 34 | model = SoftMaskedBertModel(cfg, tokenizer) 35 | 36 | if len(cfg.MODEL.WEIGHTS) > 0: 37 | ckpt_path = get_abs_path(cfg.OUTPUT_DIR, cfg.MODEL.WEIGHTS) 38 | model.load_from_checkpoint(ckpt_path, cfg=cfg, tokenizer=tokenizer) 39 | 40 | loaders = make_loaders(cfg, get_csc_loader, _collate_fn=collator) 41 | ckpt_callback = ModelCheckpoint( 42 | monitor='val_loss', 43 | dirpath=get_abs_path(cfg.OUTPUT_DIR), 44 | filename='{epoch:02d}-{val_loss:.5f}', 45 | save_top_k=1, 46 | mode='min' 47 | ) 48 | train(cfg, model, loaders, ckpt_callback) 49 | 50 | 51 | if __name__ == '__main__': 52 | main() 53 | -------------------------------------------------------------------------------- /tools/train_csc.sh: -------------------------------------------------------------------------------- 1 | python train_csc.py --opts MODE '["test"]' MODEL.WEIGHTS "epoch=1-val_loss=0.05.ckpt" --------------------------------------------------------------------------------