├── .gitignore ├── LICENSE ├── README.md ├── README_myexample.md ├── benchmark ├── download_fewrel.sh ├── download_nyt10.sh ├── download_semeval.sh └── download_wiki80.sh ├── example ├── gen_brand_data.py ├── gen_chinese_data.py ├── infer.py ├── train_bag_pcnn_att.py ├── train_supervised_bert.py └── train_supervised_cnn.py ├── opennre ├── __init__.py ├── encoder │ ├── __init__.py │ ├── base_encoder.py │ ├── bert_encoder.py │ ├── cnn_encoder.py │ └── pcnn_encoder.py ├── framework │ ├── __init__.py │ ├── bag_re.py │ ├── data_loader.py │ ├── sentence_re.py │ └── utils.py ├── model │ ├── __init__.py │ ├── bag_attention.py │ ├── bag_average.py │ ├── base_model.py │ └── softmax_nn.py ├── module │ ├── __init__.py │ ├── nn │ │ ├── __init__.py │ │ ├── cnn.py │ │ ├── lstm.py │ │ └── rnn.py │ └── pool │ │ ├── __init__.py │ │ ├── avg_pool.py │ │ └── max_pool.py ├── pretrain.py └── tokenization │ ├── __init__.py │ ├── basic_tokenizer.py │ ├── bert_tokenizer.py │ ├── utils.py │ ├── word_piece_tokenizer.py │ └── word_tokenizer.py ├── pretrain ├── download_bert.sh └── download_glove.sh ├── requirements.txt ├── setup.py └── tests └── test_inference.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | 106 | # test 107 | test.py 108 | 109 | # ckpt 110 | ckpt 111 | 112 | # vscode 113 | .vscode 114 | 115 | # tacred 116 | benchmark/tacred 117 | *.swp 118 | 119 | # data and pretrain 120 | pretrain 121 | benchmark 122 | !benchmark/*.sh 123 | !pretrain/*.sh 124 | 125 | # test env 126 | .test 127 | 128 | # package 129 | opennre-egg.info 130 | ======= 131 | # debug 132 | benchmark/nyt10-ori 133 | train_nyt10_pcnn_att_ori.py 134 | *.log 135 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Tianyu Gao 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # OpenNRE 2 | 3 | **在线的Demo网站([http://opennre.thunlp.ai/](http://opennre.thunlp.ai/)). Try it out!** 4 | 5 | 6 | OpenNRE是一个开源和可扩展的工具包,它提供了一个统一的框架来实现关系抽取模型。这个软件包是为以下群体设计的: 7 | 8 | * **关系抽取新手**。我们有手把手的教程和详细的文档,不仅可以让你使用关系抽取工具,还可以帮助你更好的了解这个领域的研究进展。 9 | * **开发者**。我们简单易用的界面和高性能的实现可以使您在实际应用中的部署更加快捷。此外,我们提供了多个预训练的模型,无需任何训练即可投入生产。 10 | * **研究人员**。通过我们的模块化设计,各种任务设置和度量工具,您可以轻松地对自己的模型进行实验,只需稍加修改。我们还提供了多个最常用的基准,用于不同设置的关系抽取。 11 | * **任何需要提交NLP作业来打动教授的人**。我们的软件包拥有最先进的模型,绝对可以帮助你在同学中脱颖而出! 12 | 13 | 14 | ## 什么是关系抽取 15 | 16 | 关系抽取是一种自然语言处理(NLP)任务,旨在提取实体(如**Bill Gates**和**Microsoft**)之间的关系(如*founder of*)。例如,从句子*Bill Gates founded Microsoft*中,我们可以抽取关系三(**Bill Gates**,*founder of*,**Microsoft**)。 17 | 18 | 关系抽取是知识图谱自动构建中的一项重要技术。通过使用关系抽取,我们可以积累抽取新的关系事实,扩展知识图谱,作为机器理解人类世界的一种方式,它有很多下游应用,如问答、推荐系统和搜索引擎。 19 | 20 | ## How to Cite 21 | 22 | A good research work is always accompanied by a thorough and faithful reference. If you use or extend our work, please cite the following paper: 23 | 24 | ``` 25 | @inproceedings{han-etal-2019-opennre, 26 | title = "{O}pen{NRE}: An Open and Extensible Toolkit for Neural Relation Extraction", 27 | author = "Han, Xu and Gao, Tianyu and Yao, Yuan and Ye, Deming and Liu, Zhiyuan and Sun, Maosong", 28 | booktitle = "Proceedings of EMNLP-IJCNLP: System Demonstrations", 29 | year = "2019", 30 | url = "https://www.aclweb.org/anthology/D19-3029", 31 | doi = "10.18653/v1/D19-3029", 32 | pages = "169--174" 33 | } 34 | ``` 35 | 36 | It's our honor to help you better explore relation extraction with our OpenNRE toolkit! 37 | 38 | ## Papers and Document 39 | 40 | If you want to learn more about neural relation extraction, visit another project of ours ([NREPapers](https://github.com/thunlp/NREPapers)). 41 | 42 | You can refer to our [document](https://opennre-docs.readthedocs.io/en/latest/) for more details about this project. 43 | 44 | ## 安装 45 | 46 | ### Install as A Python Package 47 | 48 | We are now working on deploy OpenNRE as a Python package. Coming soon! 49 | 50 | ### Using Git Repository 51 | 52 | Clone the repository from our github page (don't forget to star us!) 53 | 54 | ```bash 55 | git clone https://github.com/thunlp/OpenNRE.git 56 | ``` 57 | 58 | If it is too slow, you can try 59 | ``` 60 | git clone https://github.com/thunlp/OpenNRE.git --depth 1 61 | ``` 62 | 63 | Then install all the requirements: 64 | 65 | ``` 66 | pip install -r requirements.txt 67 | ``` 68 | 69 | **Note**: Please choose appropriate PyTorch version based on your machine (related to your CUDA version). For details, refer to https://pytorch.org/. 70 | 71 | Then install the package with 72 | ``` 73 | python setup.py install 74 | ``` 75 | 76 | If you also want to modify the code, run this: 77 | ``` 78 | python setup.py develop 79 | ``` 80 | ### 数据集下载 81 | 请注意,为了快速部署,我们已经移除了所有数据和预训练文件。你可以通过运行``benchmark``和``pretrain``文件夹中的脚本来手动下载它们。例如,如果你想下载FewRel数据集,你可以运行 "benchmark "和 "pretrain "文件夹中的脚本。 82 | 83 | ```bash 84 | bash benchmark/download_fewrel.sh 85 | ``` 86 | 87 | ## Easy Start 88 | 89 | 确保你已经按照上面的方法表明安装了OpenNRE。然后导入我们的软件包,并加载预训练好的模型。 90 | 91 | ```python 92 | >>> import opennre 93 | >>> model = opennre.get_model('wiki80_cnn_softmax') 94 | ``` 95 | 96 | 注意,首先下载checkpoint和数据可能需要几分钟。然后使用`infer`进行句子级关系抽取 97 | 98 | ```python 99 | >>> model.infer({'text': 'He was the son of Máel Dúin mac Máele Fithrich, and grandson of the high king Áed Uaridnach (died 612).', 'h': {'pos': (18, 46)}, 't': {'pos': (78, 91)}}) 100 | ('father', 0.5108704566955566) 101 | ``` 102 | 103 | 得到关系结果和它的置信度分数。 104 | 目前,我们有以下几种可用的模型。 105 | 106 | * `wiki80_cnn_softmax`: trained on `wiki80` dataset with a CNN encoder. 107 | * `wiki80_bert_softmax`: trained on `wiki80` dataset with a BERT encoder. 108 | * `wiki80_bertentity_softmax`: trained on `wiki80` dataset with a BERT encoder (using entity representation concatenation). 109 | * `tacred_bert_softmax`: trained on `TACRED` dataset with a BERT encoder. 110 | * `tacred_bertentity_softmax`: trained on `TACRED` dataset with a BERT encoder (using entity representation concatenation). 111 | 112 | ## Training 113 | 114 | 你可以用OpenNRE在自己的数据上训练自己的模型。在 "example"文件夹中,我们给出了有监督型RE模型和bag-level RE模型的训练代码样本,您可以使用我们提供的数据集或您自己的数据集。 115 | 116 | ```buildoutcfg 117 | example/ 118 | ├── train_bag_pcnn_att.py #使用pcnn,根据论文来看,bert效果较好 119 | ├── train_supervised_bert.py #训练和测试bert模型,可以加选项,使用cls或entity实体的向量表示2种情况 120 | └── train_supervised_cnn.py #使用cnn 121 | 122 | cd OpenNRE; 123 | #确保模型文件下载到了pretrain/bert-base-uncased目录下 124 | python example/train_supervised_bert.py --pretrain_path pretrain/bert-base-uncased --dataset wiki80 --pooler entity --do_train --do_test 125 | 126 | #如果没有下载, 那么指定默认名称 127 | python example/train_supervised_bert.py --pretrain_path bert-base-uncased --dataset wiki80 --pooler entity --do_train --do_test 128 | 129 | 2021-03-15 08:07:32,414 - root - WARNING - Test file ./benchmark/wiki80/wiki80_test.txt does not exist! Use val file instead 130 | 2021-03-15 08:07:32,414 - root - INFO - 参数: 131 | 2021-03-15 08:07:32,414 - root - INFO - pretrain_path: bert-base-uncased 132 | 2021-03-15 08:07:32,414 - root - INFO - ckpt: wiki80_bert-base-uncased_entity 133 | 2021-03-15 08:07:32,414 - root - INFO - pooler: entity 134 | 2021-03-15 08:07:32,414 - root - INFO - only_test: False 135 | 2021-03-15 08:07:32,414 - root - INFO - mask_entity: False 136 | 2021-03-15 08:07:32,414 - root - INFO - metric: acc 137 | 2021-03-15 08:07:32,414 - root - INFO - dataset: wiki80 138 | 2021-03-15 08:07:32,415 - root - INFO - train_file: ./benchmark/wiki80/wiki80_train.txt 139 | 2021-03-15 08:07:32,415 - root - INFO - val_file: ./benchmark/wiki80/wiki80_val.txt 140 | 2021-03-15 08:07:32,415 - root - INFO - test_file: ./benchmark/wiki80/wiki80_val.txt 141 | 2021-03-15 08:07:32,415 - root - INFO - rel2id_file: ./benchmark/wiki80/wiki80_rel2id.json 142 | 2021-03-15 08:07:32,415 - root - INFO - batch_size: 16 143 | 2021-03-15 08:07:32,415 - root - INFO - lr: 2e-05 144 | 2021-03-15 08:07:32,415 - root - INFO - max_length: 128 145 | 2021-03-15 08:07:32,415 - root - INFO - max_epoch: 3 146 | 2021-03-15 08:07:32,415 - root - INFO - 加载 BERT pre-trained checkpoint. 147 | 2021-03-15 08:07:32,630 - filelock - INFO - Lock 139806039272528 acquired on /root/.cache/huggingface/transformers/3c61d016573b14f7f008c02c4e51a366c67ab274726fe2910691e2a761acf43e.637c6035640bacb831febcc2b7f7bee0a96f9b30c2d7e9ef84082d9f252f3170.lock 148 | Downloading: 100% 433/433 [00:00<00:00, 553kB/s] 149 | 2021-03-15 08:07:32,845 - filelock - INFO - Lock 139806039272528 released on /root/.cache/huggingface/transformers/3c61d016573b14f7f008c02c4e51a366c67ab274726fe2910691e2a761acf43e.637c6035640bacb831febcc2b7f7bee0a96f9b30c2d7e9ef84082d9f252f3170.lock 150 | 2021-03-15 08:07:33,050 - filelock - INFO - Lock 139806178812176 acquired on /root/.cache/huggingface/transformers/a8041bf617d7f94ea26d15e218abd04afc2004805632abc0ed2066aa16d50d04.faf6ea826ae9c5867d12b22257f9877e6b8367890837bd60f7c54a29633f7f2f.lock 151 | Downloading: 100% 440M/440M [00:06<00:00, 65.1MB/s] 152 | 2021-03-15 08:07:40,025 - filelock - INFO - Lock 139806178812176 released on /root/.cache/huggingface/transformers/a8041bf617d7f94ea26d15e218abd04afc2004805632abc0ed2066aa16d50d04.faf6ea826ae9c5867d12b22257f9877e6b8367890837bd60f7c54a29633f7f2f.lock 153 | 2021-03-15 08:07:42,691 - filelock - INFO - Lock 139806022911376 acquired on /root/.cache/huggingface/transformers/45c3f7a79a80e1cf0a489e5c62b43f173c15db47864303a55d623bb3c96f72a5.d789d64ebfe299b0e416afc4a169632f903f693095b4629a7ea271d5a0cf2c99.lock 154 | Downloading: 100% 232k/232k [00:00<00:00, 921kB/s] 155 | 2021-03-15 08:07:43,149 - filelock - INFO - Lock 139806022911376 released on /root/.cache/huggingface/transformers/45c3f7a79a80e1cf0a489e5c62b43f173c15db47864303a55d623bb3c96f72a5.d789d64ebfe299b0e416afc4a169632f903f693095b4629a7ea271d5a0cf2c99.lock 156 | 2021-03-15 08:07:46,914 - root - INFO - 加载 RE 数据集 ./benchmark/wiki80/wiki80_train.txt with 50400 行和80 个关系. 157 | 2021-03-15 08:07:47,282 - root - INFO - 加载 RE 数据集 ./benchmark/wiki80/wiki80_val.txt with 5600 行和80 个关系. 158 | 2021-03-15 08:07:47,882 - root - INFO - 加载 RE 数据集 ./benchmark/wiki80/wiki80_val.txt with 5600 行和80 个关系. 159 | 2021-03-15 08:07:47,885 - root - INFO - 检测到GPU可用,使用GPU 160 | 2021-03-15 08:07:50,794 - root - INFO - === Epoch 0 train === 161 | 100% 3150/3150 [21:27<00:00, 2.45it/s, acc=0.806, loss=0.788] 162 | 2021-03-15 08:29:18,109 - root - INFO - === Epoch 0 val === 163 | 100% 350/350 [00:51<00:00, 6.76it/s, acc=0.858] 164 | 2021-03-15 08:30:09,911 - root - INFO - Evaluation result: {'acc': 0.8576785714285714, 'micro_p': 0.8576785714285714, 'micro_r': 0.8576785714285714, 'micro_f1': 0.8576785714285715}. 165 | 2021-03-15 08:30:09,911 - root - INFO - Metric micro_f1 current / best: 0.8576785714285715 / 0 166 | 2021-03-15 08:30:09,911 - root - INFO - Best ckpt and saved. 167 | 2021-03-15 08:30:11,301 - root - INFO - === Epoch 1 train === 168 | 100% 3150/3150 [21:37<00:00, 2.43it/s, acc=0.928, loss=0.245] 169 | 2021-03-15 08:51:48,797 - root - INFO - === Epoch 1 val === 170 | 100% 350/350 [00:51<00:00, 6.78it/s, acc=0.867] 171 | 2021-03-15 08:52:40,398 - root - INFO - Evaluation result: {'acc': 0.8669642857142857, 'micro_p': 0.8669642857142857, 'micro_r': 0.8669642857142857, 'micro_f1': 0.8669642857142857}. 172 | 2021-03-15 08:52:40,398 - root - INFO - Metric micro_f1 current / best: 0.8669642857142857 / 0.8576785714285715 173 | 2021-03-15 08:52:40,398 - root - INFO - Best ckpt and saved. 174 | 2021-03-15 08:52:41,666 - root - INFO - === Epoch 2 train === 175 | 100% 3150/3150 [21:18<00:00, 2.46it/s, acc=0.957, loss=0.149] 176 | 2021-03-15 09:14:00,638 - root - INFO - === Epoch 2 val === 177 | 100% 350/350 [00:51<00:00, 6.77it/s, acc=0.87] 178 | 2021-03-15 09:14:52,355 - root - INFO - Evaluation result: {'acc': 0.8703571428571428, 'micro_p': 0.8703571428571428, 'micro_r': 0.8703571428571428, 'micro_f1': 0.8703571428571429}. 179 | 2021-03-15 09:14:52,355 - root - INFO - Metric micro_f1 current / best: 0.8703571428571429 / 0.8669642857142857 180 | 2021-03-15 09:14:52,355 - root - INFO - Best ckpt and saved. 181 | 2021-03-15 09:14:53,672 - root - INFO - Best micro_f1 on val set: 0.870357 182 | 100% 350/350 [00:51<00:00, 6.76it/s, acc=0.87] 183 | 2021-03-15 09:15:45,670 - root - INFO - Evaluation result: {'acc': 0.8703571428571428, 'micro_p': 0.8703571428571428, 'micro_r': 0.8703571428571428, 'micro_f1': 0.8703571428571429}. 184 | 2021-03-15 09:15:45,670 - root - INFO - Test set results: 185 | 2021-03-15 09:15:45,670 - root - INFO - Accuracy: 0.8703571428571428 186 | 2021-03-15 09:15:45,670 - root - INFO - Micro precision: 0.8703571428571428 187 | 2021-03-15 09:15:45,670 - root - INFO - Micro recall: 0.8703571428571428 188 | 2021-03-15 09:15:45,670 - root - INFO - Micro F1: 0.8703571428571429 189 | 190 | #使用cls的方式, 效果比实体的方式差 191 | python example/train_supervised_bert.py --pretrain_path pretrain/bert-base-uncased --dataset wiki80 --pooler cls --do_train --do_test 192 | 193 | 2021-03-16 01:11:58,524 - root - INFO - === Epoch 0 train === 194 | 100% 3150/3150 [19:11<00:00, 2.73it/s, acc=0.668, loss=1.44] 195 | 2021-03-16 01:31:10,290 - root - INFO - === Epoch 0 val === 196 | 100% 350/350 [00:45<00:00, 7.74it/s, acc=0.798] 197 | 2021-03-16 01:31:55,504 - root - INFO - Evaluation result: {'acc': 0.7983928571428571, 'micro_p': 0.7983928571428571, 'micro_r': 0.7983928571428571, 'micro_f1': 0.7983928571428571}. 198 | 2021-03-16 01:31:55,504 - root - INFO - Metric micro_f1 current / best: 0.7983928571428571 / 0 199 | 2021-03-16 01:31:55,504 - root - INFO - Best ckpt and saved. 200 | 2021-03-16 01:31:56,855 - root - INFO - === Epoch 1 train === 201 | 100% 3150/3150 [19:20<00:00, 2.72it/s, acc=0.876, loss=0.446] 202 | 2021-03-16 01:51:17,022 - root - INFO - === Epoch 1 val === 203 | 100% 350/350 [00:45<00:00, 7.73it/s, acc=0.838] 204 | 2021-03-16 01:52:02,291 - root - INFO - Evaluation result: {'acc': 0.8382142857142857, 'micro_p': 0.8382142857142857, 'micro_r': 0.8382142857142857, 'micro_f1': 0.8382142857142857}. 205 | 2021-03-16 01:52:02,291 - root - INFO - Metric micro_f1 current / best: 0.8382142857142857 / 0.7983928571428571 206 | 2021-03-16 01:52:02,292 - root - INFO - Best ckpt and saved. 207 | 2021-03-16 01:52:03,477 - root - INFO - === Epoch 2 train === 208 | 100% 3150/3150 [19:19<00:00, 2.72it/s, acc=0.923, loss=0.278] 209 | 2021-03-16 02:11:22,763 - root - INFO - === Epoch 2 val === 210 | 100% 350/350 [00:45<00:00, 7.75it/s, acc=0.85] 211 | 2021-03-16 02:12:07,930 - root - INFO - Evaluation result: {'acc': 0.8498214285714286, 'micro_p': 0.8498214285714286, 'micro_r': 0.8498214285714286, 'micro_f1': 0.8498214285714286}. 212 | 2021-03-16 02:12:07,930 - root - INFO - Metric micro_f1 current / best: 0.8498214285714286 / 0.8382142857142857 213 | 2021-03-16 02:12:07,930 - root - INFO - Best ckpt and saved. 214 | 2021-03-16 02:12:09,198 - root - INFO - Best micro_f1 on val set: 0.849821 215 | 100% 350/350 [00:45<00:00, 7.72it/s, acc=0.85] 216 | 2021-03-16 02:12:54,762 - root - INFO - Evaluation result: {'acc': 0.8498214285714286, 'micro_p': 0.8498214285714286, 'micro_r': 0.8498214285714286, 'micro_f1': 0.8498214285714286}. 217 | 2021-03-16 02:12:54,762 - root - INFO - Test set results: 218 | 2021-03-16 02:12:54,763 - root - INFO - Accuracy: 0.8498214285714286 219 | 2021-03-16 02:12:54,763 - root - INFO - Micro precision: 0.8498214285714286 220 | 2021-03-16 02:12:54,763 - root - INFO - Micro recall: 0.8498214285714286 221 | 2021-03-16 02:12:54,763 - root - INFO - Micro F1: 0.8498214285714286 222 | ``` 223 | 224 | 225 | # 测试semeval数据集 226 | python example/train_supervised_bert.py --pretrain_path bert-base-uncased --dataset semeval --pooler entity --do_train --do_test 227 | 2021-08-12 16:27:15,142 - root - INFO - 加载 RE 数据集 ./benchmark/semeval/semeval_train.txt with 6507 行和19 个关系. 228 | 2021-08-12 16:27:15,200 - root - INFO - 加载 RE 数据集 ./benchmark/semeval/semeval_val.txt with 1493 行和19 个关系. 229 | 2021-08-12 16:27:15,306 - root - INFO - 加载 RE 数据集 ./benchmark/semeval/semeval_test.txt with 2717 行和19 个关系. 230 | 2021-08-12 16:27:15,334 - root - INFO - 检测到GPU可用,使用GPU 231 | 2021-08-12 16:27:17,546 - root - INFO - === Epoch 0 train === 232 | 100%|███████████████████████████| 407/407 [00:32<00:00, 12.70it/s, acc=0.574, loss=1.42] 233 | 2021-08-12 16:27:49,597 - root - INFO - === Epoch 0 val === 234 | 评估: 100%|██████████████████████████████████| 94/94 [00:02<00:00, 36.45it/s, acc=0.801] 235 | 2021-08-12 16:27:52,177 - root - INFO - 评估结果 : {'acc': 0.8010716677829873, 'micro_p': 0.8159695817490494, 'micro_r': 0.8730675345809601, 'micro_f1': 0.8435534591194969}. 236 | 2021-08-12 16:27:52,177 - root - INFO - Metric micro_f1 current / best: 0.8435534591194969 / 0 237 | 2021-08-12 16:27:52,177 - root - INFO - 获得了更好的metric 0.8435534591194969,保存模型 238 | 2021-08-12 16:27:52,611 - root - INFO - === Epoch 1 train === 239 | 100%|███████████████████████████| 407/407 [00:31<00:00, 12.73it/s, acc=0.88, loss=0.401] 240 | 2021-08-12 16:28:24,576 - root - INFO - === Epoch 1 val === 241 | 评估: 100%|███████████████████████████████████| 94/94 [00:02<00:00, 36.34it/s, acc=0.85] 242 | 2021-08-12 16:28:27,163 - root - INFO - 评估结果 : {'acc': 0.8499665103817816, 'micro_p': 0.8696682464454977, 'micro_r': 0.8958502847843776, 'micro_f1': 0.882565130260521}. 243 | 2021-08-12 16:28:27,163 - root - INFO - Metric micro_f1 current / best: 0.882565130260521 / 0.8435534591194969 244 | 2021-08-12 16:28:27,164 - root - INFO - 获得了更好的metric 0.882565130260521,保存模型 245 | 2021-08-12 16:28:27,901 - root - INFO - === Epoch 2 train === 246 | 100%|██████████████████████████| 407/407 [00:32<00:00, 12.68it/s, acc=0.956, loss=0.177] 247 | 2021-08-12 16:29:00,008 - root - INFO - === Epoch 2 val === 248 | 评估: 100%|██████████████████████████████████| 94/94 [00:02<00:00, 36.15it/s, acc=0.848] 249 | 2021-08-12 16:29:02,610 - root - INFO - 评估结果 : {'acc': 0.8479571332886805, 'micro_p': 0.8715083798882681, 'micro_r': 0.8885272579332791, 'micro_f1': 0.8799355358581789}. 250 | 2021-08-12 16:29:02,610 - root - INFO - Metric micro_f1 current / best: 0.8799355358581789 / 0.882565130260521 251 | 2021-08-12 16:29:02,610 - root - INFO - Best micro_f1 on val set: 0.882565 252 | 评估: 100%|████████████████████████████████| 170/170 [00:04<00:00, 35.91it/s, acc=0.842] 253 | 2021-08-12 16:29:07,473 - root - INFO - 评估结果 : {'acc': 0.8421052631578947, 'micro_p': 0.8738156761412575, 'micro_r': 0.8965974370304906, 'micro_f1': 0.8850599781897494}. 254 | 2021-08-12 16:29:07,473 - root - INFO - Test set results: 255 | 2021-08-12 16:29:07,473 - root - INFO - Accuracy: 0.8421052631578947 256 | 2021-08-12 16:29:07,473 - root - INFO - Micro precision: 0.8738156761412575 257 | 2021-08-12 16:29:07,473 - root - INFO - Micro recall: 0.8965974370304906 258 | 2021-08-12 16:29:07,473 - root - INFO - Micro F1: 0.8850599781897494 259 | 260 | ## Google Group 261 | 262 | If you want to receive our update news or take part in discussions, please join our [Google Group](https://groups.google.com/forum/#!forum/opennre/join) 263 | -------------------------------------------------------------------------------- /README_myexample.md: -------------------------------------------------------------------------------- 1 | # 已经集成到多任务模型,现有代码暂时不用 2 | 3 | # 生成中文数据,数据格式 4 | ## 例如benchmark/liter 5 | 标签文件 6 | ``` 7 | {"否": 0, "是": 1} 8 | ``` 9 | 具体数据 10 | ``` 11 | { 12 | "text":"出门回来、一定要彻底卸妆 13 | lirosa水霜、冻膜、卸妆啫喱还有香奈儿山茶花洗面奶都是无限回购。", 14 | "h":{ 15 | "name":"lirosa水霜", 16 | "id":"VKZH9J5DW8", 17 | "pos":[ 18 | 13, 19 | 21 20 | ] 21 | }, 22 | "t":{ 23 | "name":"啫喱", 24 | "id":"U7G1VDPYTG", 25 | "pos":[ 26 | 27, 27 | 29 28 | ] 29 | }, 30 | "relation":"否" 31 | } 32 | ``` 33 | 34 | # 模型是从huggface下载好的 35 | #训练模型, 使用macbert模型 36 | python train_supervised_bert.py --pretrain_path pretrain/mac_bert_model --dataset brand --pooler entity --do_train --do_test --batch_size 32 --max_length 256 --max_epoch 10 37 | 38 | #使用中文bert模型 39 | ## 实体形式 40 | 共收集到总的数据条目: 13248, 跳过的空的数据: 153, 非空reuslt的条数3347, 标签为空的数据的条数2,标签的个数统计为Counter({'否': 9051, '是': 4197}) 41 | 训练集数量10598, 测试集数量1326,开发集数量1324 42 | python train_supervised_bert.py --pretrain_path pretrain/bert_model --dataset brand --pooler entity --do_train --do_test --batch_size 32 --max_length 256 --max_epoch 10 43 | 2021-08-16 15:08:52,443 - root - INFO - 参数: 44 | 2021-08-16 15:08:52,443 - root - INFO - pretrain_path: pretrain/bert_model 45 | 2021-08-16 15:08:52,443 - root - INFO - ckpt: brand_pretrain/bert_model_entity 46 | 2021-08-16 15:08:52,444 - root - INFO - pooler: entity 47 | 2021-08-16 15:08:52,444 - root - INFO - do_train: True 48 | 2021-08-16 15:08:52,444 - root - INFO - do_test: True 49 | 2021-08-16 15:08:52,444 - root - INFO - mask_entity: False 50 | 2021-08-16 15:08:52,444 - root - INFO - metric: micro_f1 51 | 2021-08-16 15:08:52,444 - root - INFO - dataset: brand 52 | 2021-08-16 15:08:52,444 - root - INFO - train_file: ./benchmark/brand/brand_train.txt 53 | 2021-08-16 15:08:52,444 - root - INFO - val_file: ./benchmark/brand/brand_val.txt 54 | 2021-08-16 15:08:52,444 - root - INFO - test_file: ./benchmark/brand/brand_test.txt 55 | 2021-08-16 15:08:52,444 - root - INFO - rel2id_file: ./benchmark/brand/brand_rel2id.json 56 | 2021-08-16 15:08:52,444 - root - INFO - batch_size: 32 57 | 2021-08-16 15:08:52,444 - root - INFO - lr: 2e-05 58 | 2021-08-16 15:08:52,444 - root - INFO - max_length: 128 59 | 2021-08-16 15:08:52,444 - root - INFO - max_epoch: 10 60 | 2021-08-16 15:08:52,444 - root - INFO - 加载预训练的 BERT pre-trained checkpoint: pretrain/bert_model 61 | Some weights of the model checkpoint at pretrain/bert_model were not used when initializing BertModel: ['cls.predictions.decoder.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.bias', 'cls.predictions.decoder.weight'] 62 | - This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model). 63 | - This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model). 64 | 2021-08-16 15:08:54,180 - root - INFO - 加载 RE 数据集 ./benchmark/brand/brand_train.txt with 10598 行和2 个关系. 65 | 2021-08-16 15:08:54,298 - root - INFO - 加载 RE 数据集 ./benchmark/brand/brand_val.txt with 1326 行和2 个关系. 66 | 2021-08-16 15:08:54,416 - root - INFO - 加载 RE 数据集 ./benchmark/brand/brand_test.txt with 1324 行和2 个关系. 67 | 2021-08-16 15:08:54,447 - root - INFO - 检测到GPU可用,使用GPU 68 | 2021-08-16 15:08:56,674 - root - INFO - === Epoch 0 train === 69 | 100%|██████████| 332/332 [01:07<00:00, 4.95it/s, acc=0.75, loss=0.528] 70 | 2021-08-16 15:10:03,678 - root - INFO - === Epoch 0 val === 71 | 评估: 100%|██████████| 42/42 [00:04<00:00, 8.71it/s, acc=0.759] 72 | 2021-08-16 15:10:08,500 - root - INFO - 评估结果 : {'acc': 0.7594268476621417, 'micro_p': 0.7594268476621417, 'micro_r': 0.7594268476621417, 'micro_f1': 0.7594268476621419}. 73 | 2021-08-16 15:10:08,500 - root - INFO - Metric micro_f1 current / best: 0.7594268476621419 / 0 74 | 2021-08-16 15:10:08,500 - root - INFO - 获得了更好的metric 0.7594268476621419,保存模型 75 | 2021-08-16 15:10:09,188 - root - INFO - === Epoch 1 train === 76 | 100%|██████████| 332/332 [01:06<00:00, 4.98it/s, acc=0.809, loss=0.429] 77 | 2021-08-16 15:11:15,790 - root - INFO - === Epoch 1 val === 78 | 评估: 100%|██████████| 42/42 [00:04<00:00, 8.73it/s, acc=0.81] 79 | 2021-08-16 15:11:20,604 - root - INFO - 评估结果 : {'acc': 0.8099547511312217, 'micro_p': 0.8099547511312217, 'micro_r': 0.8099547511312217, 'micro_f1': 0.8099547511312217}. 80 | 2021-08-16 15:11:20,604 - root - INFO - Metric micro_f1 current / best: 0.8099547511312217 / 0.7594268476621419 81 | 2021-08-16 15:11:20,604 - root - INFO - 获得了更好的metric 0.8099547511312217,保存模型 82 | 2021-08-16 15:11:21,302 - root - INFO - === Epoch 2 train === 83 | 100%|██████████| 332/332 [01:06<00:00, 4.97it/s, acc=0.84, loss=0.357] 84 | 2021-08-16 15:12:28,068 - root - INFO - === Epoch 2 val === 85 | 评估: 100%|██████████| 42/42 [00:04<00:00, 8.71it/s, acc=0.825] 86 | 2021-08-16 15:12:32,894 - root - INFO - 评估结果 : {'acc': 0.8250377073906485, 'micro_p': 0.8250377073906485, 'micro_r': 0.8250377073906485, 'micro_f1': 0.8250377073906485}. 87 | 2021-08-16 15:12:32,894 - root - INFO - Metric micro_f1 current / best: 0.8250377073906485 / 0.8099547511312217 88 | 2021-08-16 15:12:32,894 - root - INFO - 获得了更好的metric 0.8250377073906485,保存模型 89 | 2021-08-16 15:12:33,554 - root - INFO - === Epoch 3 train === 90 | 100%|██████████| 332/332 [01:06<00:00, 4.96it/s, acc=0.857, loss=0.318] 91 | 2021-08-16 15:13:40,473 - root - INFO - === Epoch 3 val === 92 | 评估: 100%|██████████| 42/42 [00:04<00:00, 8.69it/s, acc=0.807] 93 | 2021-08-16 15:13:45,309 - root - INFO - 评估结果 : {'acc': 0.8069381598793364, 'micro_p': 0.8069381598793364, 'micro_r': 0.8069381598793364, 'micro_f1': 0.8069381598793365}. 94 | 2021-08-16 15:13:45,309 - root - INFO - Metric micro_f1 current / best: 0.8069381598793365 / 0.8250377073906485 95 | 2021-08-16 15:13:45,310 - root - INFO - === Epoch 4 train === 96 | 100%|██████████| 332/332 [01:07<00:00, 4.94it/s, acc=0.865, loss=0.296] 97 | 2021-08-16 15:14:52,566 - root - INFO - === Epoch 4 val === 98 | 评估: 100%|██████████| 42/42 [00:04<00:00, 8.69it/s, acc=0.821] 99 | 2021-08-16 15:14:57,398 - root - INFO - 评估结果 : {'acc': 0.8205128205128205, 'micro_p': 0.8205128205128205, 'micro_r': 0.8205128205128205, 'micro_f1': 0.8205128205128205}. 100 | 2021-08-16 15:14:57,398 - root - INFO - Metric micro_f1 current / best: 0.8205128205128205 / 0.8250377073906485 101 | 2021-08-16 15:14:57,399 - root - INFO - === Epoch 5 train === 102 | 100%|██████████| 332/332 [01:07<00:00, 4.95it/s, acc=0.872, loss=0.28] 103 | 2021-08-16 15:16:04,527 - root - INFO - === Epoch 5 val === 104 | 评估: 100%|██████████| 42/42 [00:04<00:00, 8.70it/s, acc=0.825] 105 | 2021-08-16 15:16:09,357 - root - INFO - 评估结果 : {'acc': 0.8250377073906485, 'micro_p': 0.8250377073906485, 'micro_r': 0.8250377073906485, 'micro_f1': 0.8250377073906485}. 106 | 2021-08-16 15:16:09,357 - root - INFO - Metric micro_f1 current / best: 0.8250377073906485 / 0.8250377073906485 107 | 2021-08-16 15:16:09,358 - root - INFO - === Epoch 6 train === 108 | 100%|██████████| 332/332 [01:07<00:00, 4.93it/s, acc=0.875, loss=0.264] 109 | 2021-08-16 15:17:16,684 - root - INFO - === Epoch 6 val === 110 | 评估: 100%|██████████| 42/42 [00:04<00:00, 8.69it/s, acc=0.811] 111 | 2021-08-16 15:17:21,516 - root - INFO - 评估结果 : {'acc': 0.8107088989441931, 'micro_p': 0.8107088989441931, 'micro_r': 0.8107088989441931, 'micro_f1': 0.8107088989441931}. 112 | 2021-08-16 15:17:21,517 - root - INFO - Metric micro_f1 current / best: 0.8107088989441931 / 0.8250377073906485 113 | 2021-08-16 15:17:21,517 - root - INFO - === Epoch 7 train === 114 | 100%|██████████| 332/332 [01:07<00:00, 4.93it/s, acc=0.878, loss=0.253] 115 | 2021-08-16 15:18:28,924 - root - INFO - === Epoch 7 val === 116 | 评估: 100%|██████████| 42/42 [00:04<00:00, 8.70it/s, acc=0.814] 117 | 2021-08-16 15:18:33,751 - root - INFO - 评估结果 : {'acc': 0.8144796380090498, 'micro_p': 0.8144796380090498, 'micro_r': 0.8144796380090498, 'micro_f1': 0.8144796380090498}. 118 | 2021-08-16 15:18:33,752 - root - INFO - Metric micro_f1 current / best: 0.8144796380090498 / 0.8250377073906485 119 | 2021-08-16 15:18:33,752 - root - INFO - === Epoch 8 train === 120 | 100%|██████████| 332/332 [01:07<00:00, 4.93it/s, acc=0.883, loss=0.242] 121 | 2021-08-16 15:19:41,049 - root - INFO - === Epoch 8 val === 122 | 评估: 100%|██████████| 42/42 [00:04<00:00, 8.70it/s, acc=0.819] 123 | 2021-08-16 15:19:45,879 - root - INFO - 评估结果 : {'acc': 0.8190045248868778, 'micro_p': 0.8190045248868778, 'micro_r': 0.8190045248868778, 'micro_f1': 0.8190045248868778}. 124 | 2021-08-16 15:19:45,879 - root - INFO - Metric micro_f1 current / best: 0.8190045248868778 / 0.8250377073906485 125 | 2021-08-16 15:19:45,879 - root - INFO - === Epoch 9 train === 126 | 100%|██████████| 332/332 [01:07<00:00, 4.93it/s, acc=0.888, loss=0.236] 127 | 2021-08-16 15:20:53,230 - root - INFO - === Epoch 9 val === 128 | 评估: 100%|██████████| 42/42 [00:04<00:00, 8.68it/s, acc=0.808] 129 | 2021-08-16 15:20:58,069 - root - INFO - 评估结果 : {'acc': 0.808446455505279, 'micro_p': 0.808446455505279, 'micro_r': 0.808446455505279, 'micro_f1': 0.808446455505279}. 130 | 2021-08-16 15:20:58,069 - root - INFO - Metric micro_f1 current / best: 0.808446455505279 / 0.8250377073906485 131 | 2021-08-16 15:20:58,069 - root - INFO - Best micro_f1 on val set: 0.825038 132 | 评估: 100%|██████████| 42/42 [00:04<00:00, 8.66it/s, acc=0.826] 133 | 2021-08-16 15:21:03,038 - root - INFO - 评估结果 : {'acc': 0.8262839879154078, 'micro_p': 0.8262839879154078, 'micro_r': 0.8262839879154078, 'micro_f1': 0.8262839879154078}. 134 | 2021-08-16 15:21:03,038 - root - INFO - Test set results: 135 | 2021-08-16 15:21:03,038 - root - INFO - Accuracy: 0.8262839879154078 136 | 2021-08-16 15:21:03,038 - root - INFO - Micro precision: 0.8262839879154078 137 | 2021-08-16 15:21:03,038 - root - INFO - Micro recall: 0.8262839879154078 138 | 2021-08-16 15:21:03,038 - root - INFO - Micro F1: 0.8262839879154078 139 | 运行成功! Step3: 训练并测试BERT模型 140 | 141 | 142 | ## CLS形式 143 | python train_supervised_bert.py --pretrain_path pretrain/bert_model --dataset brand --pooler cls --do_train --do_test --batch_size 32 --max_length 256 --max_epoch 10 144 | 共收集到总的数据条目: 13248, 跳过的空的数据: 153, 非空reuslt的条数3347, 标签为空的数据的条数2,标签的个数统计为Counter({'否': 9051, '是': 4197}) 145 | 训练集数量10598, 测试集数量1326,开发集数量1324 146 | 2021-08-12 17:54:41,666 - root - INFO - 参数: 147 | 2021-08-12 17:54:41,666 - root - INFO - pretrain_path: pretrain/bert_model 148 | 2021-08-12 17:54:41,666 - root - INFO - ckpt: brand_pretrain/bert_model_cls 149 | 2021-08-12 17:54:41,666 - root - INFO - pooler: cls 150 | 2021-08-12 17:54:41,666 - root - INFO - do_train: True 151 | 2021-08-12 17:54:41,666 - root - INFO - do_test: True 152 | 2021-08-12 17:54:41,666 - root - INFO - mask_entity: False 153 | 2021-08-12 17:54:41,666 - root - INFO - metric: micro_f1 154 | 2021-08-12 17:54:41,666 - root - INFO - dataset: brand 155 | 2021-08-12 17:54:41,666 - root - INFO - train_file: ./benchmark/brand/brand_train.txt 156 | 2021-08-12 17:54:41,666 - root - INFO - val_file: ./benchmark/brand/brand_val.txt 157 | 2021-08-12 17:54:41,666 - root - INFO - test_file: ./benchmark/brand/brand_test.txt 158 | 2021-08-12 17:54:41,666 - root - INFO - rel2id_file: ./benchmark/brand/brand_rel2id.json 159 | 2021-08-12 17:54:41,666 - root - INFO - batch_size: 32 160 | 2021-08-12 17:54:41,666 - root - INFO - lr: 2e-05 161 | 2021-08-12 17:54:41,666 - root - INFO - max_length: 128 162 | 2021-08-12 17:54:41,666 - root - INFO - max_epoch: 10 163 | 2021-08-12 17:54:41,666 - root - INFO - 加载预训练的 BERT pre-trained checkpoint: pretrain/bert_model 164 | Some weights of the model checkpoint at pretrain/bert_model were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.decoder.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.bias'] 165 | - This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model). 166 | - This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model). 167 | 2021-08-12 17:54:43,538 - root - INFO - 加载 RE 数据集 ./benchmark/brand/brand_train.txt with 10598 行和2 个关系. 168 | 2021-08-12 17:54:43,656 - root - INFO - 加载 RE 数据集 ./benchmark/brand/brand_val.txt with 1326 行和2 个关系. 169 | 2021-08-12 17:54:43,772 - root - INFO - 加载 RE 数据集 ./benchmark/brand/brand_test.txt with 1324 行和2 个关系. 170 | 2021-08-12 17:54:43,802 - root - INFO - 检测到GPU可用,使用GPU 171 | 2021-08-12 17:54:45,984 - root - INFO - === Epoch 0 train === 172 | 100%|██████████| 332/332 [01:07<00:00, 4.94it/s, acc=0.72, loss=0.581] 173 | 2021-08-12 17:55:53,174 - root - INFO - === Epoch 0 val === 174 | 评估: 100%|██████████| 42/42 [00:04<00:00, 8.74it/s, acc=0.788] 175 | 2021-08-12 17:55:57,981 - root - INFO - 评估结果 : {'acc': 0.7880844645550528, 'micro_p': 0.7880844645550528, 'micro_r': 0.7880844645550528, 'micro_f1': 0.7880844645550528}. 176 | 2021-08-12 17:55:57,981 - root - INFO - Metric micro_f1 current / best: 0.7880844645550528 / 0 177 | 2021-08-12 17:55:57,981 - root - INFO - 获得了更好的metric 0.7880844645550528,保存模型 178 | 2021-08-12 17:55:58,365 - root - INFO - === Epoch 1 train === 179 | 86%|████████▋ | 287/332 [00:58<00:09, 4.81it/s, acc=0.791, loss=0.488] 180 | 2021-08-12 17:58:20,257 - root - INFO - === Epoch 2 val === 181 | 评估: 100%|██████████| 42/42 [00:04<00:00, 8.51it/s, acc=0.818] 182 | 2021-08-12 17:58:25,196 - root - INFO - 评估结果 : {'acc': 0.8182503770739065, 'micro_p': 0.8182503770739065, 'micro_r': 0.8182503770739065, 'micro_f1': 0.8182503770739065}. 183 | 2021-08-12 17:58:25,196 - root - INFO - Metric micro_f1 current / best: 0.8182503770739065 / 0.8107088989441931 184 | 2021-08-12 17:58:25,196 - root - INFO - 获得了更好的metric 0.8182503770739065,保存模型 185 | 2021-08-12 17:58:25,885 - root - INFO - === Epoch 3 train === 186 | 100%|██████████| 332/332 [01:08<00:00, 4.82it/s, acc=0.841, loss=0.371] 187 | 2021-08-12 17:59:34,769 - root - INFO - === Epoch 3 val === 188 | 评估: 100%|██████████| 42/42 [00:05<00:00, 8.39it/s, acc=0.802] 189 | 2021-08-12 17:59:39,776 - root - INFO - 评估结果 : {'acc': 0.801659125188537, 'micro_p': 0.801659125188537, 'micro_r': 0.801659125188537, 'micro_f1': 0.8016591251885369}. 190 | 2021-08-12 17:59:39,776 - root - INFO - Metric micro_f1 current / best: 0.8016591251885369 / 0.8182503770739065 191 | 2021-08-12 17:59:39,777 - root - INFO - === Epoch 4 train === 192 | 100%|██████████| 332/332 [01:09<00:00, 4.81it/s, acc=0.852, loss=0.34] 193 | 2021-08-12 18:00:48,805 - root - INFO - === Epoch 4 val === 194 | 评估: 100%|██████████| 42/42 [00:05<00:00, 7.71it/s, acc=0.822] 195 | 2021-08-12 18:00:54,255 - root - INFO - 评估结果 : {'acc': 0.8220211161387632, 'micro_p': 0.8220211161387632, 'micro_r': 0.8220211161387632, 'micro_f1': 0.8220211161387632}. 196 | 2021-08-12 18:00:54,255 - root - INFO - Metric micro_f1 current / best: 0.8220211161387632 / 0.8182503770739065 197 | 2021-08-12 18:00:54,256 - root - INFO - 获得了更好的metric 0.8220211161387632,保存模型 198 | 2021-08-12 18:00:54,986 - root - INFO - === Epoch 5 train === 199 | 100%|██████████| 332/332 [01:11<00:00, 4.62it/s, acc=0.858, loss=0.311] 200 | 2021-08-12 18:02:06,912 - root - INFO - === Epoch 5 val === 201 | 评估: 100%|██████████| 42/42 [00:04<00:00, 8.71it/s, acc=0.834] 202 | 2021-08-12 18:02:11,735 - root - INFO - 评估结果 : {'acc': 0.8340874811463047, 'micro_p': 0.8340874811463047, 'micro_r': 0.8340874811463047, 'micro_f1': 0.8340874811463046}. 203 | 2021-08-12 18:02:11,735 - root - INFO - Metric micro_f1 current / best: 0.8340874811463046 / 0.8220211161387632 204 | 2021-08-12 18:02:11,735 - root - INFO - 获得了更好的metric 0.8340874811463046,保存模型 205 | 2021-08-12 18:02:12,439 - root - INFO - === Epoch 6 train === 206 | 100%|██████████| 332/332 [01:09<00:00, 4.78it/s, acc=0.863, loss=0.297] 207 | 2021-08-12 18:03:21,880 - root - INFO - === Epoch 6 val === 208 | 评估: 100%|██████████| 42/42 [00:05<00:00, 7.94it/s, acc=0.83] 209 | 2021-08-12 18:03:27,174 - root - INFO - 评估结果 : {'acc': 0.8295625942684767, 'micro_p': 0.8295625942684767, 'micro_r': 0.8295625942684767, 'micro_f1': 0.8295625942684767}. 210 | 2021-08-12 18:03:27,174 - root - INFO - Metric micro_f1 current / best: 0.8295625942684767 / 0.8340874811463046 211 | 2021-08-12 18:03:27,174 - root - INFO - === Epoch 7 train === 212 | 100%|██████████| 332/332 [01:13<00:00, 4.51it/s, acc=0.869, loss=0.283] 213 | 2021-08-12 18:04:40,825 - root - INFO - === Epoch 7 val === 214 | 评估: 100%|██████████| 42/42 [00:04<00:00, 8.56it/s, acc=0.818] 215 | 2021-08-12 18:04:45,732 - root - INFO - 评估结果 : {'acc': 0.8182503770739065, 'micro_p': 0.8182503770739065, 'micro_r': 0.8182503770739065, 'micro_f1': 0.8182503770739065}. 216 | 2021-08-12 18:04:45,732 - root - INFO - Metric micro_f1 current / best: 0.8182503770739065 / 0.8340874811463046 217 | 2021-08-12 18:04:45,733 - root - INFO - === Epoch 8 train === 218 | 100%|██████████| 332/332 [01:09<00:00, 4.79it/s, acc=0.873, loss=0.264] 219 | 2021-08-12 18:05:54,999 - root - INFO - === Epoch 8 val === 220 | 评估: 100%|██████████| 42/42 [00:05<00:00, 8.35it/s, acc=0.824] 221 | 2021-08-12 18:06:00,030 - root - INFO - 评估结果 : {'acc': 0.8242835595776772, 'micro_p': 0.8242835595776772, 'micro_r': 0.8242835595776772, 'micro_f1': 0.8242835595776772}. 222 | 2021-08-12 18:06:00,030 - root - INFO - Metric micro_f1 current / best: 0.8242835595776772 / 0.8340874811463046 223 | 2021-08-12 18:06:00,031 - root - INFO - === Epoch 9 train === 224 | 100%|██████████| 332/332 [01:09<00:00, 4.78it/s, acc=0.878, loss=0.256] 225 | 2021-08-12 18:07:09,461 - root - INFO - === Epoch 9 val === 226 | 评估: 100%|██████████| 42/42 [00:05<00:00, 7.84it/s, acc=0.82] 227 | 2021-08-12 18:07:14,817 - root - INFO - 评估结果 : {'acc': 0.8197586726998491, 'micro_p': 0.8197586726998491, 'micro_r': 0.8197586726998491, 'micro_f1': 0.8197586726998491}. 228 | 2021-08-12 18:07:14,817 - root - INFO - Metric micro_f1 current / best: 0.8197586726998491 / 0.8340874811463046 229 | 2021-08-12 18:07:14,817 - root - INFO - Best micro_f1 on val set: 0.834087 230 | 评估: 100%|██████████| 42/42 [00:05<00:00, 7.61it/s, acc=0.825] 231 | 2021-08-12 18:07:20,456 - root - INFO - 评估结果 : {'acc': 0.824773413897281, 'micro_p': 0.824773413897281, 'micro_r': 0.824773413897281, 'micro_f1': 0.824773413897281}. 232 | 2021-08-12 18:07:20,456 - root - INFO - Test set results: 233 | 2021-08-12 18:07:20,456 - root - INFO - Accuracy: 0.824773413897281 234 | 2021-08-12 18:07:20,456 - root - INFO - Micro precision: 0.824773413897281 235 | 2021-08-12 18:07:20,456 - root - INFO - Micro recall: 0.824773413897281 236 | 2021-08-12 18:07:20,456 - root - INFO - Micro F1: 0.824773413897281 237 | 运行成功! Step3: 训练并测试BERT模型 238 | 239 | 240 | # 模型处理cls类型的训练时的方法 241 | ``` 242 | sent0 = self.tokenizer.tokenize(sentence[:pos_min[0]]) 243 | ent0 = self.tokenizer.tokenize(sentence[pos_min[0]:pos_min[1]]) 244 | sent1 = self.tokenizer.tokenize(sentence[pos_min[1]:pos_max[0]]) 245 | ent1 = self.tokenizer.tokenize(sentence[pos_max[0]:pos_max[1]]) 246 | sent2 = self.tokenizer.tokenize(sentence[pos_max[1]:]) 247 | ent0 = ['[unused0]'] + ent0 + ['[unused1]'] 248 | ent1 = ['[unused2]'] + ent1 + ['[unused3]'] 249 | re_tokens = ['[CLS]'] + sent0 + ent0 + sent1 + ent1 + sent2 + ['[SEP]'] 250 | indexed_tokens = self.tokenizer.convert_tokens_to_ids(re_tokens) 251 | ``` 252 | -------------------------------------------------------------------------------- /benchmark/download_fewrel.sh: -------------------------------------------------------------------------------- 1 | mkdir fewrel 2 | wget -P fewrel https://thunlp.oss-cn-qingdao.aliyuncs.com/opennre/benchmark/fewrel/fewrel_train.txt 3 | wget -P fewrel https://thunlp.oss-cn-qingdao.aliyuncs.com/opennre/benchmark/fewrel/fewrel_train_rel2id.json 4 | wget -P fewrel https://thunlp.oss-cn-qingdao.aliyuncs.com/opennre/benchmark/fewrel/fewrel_val.txt 5 | wget -P fewrel https://thunlp.oss-cn-qingdao.aliyuncs.com/opennre/benchmark/fewrel/fewrel_val_rel2id.json 6 | -------------------------------------------------------------------------------- /benchmark/download_nyt10.sh: -------------------------------------------------------------------------------- 1 | mkdir nyt10 2 | wget -P nyt10 https://thunlp.oss-cn-qingdao.aliyuncs.com/opennre/benchmark/nyt10/nyt10_rel2id.json 3 | wget -P nyt10 https://thunlp.oss-cn-qingdao.aliyuncs.com/opennre/benchmark/nyt10/nyt10_train.txt 4 | wget -P nyt10 https://thunlp.oss-cn-qingdao.aliyuncs.com/opennre/benchmark/nyt10/nyt10_test.txt 5 | -------------------------------------------------------------------------------- /benchmark/download_semeval.sh: -------------------------------------------------------------------------------- 1 | mkdir semeval 2 | wget -P semeval https://thunlp.oss-cn-qingdao.aliyuncs.com/opennre/benchmark/semeval/semeval_rel2id.json 3 | wget -P semeval https://thunlp.oss-cn-qingdao.aliyuncs.com/opennre/benchmark/semeval/semeval_train.txt 4 | wget -P semeval https://thunlp.oss-cn-qingdao.aliyuncs.com/opennre/benchmark/semeval/semeval_val.txt 5 | wget -P semeval https://thunlp.oss-cn-qingdao.aliyuncs.com/opennre/benchmark/semeval/semeval_test.txt 6 | -------------------------------------------------------------------------------- /benchmark/download_wiki80.sh: -------------------------------------------------------------------------------- 1 | mkdir wiki80 2 | wget -P wiki80 https://thunlp.oss-cn-qingdao.aliyuncs.com/opennre/benchmark/wiki80/wiki80_rel2id.json 3 | wget -P wiki80 https://thunlp.oss-cn-qingdao.aliyuncs.com/opennre/benchmark/wiki80/wiki80_train.txt 4 | wget -P wiki80 https://thunlp.oss-cn-qingdao.aliyuncs.com/opennre/benchmark/wiki80/wiki80_val.txt 5 | -------------------------------------------------------------------------------- /example/gen_brand_data.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # @Date : 2021/3/19 10:58 上午 4 | # @File : gen_chinese_data.py 5 | # @Author: johnson 6 | # @Contact : github: johnson7788 7 | # @Desc : 根据label-studio标注的数据, 品牌和需求的关系判断 8 | """ 9 | { 10 | "id": 25, 11 | "data": { 12 | "brand": "修丽可age面霜", 13 | "channel": "redbook", 14 | "requirement": "保湿,抗老", 15 | "text": "我和大家一样都是普通上班族,每天面对屏幕的上班族的姐妹们皮肤总是难免的暗黄,搞得好像一下就老了好几岁的样子。工作 车子 房子各种的压力,每天的保养也不是总有时间和精力去做的。\n直到同事给我推荐了这款修丽可age面霜,每天晚上涂抹一点,半个月就感觉自己的肌肤不像平时那么松弛干燥了,效 果可以说是很不错了。\n正直秋冬,天干物燥,作为一款面霜,它的保湿效果还是很明 显的,细小的皱 纹也能有效的抚平,感觉用了之后整个人会明显觉得更 显年 轻,肤质回到了十八九的样子 甚至会恍然大悟原来之前脸干不单单是因为天气,还因为皮肤老化!因为现在整个底子好了,每天化妆都很服帖,不会再有各种卡粉啊暗沉之类的困扰了 \n真心推荐像我一样的上班族入手这款面霜,毕竟抗老真的很重要,钱要花在刀刃上嘛 " 16 | }, 17 | "completions": [ 18 | { 19 | "created_at": 1627890313, 20 | "id": 25001, 21 | "lead_time": 4.6, 22 | "result": [ 23 | { 24 | "from_name": "label", 25 | "id": "BJPV4AQ0ML", 26 | "to_name": "text", 27 | "type": "labels", 28 | "value": { 29 | "end": 107, 30 | "labels": [ 31 | "品牌" 32 | ], 33 | "start": 99, 34 | "text": "修丽可age面霜" 35 | } 36 | }, 37 | { 38 | "from_name": "label", 39 | "id": "9GU824R5UZ", 40 | "to_name": "text", 41 | "type": "labels", 42 | "value": { 43 | "end": 174, 44 | "labels": [ 45 | "需求" 46 | ], 47 | "start": 172, 48 | "text": "保湿" 49 | } 50 | }, 51 | { 52 | "from_name": "label", 53 | "id": "1G3PAO1G0C", 54 | "to_name": "text", 55 | "type": "labels", 56 | "value": { 57 | "end": 323, 58 | "labels": [ 59 | "需求" 60 | ], 61 | "start": 321, 62 | "text": "抗老" 63 | } 64 | }, 65 | { 66 | "direction": "right", 67 | "from_id": "BJPV4AQ0ML", 68 | "labels": [ 69 | "是" 70 | ], 71 | "to_id": "9GU824R5UZ", 72 | "type": "relation" 73 | } 74 | ] 75 | } 76 | ], 77 | "predictions": [] 78 | } 79 | """ 80 | import collections 81 | import os 82 | import json 83 | import re 84 | import random 85 | 86 | def gen_rel2id(destination='/Users/admin/git/OpenNRE/benchmark/brand/brand_rel2id.json'): 87 | """ 88 | 直接输出关系到id的映射 89 | :param destination: 输出的目标json文件 90 | :return: 91 | """ 92 | rel2id = {"否": 0, "是": 1} 93 | with open(destination, 'w', encoding='utf-8') as f: 94 | json.dump(rel2id, f) 95 | 96 | def gen_data(source_dir, des_dir): 97 | """ 98 | 根据原始目录生成目标训练或测试等文件 99 | :param source_dir: eg: 标注的原始数据目录 100 | :param des_dir: eg: /Users/admin/git/OpenNRE/benchmark/brand 101 | :return: 102 | """ 103 | #保存处理好的数据 104 | data = [] 105 | # 计数,空的result的数据的个数 106 | empty_result_num = 0 107 | #标签是空的数据条数 108 | empty_labels_num = 0 109 | #result不是空的数据的条数 110 | result_num = 0 111 | files = os.listdir(source_dir) 112 | # 过滤出标注的文件 113 | json_files = [f for f in files if f.endswith('.json')] 114 | labels_cnt = collections.Counter() 115 | for jfile in json_files: 116 | jfile_path = os.path.join(source_dir, jfile) 117 | with open(jfile_path, 'r') as f: 118 | json_data = json.load(f) 119 | for d in json_data: 120 | # 包含brand, channel,requirement,和text 121 | data_content = d['data'] 122 | completions = d['completions'] 123 | # 只选取第一个标注的数据 124 | result = completions[0]['result'] 125 | if not result: 126 | # result为空的,过滤掉 127 | empty_result_num += 1 128 | continue 129 | else: 130 | result_num += 1 131 | #解析result,标注数据,包含2种,一个是关键字是品牌或需求,另一个是品牌和需求的关系 132 | brand_requirements = [r for r in result if r.get('from_name')] 133 | # 变成id和其它属性的字典 134 | brand_requirements_id_dict = {br['id']:br['value'] for br in brand_requirements} 135 | relations = [r for r in result if r.get('direction')] 136 | for rel in relations: 137 | # 关系, 是或否 138 | # 如果labels为空,也跳过 139 | if not rel['labels']: 140 | empty_labels_num += 1 141 | continue 142 | relation = rel['labels'][0] 143 | # 每个关系生成一条数据 144 | text = data_content['text'] 145 | # 头部实体的名字 146 | h_id = rel['from_id'] 147 | # 获取头部id对应的名称和位置 148 | h_value = brand_requirements_id_dict[h_id] 149 | h_start = h_value['start'] 150 | h_end = h_value['end'] 151 | h_name = h_value['text'] 152 | # 尾部实体的id 153 | t_id = rel['to_id'] 154 | t_value = brand_requirements_id_dict[t_id] 155 | t_start = t_value['start'] 156 | t_end = t_value['end'] 157 | t_name = t_value['text'] 158 | #校验数据 159 | assert relation in ["是","否"] 160 | # 统计标签 161 | labels_cnt[relation] += 1 162 | one_data = { 163 | 'text': text, 164 | 'h': { 165 | 'name': h_name, 166 | 'id': h_id, 167 | 'pos': [h_start, h_end] 168 | }, 169 | 't': { 170 | 'name': t_name, 171 | 'id': t_id, 172 | 'pos': [t_start, t_end] 173 | }, 174 | 'relation': relation 175 | } 176 | data.append(one_data) 177 | print(f"共收集到总的数据条目: {len(data)}, 跳过的空的数据: {empty_result_num}, 非空reuslt的条数{result_num}, 标签为空的数据的条数{empty_labels_num},标签的个数统计为{labels_cnt}") 178 | train_file = os.path.join(des_dir, 'brand_train.txt') 179 | dev_file = os.path.join(des_dir, 'brand_test.txt') 180 | test_file = os.path.join(des_dir, 'brand_val.txt') 181 | random.seed(6) 182 | random.shuffle(data) 183 | train_num = int(len(data) * 0.8) 184 | dev_num = int(len(data) * 0.1) 185 | train_data = data[:train_num] 186 | dev_data = data[train_num:train_num+dev_num] 187 | test_data = data[train_num+dev_num:] 188 | with open(train_file, 'w', encoding='utf-8') as f: 189 | for d in train_data: 190 | f.write(json.dumps(d,ensure_ascii=False) + '\n') 191 | with open(dev_file, 'w', encoding='utf-8') as f: 192 | for d in dev_data: 193 | f.write(json.dumps(d,ensure_ascii=False)+ '\n') 194 | with open(test_file, 'w', encoding='utf-8') as f: 195 | for d in test_data: 196 | f.write(json.dumps(d,ensure_ascii=False)+ '\n') 197 | print(f"训练集数量{len(train_data)}, 测试集数量{len(test_data)},开发集数量{len(dev_data)}") 198 | 199 | if __name__ == '__main__': 200 | # gen_rel2id() 201 | gen_data(source_dir='/opt/lavector/relation/', des_dir='/Users/admin/git/OpenNRE/benchmark/brand/') -------------------------------------------------------------------------------- /example/gen_chinese_data.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # @Date : 2021/3/19 10:58 上午 4 | # @File : gen_chinese_data.py 5 | # @Author: johnson 6 | # @Contact : github: johnson7788 7 | # @Desc : 根据Chinese-Literature-NER-RE-Dataset提供的数据格式,生成我们需要的训练数据格式 8 | # 由于Chinese-Literature-NER-RE-Dataset是文档级的数据,所以其实需要更高效的训练和预测方法 9 | import os 10 | import json 11 | import re 12 | import random 13 | 14 | def gen_rel2id(train_dir, destination='/Users/admin/git/OpenNRE/benchmark/liter/liter_rel2id.json'): 15 | """ 16 | 根据Chinese-Literature-NER-RE-Dataset的训练目录生成关系到id的映射 17 | :param train_dir: *.ann和*.txt结尾的文件 18 | :param destination: 输出的目标json文件 19 | :return: 20 | """ 21 | relations = [] 22 | files = os.listdir(train_dir) 23 | #过滤出标注的文件 24 | files = [f for f in files if f.endswith('.ann')] 25 | for file in files: 26 | annfile = os.path.join(train_dir,file) 27 | with open(annfile, 'r') as f: 28 | for line in f: 29 | if line.startswith('R'): 30 | line = line.strip() 31 | line_split = re.split('[\t ]', line) 32 | relation = line_split[1] 33 | if relation == 'Coreference': 34 | print(f"文件{annfile},行 {line}是有问题的") 35 | if relation not in relations: 36 | print(f'加入关系: {relation}') 37 | relations.append(relation) 38 | desdir = os.path.dirname(destination) 39 | if not os.path.exists(desdir): 40 | os.makedirs(desdir) 41 | assert len(relations) == 9, "关系必须是9个才对" 42 | rel2id = {rel:idx for idx, rel in enumerate(relations)} 43 | with open(destination, 'w', encoding='utf-8') as f: 44 | json.dump(rel2id, f) 45 | 46 | def gen_data(source_dir, des_dir, mini_data = False, truncate=-1): 47 | """ 48 | 根据原始目录生成目标训练或测试等文件 49 | :param source_dir: eg: /Users/admin/git/Chinese-Literature-NER-RE-Dataset/relation_extraction/Training 50 | :param des_dir: eg: /Users/admin/git/OpenNRE/benchmark/liter 51 | :return: 52 | """ 53 | #保存处理好的数据 54 | data = [] 55 | files = os.listdir(source_dir) 56 | # 过滤出标注的文件 57 | ann_files = [f for f in files if f.endswith('.ann')] 58 | text_files = [f for f in files if f.endswith('.txt')] 59 | #转出成不带文件后缀的key和文件名为value的字典 60 | ann_file_dict = {f.split('.')[0]:f for f in ann_files} 61 | text_file_dict = {f.split('.')[0]: f for f in text_files} 62 | for k, v in ann_file_dict.items(): 63 | if text_file_dict.get(k) is None: 64 | print(f"文件{v} 不存在对应的txt文件,错误") 65 | continue 66 | #开始读取ann 文件 67 | annfile = os.path.join(source_dir, v) 68 | text_name = text_file_dict.get(k) 69 | textfile = os.path.join(source_dir, text_name) 70 | with open(textfile, 'r') as f: 71 | text = "" 72 | text_len = [] 73 | for line in f: 74 | text_len.append(len(line)) 75 | if len(line) == 61: 76 | #固定的行长度是61 77 | line = line.strip() 78 | text += line 79 | # text = f.read() 80 | #保存所有实体 81 | entities = [] 82 | #保存所有关系 83 | rels = [] 84 | with open(annfile, 'r') as f: 85 | for line in f: 86 | line = line.strip() 87 | if line.startswith('R'): 88 | line_split = re.split('[\t ]', line) 89 | assert len(line_split) == 4, f"关系{annfile}的行 {line}不为4项" 90 | rels.append(line_split) 91 | if line.startswith('T'): 92 | line_split = re.split('[\t ]', line) 93 | if len(line_split) == 7: 94 | # 如果不为5,那么是有逗号隔开的,例如 T81 Metric 539 540;541 542 百 鸟 95 | # 只需要T81 Metric 539 540 百 96 | pos_stop = line_split[3].split(';')[0] 97 | line_split = line_split[:3] + [pos_stop] + [line_split[5]] 98 | elif len(line_split) == 5: 99 | pass 100 | else: 101 | raise Exception(f"实体 {annfile} 的行 {line} 不为5项或者7项,有问题,请检查") 102 | #把实体的索引,进行减法,因为每61个字符一行,我们去掉了一部分'\n',所以做减法 103 | pos_start = int(line_split[2]) 104 | pos_stop = int(line_split[3]) 105 | if pos_start > 61: 106 | pos_remind1 = pos_start // 61 107 | pos_start = pos_start -pos_remind1 108 | if pos_stop > 61: 109 | pos_remind2 = pos_stop //61 110 | pos_stop = pos_stop - pos_remind2 111 | line_split = line_split[:2] + [pos_start, pos_stop] + [line_split[-1]] 112 | entities.append(line_split) 113 | #检查实体, 保存成实体id:实体的type,实体start_idx, 实体stop_idx,实体的值 114 | ent_dict = {} 115 | for entity in entities: 116 | entity_id = entity[0] 117 | if ent_dict.get(entity_id) is not None: 118 | print(f"{annfile}: 实体id已经存在过了,冲突的id,请检查 {entity}") 119 | ent_dict[entity_id] = entity[1:] 120 | 121 | #开始分析所有关系 122 | for rel in rels: 123 | relation = rel[1] 124 | arg1, h1_entityid = rel[2].split(':') 125 | assert arg1 == 'Arg1', f"{rel}分隔的首个字符不是Arg1" 126 | #实体1的id处理 127 | h1_entity = ent_dict.get(h1_entityid) 128 | if h1_entity is None: 129 | print(f"关系{rel}中对应的实体id{h1_entityid}是不存在的,请检查") 130 | h1_type,h1_pos_start, h1_pos_stop, h1_entity_value = h1_entity 131 | h1_pos_start = int(h1_pos_start) 132 | h1_pos_stop = int(h1_pos_stop) 133 | arg2, h2_entityid = rel[3].split(':') 134 | assert arg2 == 'Arg2', f"{rel}分隔的首个字符不是Arg2" 135 | #实体2的id处理 136 | h2_entity = ent_dict.get(h2_entityid) 137 | if h2_entity is None: 138 | print(f"关系{rel}中对应的实体id{h2_entityid}是不存在的,请检查") 139 | h2_type, h2_pos_start, h2_pos_stop, h2_entity_value = h2_entity 140 | h2_pos_start = int(h2_pos_start) 141 | h2_pos_stop = int(h2_pos_stop) 142 | # 检查关键字的位置是否匹配 143 | def get_true_pos(text, value, pos1, pos2, rnum=16): 144 | #从上下加8个字符获取真实的位置 145 | index_true_text = text[pos1-rnum:pos2+rnum] 146 | print(f"实体1: {value}位置不匹配, 上下的2个位置是: {index_true_text},尝试修复") 147 | newpos1, newpos2 = pos1, pos2 148 | if value in index_true_text: 149 | sres = re.finditer(re.escape(value), text) 150 | for sv in sres: 151 | if sv.start() > pos1-rnum and sv.end() < pos2+rnum: 152 | newpos1, newpos2 = sv.start(), sv.end() 153 | break 154 | else: 155 | print("通过正则没有匹配到,请检查,用最后一个位置作为索引") 156 | newpos1, newpos2 = sv.start(), sv.end() 157 | else: 158 | print("上下浮动了16个,仍然没有匹配,请检查") 159 | sres = re.finditer(re.escape(value), text) 160 | min_dist = 100 161 | for sv in sres: 162 | min_dist = min(min_dist, sv.start() - pos1, sv.end() - pos2) 163 | if min_dist in [sv.start() - pos1, sv.end() - pos2]: 164 | newpos1, newpos2 = sv.start(), sv.end() 165 | if text[newpos1:newpos2] != value: 166 | assert text[newpos1:newpos2] == value, "仍然是匹配错误的位置,请检查" 167 | return newpos1, newpos2 168 | # 验证下文本中的实体在文档中的位置时正确的 169 | if text[h1_pos_start:h1_pos_stop] != h1_entity_value: 170 | h1_pos_start, h1_pos_stop = get_true_pos(text=text,value=h1_entity_value, pos1=h1_pos_start, pos2=h1_pos_stop) 171 | if text[h2_pos_start:h2_pos_stop] != h2_entity_value: 172 | h2_pos_start, h2_pos_stop = get_true_pos(text=text,value=h2_entity_value, pos1=h2_pos_start, pos2=h2_pos_stop) 173 | 174 | if truncate != -1: 175 | if abs(h1_pos_start - h2_pos_stop) > truncate: 176 | print(f'2个实体间的距离很大,超过了{truncate}长度') 177 | else: 178 | #开始截断数据, 只保留最大长度 179 | add_length = truncate - abs(h1_pos_start - h2_pos_stop) 180 | added = int(add_length/2) 181 | if h1_pos_start < h2_pos_stop: 182 | truncate_start = h1_pos_start - added 183 | truncate_end = h2_pos_stop + added 184 | else: 185 | truncate_start = h2_pos_stop - added 186 | truncate_end = h1_pos_start + added 187 | if truncate_start <0: 188 | truncate_start = 0 189 | truncate_text = text[truncate_start:truncate_end] 190 | else: 191 | truncate_text = text 192 | # 开始整理成一条数据 193 | one_data = { 194 | 'text': truncate_text, 195 | 'h': { 196 | 'name': h1_entity_value, 197 | 'id': h1_entityid, 198 | 'pos': [h1_pos_start, h1_pos_stop] 199 | }, 200 | 't': { 201 | 'name': h2_entity_value, 202 | 'id': h2_entityid, 203 | 'pos': [h2_pos_start, h2_pos_stop] 204 | }, 205 | 'relation': relation 206 | } 207 | 208 | data.append(one_data) 209 | train_file = os.path.join(des_dir, 'liter_train.txt') 210 | dev_file = os.path.join(des_dir, 'liter_test.txt') 211 | test_file = os.path.join(des_dir, 'liter_val.txt') 212 | print(f"一共处理了{len(ann_files)}个文件,生成{len(data)}条数据") 213 | random.shuffle(data) 214 | train_num = int(len(data) * 0.8) 215 | dev_num = int(len(data) * 0.1) 216 | train_data = data[:train_num] 217 | dev_data = data[train_num:train_num+dev_num] 218 | test_data = data[train_num+dev_num:] 219 | if mini_data: 220 | #选择前500条样本测试 221 | train_data = train_data[:500] 222 | dev_data = dev_data[:100] 223 | test_data = test_data[:100] 224 | with open(train_file, 'w', encoding='utf-8') as f: 225 | for d in train_data: 226 | f.write(json.dumps(d) + '\n') 227 | with open(dev_file, 'w', encoding='utf-8') as f: 228 | for d in dev_data: 229 | f.write(json.dumps(d)+ '\n') 230 | with open(test_file, 'w', encoding='utf-8') as f: 231 | for d in test_data: 232 | f.write(json.dumps(d)+ '\n') 233 | print(f"训练集数量{len(train_data)}, 测试集数量{len(test_data)},开发集数量{len(dev_data)}") 234 | 235 | if __name__ == '__main__': 236 | # gen_rel2id(train_dir='/Users/admin/git/Chinese-Literature-NER-RE-Dataset/relation_extraction/Training') 237 | gen_data(source_dir='/Users/admin/git/Chinese-Literature-NER-RE-Dataset/relation_extraction/Training', des_dir='/Users/admin/git/OpenNRE/benchmark/liter', mini_data=False, truncate=196) -------------------------------------------------------------------------------- /example/infer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # @Date : 2021/3/15 10:40 上午 4 | # @File : infer.py 5 | # @Author: johnson 6 | # @Contact : github: johnson7788 7 | # @Desc : 推理模型 8 | 9 | import opennre 10 | 11 | def infer_wiki80_cnn_softmax(): 12 | model = opennre.get_model('wiki80_cnn_softmax') 13 | result = model.infer({ 14 | 'text': 'He was the son of Máel Dúin mac Máele Fithrich, and grandson of the high king Áed Uaridnach (died 612).', 15 | 'h': {'pos': (18, 46)}, 't': {'pos': (78, 91)}}) 16 | print(result) 17 | 18 | 19 | def infer_wiki80_bert_softmax(): 20 | """ 21 | 有一些错误 22 | :return: 23 | """ 24 | model = opennre.get_model('wiki80_bert_softmax') 25 | result = model.infer({ 26 | 'text': 'He was the son of Máel Dúin mac Máele Fithrich, and grandson of the high king Áed Uaridnach (died 612).', 27 | 'h': {'pos': (18, 46)}, 't': {'pos': (78, 91)}}) 28 | print(result) 29 | 30 | 31 | def infer_wiki80_bertentity_softmax(): 32 | model = opennre.get_model('wiki80_bertentity_softmax') 33 | result = model.infer({ 34 | 'text': 'He was the son of Máel Dúin mac Máele Fithrich, and grandson of the high king Áed Uaridnach (died 612).', 35 | 'h': {'pos': (18, 46)}, 't': {'pos': (78, 91)}}) 36 | print(result) 37 | 38 | 39 | def infer_tacred_bertentity_softmax(): 40 | model = opennre.get_model('tacred_bertentity_softmax') 41 | result = model.infer({ 42 | 'text': 'He was the son of Máel Dúin mac Máele Fithrich, and grandson of the high king Áed Uaridnach (died 612).', 43 | 'h': {'pos': (18, 46)}, 't': {'pos': (78, 91)}}) 44 | print(result) 45 | 46 | def infer_tacred_bert_softmax(): 47 | model = opennre.get_model('tacred_bert_softmax') 48 | result = model.infer({ 49 | 'text': 'He was the son of Máel Dúin mac Máele Fithrich, and grandson of the high king Áed Uaridnach (died 612).', 50 | 'h': {'pos': (18, 46)}, 't': {'pos': (78, 91)}}) 51 | print(result) 52 | 53 | if __name__ == '__main__': 54 | infer_wiki80_bert_softmax() 55 | # infer_tacred_bertentity_softmax() 56 | # infer_tacred_bert_softmax() -------------------------------------------------------------------------------- /example/train_bag_pcnn_att.py: -------------------------------------------------------------------------------- 1 | # coding:utf-8 2 | import sys, json 3 | import torch 4 | import os 5 | import numpy as np 6 | import opennre 7 | from opennre import encoder, model, framework 8 | import sys 9 | import os 10 | import argparse 11 | import logging 12 | import random 13 | 14 | def set_seed(seed): 15 | random.seed(seed) 16 | np.random.seed(seed) 17 | torch.manual_seed(seed) 18 | torch.cuda.manual_seed_all(seed) 19 | 20 | parser = argparse.ArgumentParser() 21 | parser.add_argument('--ckpt', default='', 22 | help='Checkpoint name') 23 | parser.add_argument('--only_test', action='store_true', 24 | help='Only run test') 25 | 26 | # Data 27 | parser.add_argument('--metric', default='auc', choices=['micro_f1', 'auc'], 28 | help='Metric for picking up best checkpoint') 29 | parser.add_argument('--dataset', default='none', choices=['none', 'wiki_distant', 'nyt10'], 30 | help='Dataset. If not none, the following args can be ignored') 31 | parser.add_argument('--train_file', default='', type=str, 32 | help='Training data file') 33 | parser.add_argument('--val_file', default='', type=str, 34 | help='Validation data file') 35 | parser.add_argument('--test_file', default='', type=str, 36 | help='Test data file') 37 | parser.add_argument('--rel2id_file', default='', type=str, 38 | help='Relation to ID file') 39 | 40 | # Bag related 41 | parser.add_argument('--bag_size', type=int, default=0, 42 | help='Fixed bag size. If set to 0, use original bag sizes') 43 | 44 | # Hyper-parameters 45 | parser.add_argument('--batch_size', default=160, type=int, 46 | help='Batch size') 47 | parser.add_argument('--lr', default=0.1, type=float, 48 | help='Learning rate') 49 | parser.add_argument('--optim', default='sgd', type=str, 50 | help='Optimizer') 51 | parser.add_argument('--weight_decay', default=1e-5, type=float, 52 | help='Weight decay') 53 | parser.add_argument('--max_length', default=120, type=int, 54 | help='Maximum sentence length') 55 | parser.add_argument('--max_epoch', default=100, type=int, 56 | help='Max number of training epochs') 57 | 58 | # Others 59 | parser.add_argument('--seed', default=42, type=int, 60 | help='Random seed') 61 | 62 | args = parser.parse_args() 63 | 64 | # Set random seed 65 | set_seed(args.seed) 66 | 67 | # Some basic settings 68 | root_path = '.' 69 | sys.path.append(root_path) 70 | if not os.path.exists('ckpt'): 71 | os.mkdir('ckpt') 72 | if len(args.ckpt) == 0: 73 | args.ckpt = '{}_{}'.format(args.dataset, 'pcnn_att') 74 | ckpt = 'ckpt/{}.pth.tar'.format(args.ckpt) 75 | 76 | if args.dataset != 'none': 77 | opennre.download(args.dataset, root_path=root_path) 78 | args.train_file = os.path.join(root_path, 'benchmark', args.dataset, '{}_train.txt'.format(args.dataset)) 79 | args.val_file = os.path.join(root_path, 'benchmark', args.dataset, '{}_val.txt'.format(args.dataset)) 80 | if not os.path.exists(args.val_file): 81 | logging.info("Cannot find the validation file. Use the test file instead.") 82 | args.val_file = os.path.join(root_path, 'benchmark', args.dataset, '{}_test.txt'.format(args.dataset)) 83 | args.test_file = os.path.join(root_path, 'benchmark', args.dataset, '{}_test.txt'.format(args.dataset)) 84 | args.rel2id_file = os.path.join(root_path, 'benchmark', args.dataset, '{}_rel2id.json'.format(args.dataset)) 85 | else: 86 | if not (os.path.exists(args.train_file) and os.path.exists(args.val_file) and os.path.exists(args.test_file) and os.path.exists(args.rel2id_file)): 87 | raise Exception('--train_file, --val_file, --test_file and --rel2id_file are not specified or files do not exist. Or specify --dataset') 88 | 89 | logging.info('Arguments:') 90 | for arg in vars(args): 91 | logging.info(' {}: {}'.format(arg, getattr(args, arg))) 92 | 93 | rel2id = json.load(open(args.rel2id_file)) 94 | 95 | # Download glove 96 | opennre.download('glove', root_path=root_path) 97 | word2id = json.load(open(os.path.join(root_path, 'pretrain/glove/glove.6B.50d_word2id.json'))) 98 | word2vec = np.load(os.path.join(root_path, 'pretrain/glove/glove.6B.50d_mat.npy')) 99 | 100 | # Define the sentence encoder 101 | sentence_encoder = opennre.encoder.PCNNEncoder( 102 | token2id=word2id, 103 | max_length=args.max_length, 104 | word_size=50, 105 | position_size=5, 106 | hidden_size=230, 107 | blank_padding=True, 108 | kernel_size=3, 109 | padding_size=1, 110 | word2vec=word2vec, 111 | dropout=0.5 112 | ) 113 | 114 | # Define the model 115 | model = opennre.model.BagAttention(sentence_encoder, len(rel2id), rel2id) 116 | 117 | # Define the whole training framework 118 | framework = opennre.framework.BagRE( 119 | train_path=args.train_file, 120 | val_path=args.val_file, 121 | test_path=args.test_file, 122 | model=model, 123 | ckpt=ckpt, 124 | batch_size=args.batch_size, 125 | max_epoch=args.max_epoch, 126 | lr=args.lr, 127 | weight_decay=args.weight_decay, 128 | opt=args.optim, 129 | bag_size=args.bag_size) 130 | 131 | # Train the model 132 | if not args.only_test: 133 | framework.train_model(args.metric) 134 | 135 | # Test the model 136 | framework.load_state_dict(torch.load(ckpt)['state_dict']) 137 | result = framework.eval_model(framework.test_loader) 138 | 139 | # Print the result 140 | logging.info('Test set results:') 141 | logging.info('AUC: {}'.format(result['auc'])) 142 | logging.info('Micro F1: {}'.format(result['micro_f1'])) 143 | -------------------------------------------------------------------------------- /example/train_supervised_bert.py: -------------------------------------------------------------------------------- 1 | # coding:utf-8 2 | import torch 3 | import numpy as np 4 | import json 5 | import opennre 6 | from opennre import encoder, model 7 | import sys 8 | import os 9 | import argparse 10 | import logging 11 | 12 | def doargs(): 13 | parser = argparse.ArgumentParser() 14 | parser.add_argument('--pretrain_path', default='pretrain/bert_model', 15 | help='预训练的模型的名字,默认英文的忽略大小写的bert, 或者给定模型下载好的路径,从路径中加载') 16 | parser.add_argument('--ckpt', default='', 17 | help='Checkpoint 位置') 18 | parser.add_argument('--pooler', default='entity', choices=['cls', 'entity'], 19 | help='句子表达,使用cls还是实体的表达') 20 | parser.add_argument('--do_train', action='store_true',help='训练模型') 21 | parser.add_argument('--do_test', action='store_true',help='测试模型') 22 | parser.add_argument('--mask_entity', action='store_true', 23 | help='是否mask实体提及') 24 | 25 | # Data 26 | parser.add_argument('--metric', default='micro_f1', choices=['micro_f1', 'acc'], 27 | help='选择best checkpoint时使用哪个 Metric') 28 | parser.add_argument('--dataset', default='none', choices=['none', 'semeval', 'wiki80', 'tacred', 'liter','brand'], 29 | help='Dataset. 如果数据集不为none,那么需要指定每个单独的训练文件,否则使用几个专用数据集') 30 | parser.add_argument('--train_file', default='', type=str, 31 | help='训练数据集') 32 | parser.add_argument('--val_file', default='', type=str, 33 | help='验证数据集') 34 | parser.add_argument('--test_file', default='', type=str, 35 | help='测试数据集') 36 | parser.add_argument('--rel2id_file', default='', type=str, 37 | help='关系到id的映射文件') 38 | 39 | # Hyper-parameters 40 | parser.add_argument('--batch_size', default=16, type=int, 41 | help='Batch size') 42 | parser.add_argument('--lr', default=2e-5, type=float, 43 | help='Learning rate') 44 | parser.add_argument('--max_length', default=128, type=int, 45 | help='最大序列长度') 46 | parser.add_argument('--max_epoch', default=3, type=int, 47 | help='最大训练的epoch') 48 | 49 | args = parser.parse_args() 50 | return args 51 | 52 | def load_dataset_and_framework(): 53 | # Some basic settings 54 | root_path = '.' 55 | sys.path.append(root_path) 56 | if not os.path.exists('ckpt'): 57 | os.mkdir('ckpt') 58 | if len(args.ckpt) == 0: 59 | args.ckpt = '{}_{}_{}'.format(args.dataset, args.pretrain_path, args.pooler) 60 | ckpt = 'ckpt/{}.pth.tar'.format(args.ckpt) 61 | 62 | if args.dataset != 'none': 63 | opennre.download(args.dataset, root_path=root_path) 64 | args.train_file = os.path.join(root_path, 'benchmark', args.dataset, '{}_train.txt'.format(args.dataset)) 65 | args.val_file = os.path.join(root_path, 'benchmark', args.dataset, '{}_val.txt'.format(args.dataset)) 66 | args.test_file = os.path.join(root_path, 'benchmark', args.dataset, '{}_test.txt'.format(args.dataset)) 67 | if not os.path.exists(args.test_file): 68 | logging.warning("Test file {} does not exist! Use val file instead".format(args.test_file)) 69 | args.test_file = args.val_file 70 | args.rel2id_file = os.path.join(root_path, 'benchmark', args.dataset, '{}_rel2id.json'.format(args.dataset)) 71 | if args.dataset == 'wiki80': 72 | args.metric = 'acc' 73 | else: 74 | args.metric = 'micro_f1' 75 | else: 76 | if not (os.path.exists(args.train_file) and os.path.exists(args.val_file) and os.path.exists( 77 | args.test_file) and os.path.exists(args.rel2id_file)): 78 | raise Exception('--train_file, --val_file, --test_file and --rel2id_file 没有指定或文件不存在,请检查. 或者指定 --dataset') 79 | 80 | logging.info('参数:') 81 | for arg in vars(args): 82 | logging.info(' {}: {}'.format(arg, getattr(args, arg))) 83 | 84 | rel2id = json.load(open(args.rel2id_file)) 85 | 86 | # Define the sentence encoder 87 | if args.pooler == 'entity': 88 | sentence_encoder = opennre.encoder.BERTEntityEncoder( 89 | max_length=args.max_length, 90 | pretrain_path=args.pretrain_path, 91 | mask_entity=args.mask_entity 92 | ) 93 | elif args.pooler == 'cls': 94 | sentence_encoder = opennre.encoder.BERTEncoder( 95 | max_length=args.max_length, 96 | pretrain_path=args.pretrain_path, 97 | mask_entity=args.mask_entity 98 | ) 99 | else: 100 | raise NotImplementedError 101 | 102 | # 初始化softmax模型 103 | model = opennre.model.SoftmaxNN(sentence_encoder, num_class=len(rel2id), rel2id=rel2id) 104 | 105 | # Define the whole training framework 106 | myframe = opennre.framework.SentenceRE( 107 | train_path=args.train_file, 108 | val_path=args.val_file, 109 | test_path=args.test_file, 110 | model=model, 111 | ckpt=ckpt, 112 | batch_size=args.batch_size, 113 | max_epoch=args.max_epoch, 114 | lr=args.lr, 115 | opt='adamw', 116 | parallel=False, # 是否GPU并发 117 | num_workers=0 # Dataloader的进程数,使用并发时使用 118 | ) 119 | return myframe, ckpt 120 | 121 | def dotrain(): 122 | #训练模型 123 | myframe.train_model('micro_f1') 124 | 125 | def dotest(ckpt): 126 | #加载训练好的模型,开始测试 127 | myframe.load_state_dict(torch.load(ckpt)['state_dict']) 128 | result = myframe.eval_model(myframe.test_loader) 129 | 130 | #打印结果 131 | logging.info('Test set results:') 132 | logging.info('Accuracy: {}'.format(result['acc'])) 133 | logging.info('Micro precision: {}'.format(result['micro_p'])) 134 | logging.info('Micro recall: {}'.format(result['micro_r'])) 135 | logging.info('Micro F1: {}'.format(result['micro_f1'])) 136 | 137 | if __name__ == '__main__': 138 | args = doargs() 139 | logfile = "train.log" 140 | logging.basicConfig( 141 | level=logging.DEBUG, 142 | format="%(asctime)s - [%(levelname)s] - %(module)s - %(message)s", 143 | handlers=[ 144 | logging.StreamHandler(), 145 | logging.FileHandler(logfile, mode='w', encoding='utf-8'), 146 | ] 147 | ) 148 | myframe, ckpt = load_dataset_and_framework() 149 | if args.do_train: 150 | dotrain() 151 | if args.do_test: 152 | dotest(ckpt) -------------------------------------------------------------------------------- /example/train_supervised_cnn.py: -------------------------------------------------------------------------------- 1 | # coding:utf-8 2 | import torch 3 | import numpy as np 4 | import json 5 | import opennre 6 | from opennre import encoder, model, framework 7 | import sys 8 | import os 9 | import argparse 10 | import logging 11 | 12 | parser = argparse.ArgumentParser() 13 | parser.add_argument('--ckpt', default='', 14 | help='Checkpoint name') 15 | parser.add_argument('--only_test', action='store_true', 16 | help='Only run test') 17 | 18 | # Data 19 | parser.add_argument('--metric', default='micro_f1', choices=['micro_f1', 'acc'], 20 | help='选择best checkpoint时使用哪个 Metric') 21 | parser.add_argument('--dataset', default='wiki80', choices=['none', 'semeval', 'wiki80', 'tacred', 'liter'], 22 | help='Dataset. 如果数据集不为none,那么需要指定每个单独的训练文件,否则使用几个专用数据集') 23 | parser.add_argument('--train_file', default='', type=str, 24 | help='训练数据集') 25 | parser.add_argument('--val_file', default='', type=str, 26 | help='验证数据集') 27 | parser.add_argument('--test_file', default='', type=str, 28 | help='测试数据集') 29 | parser.add_argument('--rel2id_file', default='', type=str, 30 | help='关系到id的映射文件') 31 | 32 | parser.add_argument('--batch_size', default=16, type=int, 33 | help='Batch size') 34 | parser.add_argument('--lr', default=2e-5, type=float, 35 | help='Learning rate') 36 | parser.add_argument('--max_length', default=128, type=int, 37 | help='最大序列长度') 38 | parser.add_argument('--max_epoch', default=3, type=int, 39 | help='最大训练的epoch') 40 | parser.add_argument('--weight_decay', default=1e-5, type=float, 41 | help='Weight decay') 42 | 43 | args = parser.parse_args() 44 | 45 | # Some basic settings 46 | root_path = '.' 47 | sys.path.append(root_path) 48 | if not os.path.exists('ckpt'): 49 | os.mkdir('ckpt') 50 | if len(args.ckpt) == 0: 51 | args.ckpt = '{}_{}'.format(args.dataset, 'cnn') 52 | ckpt = 'ckpt/{}.pth.tar'.format(args.ckpt) 53 | 54 | if args.dataset != 'none': 55 | opennre.download(args.dataset, root_path=root_path) 56 | args.train_file = os.path.join(root_path, 'benchmark', args.dataset, '{}_train.txt'.format(args.dataset)) 57 | args.val_file = os.path.join(root_path, 'benchmark', args.dataset, '{}_val.txt'.format(args.dataset)) 58 | args.test_file = os.path.join(root_path, 'benchmark', args.dataset, '{}_test.txt'.format(args.dataset)) 59 | args.rel2id_file = os.path.join(root_path, 'benchmark', args.dataset, '{}_rel2id.json'.format(args.dataset)) 60 | if args.dataset == 'wiki80': 61 | args.metric = 'acc' 62 | else: 63 | args.metric = 'micro_f1' 64 | else: 65 | if not (os.path.exists(args.train_file) and os.path.exists(args.val_file) and os.path.exists(args.test_file) and os.path.exists(args.rel2id_file)): 66 | raise Exception('--train_file, --val_file, --test_file and --rel2id_file are not specified or files do not exist. Or specify --dataset') 67 | 68 | logging.info('Arguments:') 69 | for arg in vars(args): 70 | logging.info(' {}: {}'.format(arg, getattr(args, arg))) 71 | 72 | rel2id = json.load(open(args.rel2id_file)) 73 | 74 | # Download glove 75 | opennre.download('glove', root_path=root_path) 76 | word2id = json.load(open(os.path.join(root_path, 'pretrain/glove/glove.6B.50d_word2id.json'))) 77 | word2vec = np.load(os.path.join(root_path, 'pretrain/glove/glove.6B.50d_mat.npy')) 78 | 79 | # Define the sentence encoder 80 | sentence_encoder = opennre.encoder.CNNEncoder( 81 | token2id=word2id, 82 | max_length=args.max_length, 83 | word_size=50, 84 | position_size=5, 85 | hidden_size=230, 86 | blank_padding=True, 87 | kernel_size=3, 88 | padding_size=1, 89 | word2vec=word2vec, 90 | dropout=0.5 91 | ) 92 | 93 | 94 | # Define the model 95 | model = opennre.model.SoftmaxNN(sentence_encoder, len(rel2id), rel2id) 96 | 97 | # Define the whole training framework 98 | framework = opennre.framework.SentenceRE( 99 | train_path=args.train_file, 100 | val_path=args.val_file, 101 | test_path=args.test_file, 102 | model=model, 103 | ckpt=ckpt, 104 | batch_size=args.batch_size, 105 | max_epoch=args.max_epoch, 106 | lr=args.lr, 107 | weight_decay=args.weight_decay, 108 | opt='sgd' 109 | ) 110 | 111 | # Train the model 112 | if not args.only_test: 113 | framework.train_model(args.metric) 114 | 115 | # Test 116 | framework.load_state_dict(torch.load(ckpt)['state_dict']) 117 | result = framework.eval_model(framework.test_loader) 118 | 119 | # Print the result 120 | logging.info('Test set results:') 121 | if args.metric == 'acc': 122 | logging.info('Accuracy: {}'.format(result['acc'])) 123 | else: 124 | logging.info('Micro precision: {}'.format(result['micro_p'])) 125 | logging.info('Micro recall: {}'.format(result['micro_r'])) 126 | logging.info('Micro F1: {}'.format(result['micro_f1'])) 127 | -------------------------------------------------------------------------------- /opennre/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | from .pretrain import check_root, get_model, download, download_pretrain 6 | import logging 7 | import os 8 | 9 | logging.basicConfig(format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', level=os.environ.get("LOGLEVEL", "INFO")) 10 | 11 | def fix_seed(seed=12345): 12 | import torch 13 | import numpy as np 14 | import random 15 | torch.manual_seed(seed) # cpu 16 | torch.cuda.manual_seed(seed) # gpu 17 | np.random.seed(seed) # numpy 18 | random.seed(seed) # random and transforms 19 | torch.backends.cudnn.deterministic=True # cudnn 20 | -------------------------------------------------------------------------------- /opennre/encoder/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | from .cnn_encoder import CNNEncoder 6 | from .pcnn_encoder import PCNNEncoder 7 | from .bert_encoder import BERTEncoder, BERTEntityEncoder 8 | 9 | __all__ = [ 10 | 'CNNEncoder', 11 | 'PCNNEncoder', 12 | 'BERTEncoder', 13 | 'BERTEntityEncoder' 14 | ] -------------------------------------------------------------------------------- /opennre/encoder/base_encoder.py: -------------------------------------------------------------------------------- 1 | import math, logging 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import numpy as np 6 | from ..tokenization import WordTokenizer 7 | 8 | class BaseEncoder(nn.Module): 9 | 10 | def __init__(self, 11 | token2id, 12 | max_length=128, 13 | hidden_size=230, 14 | word_size=50, 15 | position_size=5, 16 | blank_padding=True, 17 | word2vec=None, 18 | mask_entity=False): 19 | """ 20 | Args: 21 | token2id: dictionary of token->idx mapping 22 | max_length: max length of sentence, used for postion embedding 23 | hidden_size: hidden size 24 | word_size: size of word embedding 25 | position_size: size of position embedding 26 | blank_padding: padding for CNN 27 | word2vec: pretrained word2vec numpy 28 | """ 29 | # Hyperparameters 30 | super().__init__() 31 | 32 | self.token2id = token2id 33 | self.max_length = max_length 34 | self.num_token = len(token2id) 35 | self.num_position = max_length * 2 36 | self.mask_entity = mask_entity 37 | 38 | if word2vec is None: 39 | self.word_size = word_size 40 | else: 41 | self.word_size = word2vec.shape[-1] 42 | 43 | self.position_size = position_size 44 | self.hidden_size = hidden_size 45 | self.input_size = word_size + position_size * 2 46 | self.blank_padding = blank_padding 47 | 48 | if not '[UNK]' in self.token2id: 49 | self.token2id['[UNK]'] = len(self.token2id) 50 | self.num_token += 1 51 | if not '[PAD]' in self.token2id: 52 | self.token2id['[PAD]'] = len(self.token2id) 53 | self.num_token += 1 54 | 55 | # Word embedding 56 | self.word_embedding = nn.Embedding(self.num_token, self.word_size) 57 | if word2vec is not None: 58 | logging.info("Initializing word embedding with word2vec.") 59 | word2vec = torch.from_numpy(word2vec) 60 | if self.num_token == len(word2vec) + 2: 61 | unk = torch.randn(1, self.word_size) / math.sqrt(self.word_size) 62 | blk = torch.zeros(1, self.word_size) 63 | self.word_embedding.weight.data.copy_(torch.cat([word2vec, unk, blk], 0)) 64 | else: 65 | self.word_embedding.weight.data.copy_(word2vec) 66 | 67 | # Position Embedding 68 | self.pos1_embedding = nn.Embedding(2 * max_length, self.position_size, padding_idx=0) 69 | self.pos2_embedding = nn.Embedding(2 * max_length, self.position_size, padding_idx=0) 70 | self.tokenizer = WordTokenizer(vocab=self.token2id, unk_token="[UNK]") 71 | 72 | def forward(self, token, pos1, pos2): 73 | """ 74 | Args: 75 | token: (B, L), index of tokens 76 | pos1: (B, L), relative position to head entity 77 | pos2: (B, L), relative position to tail entity 78 | Return: 79 | (B, H), representations for sentences 80 | """ 81 | # Check size of tensors 82 | pass 83 | 84 | def tokenize(self, item): 85 | """ 86 | Args: 87 | item: input instance, including sentence, entity positions, etc. 88 | Return: 89 | index number of tokens and positions 90 | """ 91 | if 'text' in item: 92 | sentence = item['text'] 93 | is_token = False 94 | else: 95 | sentence = item['token'] 96 | is_token = True 97 | pos_head = item['h']['pos'] 98 | pos_tail = item['t']['pos'] 99 | 100 | # Sentence -> token 101 | if not is_token: 102 | if pos_head[0] > pos_tail[0]: 103 | pos_min, pos_max = [pos_tail, pos_head] 104 | rev = True 105 | else: 106 | pos_min, pos_max = [pos_head, pos_tail] 107 | rev = False 108 | sent_0 = self.tokenizer.tokenize(sentence[:pos_min[0]]) 109 | sent_1 = self.tokenizer.tokenize(sentence[pos_min[1]:pos_max[0]]) 110 | sent_2 = self.tokenizer.tokenize(sentence[pos_max[1]:]) 111 | ent_0 = self.tokenizer.tokenize(sentence[pos_min[0]:pos_min[1]]) 112 | ent_1 = self.tokenizer.tokenize(sentence[pos_max[0]:pos_max[1]]) 113 | if self.mask_entity: 114 | ent_0 = ['[UNK]'] 115 | ent_1 = ['[UNK]'] 116 | tokens = sent_0 + ent_0 + sent_1 + ent_1 + sent_2 117 | if rev: 118 | pos_tail = [len(sent_0), len(sent_0) + len(ent_0)] 119 | pos_head = [len(sent_0) + len(ent_0) + len(sent_1), len(sent_0) + len(ent_0) + len(sent_1) + len(ent_1)] 120 | else: 121 | pos_head = [len(sent_0), len(sent_0) + len(ent_0)] 122 | pos_tail = [len(sent_0) + len(ent_0) + len(sent_1), len(sent_0) + len(ent_0) + len(sent_1) + len(ent_1)] 123 | else: 124 | tokens = sentence 125 | 126 | # Token -> index 127 | if self.blank_padding: 128 | indexed_tokens = self.tokenizer.convert_tokens_to_ids(tokens, self.max_length, self.token2id['[PAD]'], self.token2id['[UNK]']) 129 | else: 130 | indexed_tokens = self.tokenizer.convert_tokens_to_ids(tokens, unk_id = self.token2id['[UNK]']) 131 | 132 | # Position -> index 133 | pos1 = [] 134 | pos2 = [] 135 | pos1_in_index = min(pos_head[0], self.max_length) 136 | pos2_in_index = min(pos_tail[0], self.max_length) 137 | for i in range(len(tokens)): 138 | pos1.append(min(i - pos1_in_index + self.max_length, 2 * self.max_length - 1)) 139 | pos2.append(min(i - pos2_in_index + self.max_length, 2 * self.max_length - 1)) 140 | 141 | if self.blank_padding: 142 | while len(pos1) < self.max_length: 143 | pos1.append(0) 144 | while len(pos2) < self.max_length: 145 | pos2.append(0) 146 | indexed_tokens = indexed_tokens[:self.max_length] 147 | pos1 = pos1[:self.max_length] 148 | pos2 = pos2[:self.max_length] 149 | 150 | indexed_tokens = torch.tensor(indexed_tokens).long().unsqueeze(0) # (1, L) 151 | pos1 = torch.tensor(pos1).long().unsqueeze(0) # (1, L) 152 | pos2 = torch.tensor(pos2).long().unsqueeze(0) # (1, L) 153 | 154 | return indexed_tokens, pos1, pos2 155 | -------------------------------------------------------------------------------- /opennre/encoder/bert_encoder.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import torch 3 | import torch.nn as nn 4 | from transformers import BertModel, BertTokenizer 5 | from .base_encoder import BaseEncoder 6 | 7 | class BERTEncoder(nn.Module): 8 | def __init__(self, max_length, pretrain_path, blank_padding=True, mask_entity=False): 9 | """ 10 | Args: 11 | max_length: max length of sentence 12 | pretrain_path: path of pretrain model 13 | """ 14 | super().__init__() 15 | self.max_length = max_length 16 | self.blank_padding = blank_padding 17 | self.hidden_size = 768 18 | self.mask_entity = mask_entity 19 | logging.info(f'加载预训练的 BERT pre-trained checkpoint: {pretrain_path}') 20 | self.bert = BertModel.from_pretrained(pretrain_path) 21 | self.tokenizer = BertTokenizer.from_pretrained(pretrain_path) 22 | 23 | def forward(self, token, att_mask): 24 | """ 25 | Args: 26 | token: (B, L), index of tokens 27 | att_mask: (B, L), attention mask (1 for contents and 0 for padding) 28 | Return: 29 | (B, H), representations for sentences 30 | """ 31 | _, x = self.bert(token, attention_mask=att_mask, return_dict=False) 32 | return x 33 | 34 | def tokenize(self, item): 35 | """ 36 | Args: 37 | item: data instance containing 'text' / 'token', 'h' and 't' 38 | Return: 39 | Name of the relation of the sentence 40 | """ 41 | # Sentence -> token 42 | if 'text' in item: 43 | sentence = item['text'] 44 | is_token = False 45 | else: 46 | sentence = item['token'] 47 | is_token = True 48 | pos_head = item['h']['pos'] 49 | pos_tail = item['t']['pos'] 50 | 51 | pos_min = pos_head 52 | pos_max = pos_tail 53 | if pos_head[0] > pos_tail[0]: 54 | pos_min = pos_tail 55 | pos_max = pos_head 56 | rev = True 57 | else: 58 | rev = False 59 | 60 | if not is_token: 61 | sent0 = self.tokenizer.tokenize(sentence[:pos_min[0]]) 62 | ent0 = self.tokenizer.tokenize(sentence[pos_min[0]:pos_min[1]]) 63 | sent1 = self.tokenizer.tokenize(sentence[pos_min[1]:pos_max[0]]) 64 | ent1 = self.tokenizer.tokenize(sentence[pos_max[0]:pos_max[1]]) 65 | sent2 = self.tokenizer.tokenize(sentence[pos_max[1]:]) 66 | else: 67 | sent0 = self.tokenizer.tokenize(' '.join(sentence[:pos_min[0]])) 68 | ent0 = self.tokenizer.tokenize(' '.join(sentence[pos_min[0]:pos_min[1]])) 69 | sent1 = self.tokenizer.tokenize(' '.join(sentence[pos_min[1]:pos_max[0]])) 70 | ent1 = self.tokenizer.tokenize(' '.join(sentence[pos_max[0]:pos_max[1]])) 71 | sent2 = self.tokenizer.tokenize(' '.join(sentence[pos_max[1]:])) 72 | 73 | if self.mask_entity: 74 | ent0 = ['[unused4]'] if not rev else ['[unused5]'] 75 | ent1 = ['[unused5]'] if not rev else ['[unused4]'] 76 | else: 77 | ent0 = ['[unused0]'] + ent0 + ['[unused1]'] if not rev else ['[unused2]'] + ent0 + ['[unused3]'] 78 | ent1 = ['[unused2]'] + ent1 + ['[unused3]'] if not rev else ['[unused0]'] + ent1 + ['[unused1]'] 79 | 80 | re_tokens = ['[CLS]'] + sent0 + ent0 + sent1 + ent1 + sent2 + ['[SEP]'] 81 | 82 | indexed_tokens = self.tokenizer.convert_tokens_to_ids(re_tokens) 83 | avai_len = len(indexed_tokens) 84 | 85 | # Padding 86 | if self.blank_padding: 87 | while len(indexed_tokens) < self.max_length: 88 | indexed_tokens.append(0) # 0 is id for [PAD] 89 | indexed_tokens = indexed_tokens[:self.max_length] 90 | indexed_tokens = torch.tensor(indexed_tokens).long().unsqueeze(0) # (1, L) 91 | 92 | # Attention mask 93 | att_mask = torch.zeros(indexed_tokens.size()).long() # (1, L) 94 | att_mask[0, :avai_len] = 1 95 | 96 | return indexed_tokens, att_mask 97 | 98 | 99 | class BERTEntityEncoder(nn.Module): 100 | def __init__(self, max_length, pretrain_path, blank_padding=True, mask_entity=False): 101 | """ 102 | 加载huggface的预训练模型,初始化tokenizer 103 | Args: 104 | max_length: 最大序列长度 105 | pretrain_path: 预训练模型的路径 106 | blank_padding: bool 107 | mask_entity: bool, 是否mask掉实体,如果mask掉实体,那么实体的名称就会用特殊的字符表示,但是模型性能会下降 108 | """ 109 | super().__init__() 110 | self.max_length = max_length 111 | self.blank_padding = blank_padding 112 | self.hidden_size = 768 * 2 113 | self.mask_entity = mask_entity 114 | logging.info(f'加载预训练的 BERT pre-trained checkpoint: {pretrain_path}') 115 | self.bert = BertModel.from_pretrained(pretrain_path) 116 | self.tokenizer = BertTokenizer.from_pretrained(pretrain_path) 117 | self.linear = nn.Linear(self.hidden_size, self.hidden_size) 118 | 119 | def forward(self, token, att_mask, pos1, pos2): 120 | """ 121 | Args: 122 | token: (Batch_size, seq_length), index of tokens, 句子和实体拼接后的token 123 | att_mask: (Batch_size, seq_length), attention mask (1 for contents and 0 for padding) 124 | pos1: (Batch_size, 1), position of the head entity starter, [batch_size, 1] 实体1的开始位置 125 | pos2: (Batch_size, 1), position of the tail entity starter, [batch_size, 1] 实体2的结束位置 126 | Return: 127 | (B, 2H), representations for sentences 128 | """ 129 | # hidden [batch_size, seq_len, output_demision], 先经过bert 130 | hidden, _ = self.bert(token, attention_mask=att_mask, return_dict=False,) 131 | # 初始化一个向量 onehot_head shape [batch_size, seq_len] 132 | onehot_head = torch.zeros(hidden.size()[:2]).float().to(hidden.device) # (B, L) 133 | onehot_tail = torch.zeros(hidden.size()[:2]).float().to(hidden.device) # (B, L) 134 | #获取实体位置的向量, 135 | onehot_head = onehot_head.scatter_(1, pos1, 1) 136 | onehot_tail = onehot_tail.scatter_(1, pos2, 1) 137 | #head_hidden,tail_hidden --> [batch_size, hidden_demision], [16,768] 138 | head_hidden = (onehot_head.unsqueeze(2) * hidden).sum(1) # (B, H) 139 | tail_hidden = (onehot_tail.unsqueeze(2) * hidden).sum(1) # (B, H) 140 | #把实体的头和尾拼接起来, x shape (Batch_size, 2*hidden_demision), 放入线性模型 141 | x = torch.cat([head_hidden, tail_hidden], 1) 142 | x = self.linear(x) 143 | return x 144 | 145 | def tokenize(self, item): 146 | """ 147 | Args: 148 | item: 一条数据,包括 'text' 或 'token', 'h' and 't', eg: {'token': ['It', 'then', 'enters', 'the', 'spectacular', 'Clydach', 'Gorge', ',', 'dropping', 'about', 'to', 'Gilwern', 'and', 'its', 'confluence', 'with', 'the', 'River', 'Usk', 'Ordnance', 'Survey', 'Explorer', 'map', 'OL13', ',', '"', 'Brecon', 'Beacons', 'National', 'Park', ':', 'eastern', 'area', '"', '.'], 'h': {'name': 'gilwern', 'id': 'Q5562649', 'pos': [11, 12]}, 't': {'name': 'river usk', 'id': 'Q19699', 'pos': [17, 19]}, 'relation': 'located in or next to body of water'} 149 | Return: 150 | indexed_tokens,整理好的id, att_mask对应的mask, pos1实体1的起始位置, pos2实体2的结束位置 151 | """ 152 | # Sentence -> token, text表示文本没有进行token,是一个完整的句子,token表示文本token过了,是一个列表了 153 | if 'text' in item: 154 | sentence = item['text'] 155 | is_token = False 156 | else: 157 | sentence = item['token'] 158 | is_token = True 159 | # pos_head 第一个实体的位置 eg: [11, 12], pos_tail第二个实体的位置 160 | pos_head = item['h']['pos'] 161 | pos_tail = item['t']['pos'] 162 | pos_min = pos_head 163 | pos_max = pos_tail 164 | #确定哪个实体在句子的前面和后面 165 | if pos_head[0] > pos_tail[0]: 166 | pos_min = pos_tail 167 | pos_max = pos_head 168 | rev = True 169 | else: 170 | rev = False 171 | 172 | if not is_token: 173 | sent0 = self.tokenizer.tokenize(sentence[:pos_min[0]]) 174 | ent0 = self.tokenizer.tokenize(sentence[pos_min[0]:pos_min[1]]) 175 | sent1 = self.tokenizer.tokenize(sentence[pos_min[1]:pos_max[0]]) 176 | ent1 = self.tokenizer.tokenize(sentence[pos_max[0]:pos_max[1]]) 177 | sent2 = self.tokenizer.tokenize(sentence[pos_max[1]:]) 178 | else: 179 | #sent0 是句子到第一个实体的单词的位置, eg: ['it', 'then', 'enters', 'the', 'spectacular', 'cl', '##yd', '##ach', 'gorge', ',', 'dropping', 'about', 'to'] 180 | sent0 = self.tokenizer.tokenize(' '.join(sentence[:pos_min[0]])) 181 | # ent0实体的tokenizer, eg: ['gil', '##wer', '##n'] 182 | ent0 = self.tokenizer.tokenize(' '.join(sentence[pos_min[0]:pos_min[1]])) 183 | # sent1是第一个实体到第二个实体之间的句子, eg: ['and', 'its', 'confluence', 'with', 'the'] 184 | sent1 = self.tokenizer.tokenize(' '.join(sentence[pos_min[1]:pos_max[0]])) 185 | #是第二个实体, eg: ['river', 'us', '##k'] 186 | ent1 = self.tokenizer.tokenize(' '.join(sentence[pos_max[0]:pos_max[1]])) 187 | #第二个实体到句子末尾 eg: ['ordnance', 'survey', 'explorer', 'map', 'ol', '##13', ',', '"', 'br', '##ec', '##on', 'beacon', '##s', 'national', 'park', ':', 'eastern', 'area', '"', '.'] 188 | sent2 = self.tokenizer.tokenize(' '.join(sentence[pos_max[1]:])) 189 | 190 | if self.mask_entity: 191 | ent0 = ['[unused4]'] if not rev else ['[unused5]'] 192 | ent1 = ['[unused5]'] if not rev else ['[unused4]'] 193 | else: 194 | # 不mask实体,用特殊的mask围住实体 195 | ent0 = ['[unused0]'] + ent0 + ['[unused1]'] if not rev else ['[unused2]'] + ent0 + ['[unused3]'] 196 | ent1 = ['[unused2]'] + ent1 + ['[unused3]'] if not rev else ['[unused0]'] + ent1 + ['[unused1]'] 197 | #在句子开头和末尾加上special token 198 | re_tokens = ['[CLS]'] + sent0 + ent0 + sent1 + ent1 + sent2 + ['[SEP]'] 199 | # 实体的位置也改变了 200 | pos1 = 1 + len(sent0) if not rev else 1 + len(sent0 + ent0 + sent1) 201 | pos2 = 1 + len(sent0 + ent0 + sent1) if not rev else 1 + len(sent0) 202 | pos1 = min(self.max_length - 1, pos1) 203 | pos2 = min(self.max_length - 1, pos2) 204 | #把token转换成id 205 | indexed_tokens = self.tokenizer.convert_tokens_to_ids(re_tokens) 206 | avai_len = len(indexed_tokens) 207 | 208 | # Position, pos1 eg: tensor([[2]]) 209 | pos1 = torch.tensor([[pos1]]).long() 210 | pos2 = torch.tensor([[pos2]]).long() 211 | 212 | # Padding, 如果少于长度,开始padding,多余长度,截断 213 | if self.blank_padding: 214 | while len(indexed_tokens) < self.max_length: 215 | indexed_tokens.append(0) # 0 is id for [PAD] 216 | indexed_tokens = indexed_tokens[:self.max_length] 217 | indexed_tokens = torch.tensor(indexed_tokens).long().unsqueeze(0) # shape,[1, max_seq_length] torch.Size([1, 128]) 218 | 219 | # Attention mask, 只有真实的长度为1,其它的地方为0 220 | att_mask = torch.zeros(indexed_tokens.size()).long() # (1, L) 221 | att_mask[0, :avai_len] = 1 222 | 223 | return indexed_tokens, att_mask, pos1, pos2 224 | -------------------------------------------------------------------------------- /opennre/encoder/cnn_encoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from ..module.nn import CNN 6 | from ..module.pool import MaxPool 7 | from .base_encoder import BaseEncoder 8 | 9 | class CNNEncoder(BaseEncoder): 10 | 11 | def __init__(self, 12 | token2id, 13 | max_length=128, 14 | hidden_size=230, 15 | word_size=50, 16 | position_size=5, 17 | blank_padding=True, 18 | word2vec=None, 19 | kernel_size=3, 20 | padding_size=1, 21 | dropout=0, 22 | activation_function=F.relu, 23 | mask_entity=False): 24 | """ 25 | Args: 26 | token2id: dictionary of token->idx mapping 27 | max_length: max length of sentence, used for postion embedding 28 | hidden_size: hidden size 29 | word_size: size of word embedding 30 | position_size: size of position embedding 31 | blank_padding: padding for CNN 32 | word2vec: pretrained word2vec numpy 33 | kernel_size: kernel_size size for CNN 34 | padding_size: padding_size for CNN 35 | """ 36 | # Hyperparameters 37 | super(CNNEncoder, self).__init__(token2id, max_length, hidden_size, word_size, position_size, blank_padding, word2vec, mask_entity=mask_entity) 38 | self.drop = nn.Dropout(dropout) 39 | self.kernel_size = kernel_size 40 | self.padding_size = padding_size 41 | self.act = activation_function 42 | 43 | self.conv = nn.Conv1d(self.input_size, self.hidden_size, self.kernel_size, padding=self.padding_size) 44 | self.pool = nn.MaxPool1d(self.max_length) 45 | 46 | def forward(self, token, pos1, pos2): 47 | """ 48 | Args: 49 | token: (B, L), index of tokens 50 | pos1: (B, L), relative position to head entity 51 | pos2: (B, L), relative position to tail entity 52 | Return: 53 | (B, EMBED), representations for sentences 54 | """ 55 | # Check size of tensors 56 | if len(token.size()) != 2 or token.size() != pos1.size() or token.size() != pos2.size(): 57 | raise Exception("Size of token, pos1 ans pos2 should be (B, L)") 58 | x = torch.cat([self.word_embedding(token), 59 | self.pos1_embedding(pos1), 60 | self.pos2_embedding(pos2)], 2) # (B, L, EMBED) 61 | x = x.transpose(1, 2) # (B, EMBED, L) 62 | x = self.act(self.conv(x)) # (B, H, L) 63 | x = self.pool(x).squeeze(-1) 64 | x = self.drop(x) 65 | return x 66 | 67 | def tokenize(self, item): 68 | return super().tokenize(item) 69 | -------------------------------------------------------------------------------- /opennre/encoder/pcnn_encoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from ..module.nn import CNN 6 | from ..module.pool import MaxPool 7 | from .base_encoder import BaseEncoder 8 | 9 | from nltk import word_tokenize 10 | 11 | class PCNNEncoder(BaseEncoder): 12 | 13 | def __init__(self, 14 | token2id, 15 | max_length=128, 16 | hidden_size=230, 17 | word_size=50, 18 | position_size=5, 19 | blank_padding=True, 20 | word2vec=None, 21 | kernel_size=3, 22 | padding_size=1, 23 | dropout=0.0, 24 | activation_function=F.relu, 25 | mask_entity=False): 26 | """ 27 | Args: 28 | token2id: dictionary of token->idx mapping 29 | max_length: max length of sentence, used for postion embedding 30 | hidden_size: hidden size 31 | word_size: size of word embedding 32 | position_size: size of position embedding 33 | blank_padding: padding for CNN 34 | word2vec: pretrained word2vec numpy 35 | kernel_size: kernel_size size for CNN 36 | padding_size: padding_size for CNN 37 | """ 38 | # hyperparameters 39 | super().__init__(token2id, max_length, hidden_size, word_size, position_size, blank_padding, word2vec, mask_entity=mask_entity) 40 | self.drop = nn.Dropout(dropout) 41 | self.kernel_size = kernel_size 42 | self.padding_size = padding_size 43 | self.act = activation_function 44 | 45 | self.conv = nn.Conv1d(self.input_size, self.hidden_size, self.kernel_size, padding=self.padding_size) 46 | self.pool = nn.MaxPool1d(self.max_length) 47 | self.mask_embedding = nn.Embedding(4, 3) 48 | self.mask_embedding.weight.data.copy_(torch.FloatTensor([[0, 0, 0], [1, 0, 0], [0, 1, 0], [0, 0, 1]])) 49 | self.mask_embedding.weight.requires_grad = False 50 | self._minus = -100 51 | 52 | self.hidden_size *= 3 53 | 54 | def forward(self, token, pos1, pos2, mask): 55 | """ 56 | Args: 57 | token: (B, L), index of tokens 58 | pos1: (B, L), relative position to head entity 59 | pos2: (B, L), relative position to tail entity 60 | Return: 61 | (B, EMBED), representations for sentences 62 | """ 63 | # Check size of tensors 64 | if len(token.size()) != 2 or token.size() != pos1.size() or token.size() != pos2.size(): 65 | raise Exception("Size of token, pos1 ans pos2 should be (B, L)") 66 | x = torch.cat([self.word_embedding(token), 67 | self.pos1_embedding(pos1), 68 | self.pos2_embedding(pos2)], 2) # (B, L, EMBED) 69 | x = x.transpose(1, 2) # (B, EMBED, L) 70 | x = self.conv(x) # (B, H, L) 71 | 72 | mask = 1 - self.mask_embedding(mask).transpose(1, 2) # (B, L) -> (B, L, 3) -> (B, 3, L) 73 | pool1 = self.pool(self.act(x + self._minus * mask[:, 0:1, :])) # (B, H, 1) 74 | pool2 = self.pool(self.act(x + self._minus * mask[:, 1:2, :])) 75 | pool3 = self.pool(self.act(x + self._minus * mask[:, 2:3, :])) 76 | x = torch.cat([pool1, pool2, pool3], 1) # (B, 3H, 1) 77 | x = x.squeeze(2) # (B, 3H) 78 | x = self.drop(x) 79 | 80 | return x 81 | 82 | def tokenize(self, item): 83 | """ 84 | Args: 85 | sentence: string, the input sentence 86 | pos_head: [start, end], position of the head entity 87 | pos_end: [start, end], position of the tail entity 88 | is_token: if is_token == True, sentence becomes an array of token 89 | Return: 90 | Name of the relation of the sentence 91 | """ 92 | if 'text' in item: 93 | sentence = item['text'] 94 | is_token = False 95 | else: 96 | sentence = item['token'] 97 | is_token = True 98 | pos_head = item['h']['pos'] 99 | pos_tail = item['t']['pos'] 100 | 101 | # Sentence -> token 102 | if not is_token: 103 | if pos_head[0] > pos_tail[0]: 104 | pos_min, pos_max = [pos_tail, pos_head] 105 | rev = True 106 | else: 107 | pos_min, pos_max = [pos_head, pos_tail] 108 | rev = False 109 | sent_0 = self.tokenizer.tokenize(sentence[:pos_min[0]]) 110 | sent_1 = self.tokenizer.tokenize(sentence[pos_min[1]:pos_max[0]]) 111 | sent_2 = self.tokenizer.tokenize(sentence[pos_max[1]:]) 112 | ent_0 = self.tokenizer.tokenize(sentence[pos_min[0]:pos_min[1]]) 113 | ent_1 = self.tokenizer.tokenize(sentence[pos_max[0]:pos_max[1]]) 114 | if self.mask_entity: 115 | ent_0 = ['[UNK]'] 116 | ent_1 = ['[UNK]'] 117 | tokens = sent_0 + ent_0 + sent_1 + ent_1 + sent_2 118 | if rev: 119 | pos_tail = [len(sent_0), len(sent_0) + len(ent_0)] 120 | pos_head = [len(sent_0) + len(ent_0) + len(sent_1), len(sent_0) + len(ent_0) + len(sent_1) + len(ent_1)] 121 | else: 122 | pos_head = [len(sent_0), len(sent_0) + len(ent_0)] 123 | pos_tail = [len(sent_0) + len(ent_0) + len(sent_1), len(sent_0) + len(ent_0) + len(sent_1) + len(ent_1)] 124 | else: 125 | tokens = sentence 126 | 127 | # Token -> index 128 | if self.blank_padding: 129 | indexed_tokens = self.tokenizer.convert_tokens_to_ids(tokens, self.max_length, self.token2id['[PAD]'], self.token2id['[UNK]']) 130 | else: 131 | indexed_tokens = self.tokenizer.convert_tokens_to_ids(tokens, unk_id = self.token2id['[UNK]']) 132 | 133 | # Position -> index 134 | pos1 = [] 135 | pos2 = [] 136 | pos1_in_index = min(pos_head[0], self.max_length) 137 | pos2_in_index = min(pos_tail[0], self.max_length) 138 | for i in range(len(tokens)): 139 | pos1.append(min(i - pos1_in_index + self.max_length, 2 * self.max_length - 1)) 140 | pos2.append(min(i - pos2_in_index + self.max_length, 2 * self.max_length - 1)) 141 | 142 | if self.blank_padding: 143 | while len(pos1) < self.max_length: 144 | pos1.append(0) 145 | while len(pos2) < self.max_length: 146 | pos2.append(0) 147 | indexed_tokens = indexed_tokens[:self.max_length] 148 | pos1 = pos1[:self.max_length] 149 | pos2 = pos2[:self.max_length] 150 | 151 | indexed_tokens = torch.tensor(indexed_tokens).long().unsqueeze(0) # (1, L) 152 | pos1 = torch.tensor(pos1).long().unsqueeze(0) # (1, L) 153 | pos2 = torch.tensor(pos2).long().unsqueeze(0) # (1, L) 154 | 155 | # Mask 156 | mask = [] 157 | pos_min = min(pos1_in_index, pos2_in_index) 158 | pos_max = max(pos1_in_index, pos2_in_index) 159 | for i in range(len(tokens)): 160 | if i <= pos_min: 161 | mask.append(1) 162 | elif i <= pos_max: 163 | mask.append(2) 164 | else: 165 | mask.append(3) 166 | # Padding 167 | if self.blank_padding: 168 | while len(mask) < self.max_length: 169 | mask.append(0) 170 | mask = mask[:self.max_length] 171 | 172 | mask = torch.tensor(mask).long().unsqueeze(0) # (1, L) 173 | return indexed_tokens, pos1, pos2, mask 174 | -------------------------------------------------------------------------------- /opennre/framework/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | from .data_loader import SentenceREDataset, SentenceRELoader, BagREDataset, BagRELoader 6 | from .sentence_re import SentenceRE 7 | from .bag_re import BagRE 8 | 9 | __all__ = [ 10 | 'SentenceREDataset', 11 | 'SentenceRELoader', 12 | 'SentenceRE', 13 | 'BagRE', 14 | 'BagREDataset', 15 | 'BagRELoader' 16 | ] 17 | -------------------------------------------------------------------------------- /opennre/framework/bag_re.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn, optim 3 | import json 4 | from .data_loader import SentenceRELoader, BagRELoader 5 | from .utils import AverageMeter 6 | from tqdm import tqdm 7 | import os 8 | 9 | class BagRE(nn.Module): 10 | 11 | def __init__(self, 12 | model, 13 | train_path, 14 | val_path, 15 | test_path, 16 | ckpt, 17 | batch_size=32, 18 | max_epoch=100, 19 | lr=0.1, 20 | weight_decay=1e-5, 21 | opt='sgd', 22 | bag_size=0, 23 | loss_weight=False): 24 | 25 | super().__init__() 26 | self.max_epoch = max_epoch 27 | self.bag_size = bag_size 28 | # Load data 29 | if train_path != None: 30 | self.train_loader = BagRELoader( 31 | train_path, 32 | model.rel2id, 33 | model.sentence_encoder.tokenize, 34 | batch_size, 35 | True, 36 | bag_size=bag_size, 37 | entpair_as_bag=False) 38 | 39 | if val_path != None: 40 | self.val_loader = BagRELoader( 41 | val_path, 42 | model.rel2id, 43 | model.sentence_encoder.tokenize, 44 | batch_size, 45 | False, 46 | bag_size=bag_size, 47 | entpair_as_bag=True) 48 | 49 | if test_path != None: 50 | self.test_loader = BagRELoader( 51 | test_path, 52 | model.rel2id, 53 | model.sentence_encoder.tokenize, 54 | batch_size, 55 | False, 56 | bag_size=bag_size, 57 | entpair_as_bag=True 58 | ) 59 | # Model 60 | self.model = nn.DataParallel(model) 61 | # Criterion 62 | if loss_weight: 63 | self.criterion = nn.CrossEntropyLoss(weight=self.train_loader.dataset.weight) 64 | else: 65 | self.criterion = nn.CrossEntropyLoss() 66 | # Params and optimizer 67 | params = self.model.parameters() 68 | self.lr = lr 69 | if opt == 'sgd': 70 | self.optimizer = optim.SGD(params, lr, weight_decay=weight_decay) 71 | elif opt == 'adam': 72 | self.optimizer = optim.Adam(params, lr, weight_decay=weight_decay) 73 | elif opt == 'adamw': 74 | from transformers import AdamW 75 | params = list(self.named_parameters()) 76 | no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight'] 77 | grouped_params = [ 78 | { 79 | 'params': [p for n, p in params if not any(nd in n for nd in no_decay)], 80 | 'weight_decay': 0.01, 81 | 'lr': lr, 82 | 'ori_lr': lr 83 | }, 84 | { 85 | 'params': [p for n, p in params if any(nd in n for nd in no_decay)], 86 | 'weight_decay': 0.0, 87 | 'lr': lr, 88 | 'ori_lr': lr 89 | } 90 | ] 91 | self.optimizer = AdamW(grouped_params, correct_bias=False) 92 | else: 93 | raise Exception("Invalid optimizer. Must be 'sgd' or 'adam' or 'bert_adam'.") 94 | # Cuda 95 | if torch.cuda.is_available(): 96 | self.cuda() 97 | # Ckpt 98 | self.ckpt = ckpt 99 | 100 | def train_model(self, metric='auc'): 101 | best_metric = 0 102 | for epoch in range(self.max_epoch): 103 | # Train 104 | self.train() 105 | print("=== Epoch %d train ===" % epoch) 106 | avg_loss = AverageMeter() 107 | avg_acc = AverageMeter() 108 | avg_pos_acc = AverageMeter() 109 | t = tqdm(self.train_loader) 110 | for iter, data in enumerate(t): 111 | if torch.cuda.is_available(): 112 | for i in range(len(data)): 113 | try: 114 | data[i] = data[i].cuda() 115 | except: 116 | pass 117 | label = data[0] 118 | bag_name = data[1] 119 | scope = data[2] 120 | args = data[3:] 121 | logits = self.model(label, scope, *args, bag_size=self.bag_size) 122 | loss = self.criterion(logits, label) 123 | score, pred = logits.max(-1) # (B) 124 | acc = float((pred == label).long().sum()) / label.size(0) 125 | pos_total = (label != 0).long().sum() 126 | pos_correct = ((pred == label).long() * (label != 0).long()).sum() 127 | if pos_total > 0: 128 | pos_acc = float(pos_correct) / float(pos_total) 129 | else: 130 | pos_acc = 0 131 | 132 | # Log 133 | avg_loss.update(loss.item(), 1) 134 | avg_acc.update(acc, 1) 135 | avg_pos_acc.update(pos_acc, 1) 136 | t.set_postfix(loss=avg_loss.avg, acc=avg_acc.avg, pos_acc=avg_pos_acc.avg) 137 | 138 | # Optimize 139 | loss.backward() 140 | self.optimizer.step() 141 | self.optimizer.zero_grad() 142 | 143 | # Val 144 | print("=== Epoch %d val ===" % epoch) 145 | result = self.eval_model(self.val_loader) 146 | print("AUC: %.4f" % result['auc']) 147 | print("Micro F1: %.4f" % (result['micro_f1'])) 148 | if result[metric] > best_metric: 149 | print("Best ckpt and saved.") 150 | torch.save({'state_dict': self.model.module.state_dict()}, self.ckpt) 151 | best_metric = result[metric] 152 | print("Best %s on val set: %f" % (metric, best_metric)) 153 | 154 | def eval_model(self, eval_loader): 155 | self.model.eval() 156 | with torch.no_grad(): 157 | t = tqdm(eval_loader) 158 | pred_result = [] 159 | for iter, data in enumerate(t): 160 | if torch.cuda.is_available(): 161 | for i in range(len(data)): 162 | try: 163 | data[i] = data[i].cuda() 164 | except: 165 | pass 166 | label = data[0] 167 | bag_name = data[1] 168 | scope = data[2] 169 | args = data[3:] 170 | logits = self.model(None, scope, *args, train=False, bag_size=self.bag_size) # results after softmax 171 | logits = logits.cpu().numpy() 172 | for i in range(len(logits)): 173 | for relid in range(self.model.module.num_class): 174 | if self.model.module.id2rel[relid] != 'NA': 175 | pred_result.append({ 176 | 'entpair': bag_name[i][:2], 177 | 'relation': self.model.module.id2rel[relid], 178 | 'score': logits[i][relid] 179 | }) 180 | result = eval_loader.dataset.eval(pred_result) 181 | return result 182 | 183 | def load_state_dict(self, state_dict): 184 | self.model.module.load_state_dict(state_dict) 185 | -------------------------------------------------------------------------------- /opennre/framework/data_loader.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.utils.data as data 3 | import os, random, json, logging 4 | import numpy as np 5 | import sklearn.metrics 6 | 7 | class SentenceREDataset(data.Dataset): 8 | """ 9 | Sentence-level relation extraction dataset 10 | """ 11 | def __init__(self, path, rel2id, tokenizer, kwargs): 12 | """ 13 | Args: 14 | path: 数据原始文件 15 | rel2id: dictionary of relation->id mapping, 关系到id的映射字典 16 | tokenizer: function of tokenizing,初始化的tokenizer 17 | kwargs: tokenizer的其它参数 18 | """ 19 | super().__init__() 20 | self.path = path 21 | self.tokenizer = tokenizer 22 | self.rel2id = rel2id 23 | self.kwargs = kwargs 24 | 25 | # Load the file 26 | f = open(path) 27 | self.data = [] 28 | for line in f.readlines(): 29 | line = line.rstrip() 30 | if len(line) > 0: 31 | self.data.append(eval(line)) 32 | f.close() 33 | logging.info("加载 RE 数据集 {} with {} 行和{} 个关系.".format(path, len(self.data), len(self.rel2id))) 34 | 35 | def __len__(self): 36 | return len(self.data) 37 | 38 | def __getitem__(self, index): 39 | """ 40 | 获取一条数据,数据经过tokenizer后的, 会连续获取一个batch_size的数据 41 | :param index: 47393 42 | :return: 包含5个元素的列表, 关系的id, 句子id, att_mask, 实体1的起始位置tensor, 实体2的结束位置 43 | """ 44 | item = self.data[index] 45 | seq = list(self.tokenizer(item, **self.kwargs)) 46 | # [self.rel2id[item['relation']]] 代表关系的id 47 | res = [self.rel2id[item['relation']]] + seq 48 | return res # label, seq1, seq2, ... 49 | 50 | def collate_fn(data): 51 | """ 52 | 对一个batch的数据进行处理,里面是一个列表,是batch_size大小,是上面__getitem__返回的每一条数据拼接成的一个batch_size大小的列表,这里每个原素是包含5个元素 53 | 经过list(zip(*data))处理后每个元素的第一维度是batch_size 54 | :return: 55 | """ 56 | data = list(zip(*data)) 57 | labels = data[0] 58 | # seqs是4种特征,句子id, att_mask, 实体1的起始位置tensor, 实体2的结束位置 59 | seqs = data[1:] 60 | batch_labels = torch.tensor(labels).long() # (B) 61 | batch_seqs = [] 62 | for seq in seqs: 63 | # 把每个特征的batch_size 拼接起来 [1,128] --> [batch_size, max_seq_length] , 然后放到一个列表中 64 | batch_seqs.append(torch.cat(seq, 0)) # (B, L) 65 | #返回包含一个列表的含4个元素,每个元素都是tensor 66 | return [batch_labels] + batch_seqs 67 | 68 | def eval(self, pred_result, use_name=False): 69 | """ 70 | Args: 71 | pred_result: 预测标签(id)的列表,在生成dataloader时确保`shuffle`参数设置为`False`。 72 | use_name: 如果True,`pred_result`包含预测的关系名,而不是id。 73 | Return: 74 | {'acc': acc, 'micro_p': micro_p, 'micro_r': micro_r, 'micro_f1': micro_f1} 75 | """ 76 | correct = 0 77 | total = len(self.data) 78 | correct_positive = 0 79 | pred_positive = 0 80 | gold_positive = 0 81 | neg = -1 82 | for name in ['NA', 'na', 'no_relation', 'Other', 'Others']: 83 | if name in self.rel2id: 84 | if use_name: 85 | neg = name 86 | else: 87 | neg = self.rel2id[name] 88 | break 89 | for i in range(total): 90 | if use_name: 91 | golden = self.data[i]['relation'] 92 | else: 93 | golden = self.rel2id[self.data[i]['relation']] 94 | if golden == pred_result[i]: 95 | correct += 1 96 | if golden != neg: 97 | correct_positive += 1 98 | if golden != neg: 99 | gold_positive +=1 100 | if pred_result[i] != neg: 101 | pred_positive += 1 102 | acc = float(correct) / float(total) 103 | try: 104 | micro_p = float(correct_positive) / float(pred_positive) 105 | except: 106 | micro_p = 0 107 | try: 108 | micro_r = float(correct_positive) / float(gold_positive) 109 | except: 110 | micro_r = 0 111 | try: 112 | micro_f1 = 2 * micro_p * micro_r / (micro_p + micro_r) 113 | except: 114 | micro_f1 = 0 115 | result = {'acc': acc, 'micro_p': micro_p, 'micro_r': micro_r, 'micro_f1': micro_f1} 116 | logging.info('评估结果 : {}.'.format(result)) 117 | return result 118 | 119 | def SentenceRELoader(path, rel2id, tokenizer, batch_size, 120 | shuffle, num_workers=8, collate_fn=SentenceREDataset.collate_fn, **kwargs): 121 | """ 122 | 加载数据,返回Dataloader 123 | :param path: 数据集文件 eg: './benchmark/wiki80/wiki80_train.txt' 124 | :param rel2id: 关系到id的映射字典 125 | :param tokenizer: 初始化的tokenizer 126 | :param batch_size: eg: 16 127 | :param shuffle: bool 128 | :param num_workers: 使用的进程数 129 | :param collate_fn: 数据处理函数 130 | :param kwargs: tokenizer的其它参数 131 | :return: 132 | """ 133 | dataset = SentenceREDataset(path = path, rel2id = rel2id, tokenizer = tokenizer, kwargs=kwargs) 134 | data_loader = data.DataLoader(dataset=dataset, 135 | batch_size=batch_size, 136 | shuffle=shuffle, 137 | pin_memory=True, 138 | num_workers=num_workers, 139 | collate_fn=collate_fn) 140 | return data_loader 141 | 142 | class BagREDataset(data.Dataset): 143 | """ 144 | Bag-level relation extraction dataset. Note that relation of NA should be named as 'NA'. 145 | """ 146 | def __init__(self, path, rel2id, tokenizer, entpair_as_bag=False, bag_size=0, mode=None): 147 | """ 148 | Args: 149 | path: path of the input file 150 | rel2id: dictionary of relation->id mapping 151 | tokenizer: function of tokenizing 152 | entpair_as_bag: if True, bags are constructed based on same 153 | entity pairs instead of same relation facts (ignoring 154 | relation labels) 155 | """ 156 | super().__init__() 157 | self.tokenizer = tokenizer 158 | self.rel2id = rel2id 159 | self.entpair_as_bag = entpair_as_bag 160 | self.bag_size = bag_size 161 | 162 | # Load the file 163 | f = open(path) 164 | self.data = [] 165 | for line in f: 166 | line = line.rstrip() 167 | if len(line) > 0: 168 | self.data.append(eval(line)) 169 | f.close() 170 | 171 | # Construct bag-level dataset (a bag contains instances sharing the same relation fact) 172 | if mode == None: 173 | self.weight = np.ones((len(self.rel2id)), dtype=np.float32) 174 | self.bag_scope = [] 175 | self.name2id = {} 176 | self.bag_name = [] 177 | self.facts = {} 178 | for idx, item in enumerate(self.data): 179 | fact = (item['h']['id'], item['t']['id'], item['relation']) 180 | if item['relation'] != 'NA': 181 | self.facts[fact] = 1 182 | if entpair_as_bag: 183 | name = (item['h']['id'], item['t']['id']) 184 | else: 185 | name = fact 186 | if name not in self.name2id: 187 | self.name2id[name] = len(self.name2id) 188 | self.bag_scope.append([]) 189 | self.bag_name.append(name) 190 | self.bag_scope[self.name2id[name]].append(idx) 191 | self.weight[self.rel2id[item['relation']]] += 1.0 192 | self.weight = 1.0 / (self.weight ** 0.05) 193 | self.weight = torch.from_numpy(self.weight) 194 | else: 195 | pass 196 | 197 | def __len__(self): 198 | return len(self.bag_scope) 199 | 200 | def __getitem__(self, index): 201 | bag = self.bag_scope[index] 202 | if self.bag_size > 0: 203 | if self.bag_size <= len(bag): 204 | resize_bag = random.sample(bag, self.bag_size) 205 | else: 206 | resize_bag = bag + list(np.random.choice(bag, self.bag_size - len(bag))) 207 | bag = resize_bag 208 | 209 | seqs = None 210 | rel = self.rel2id[self.data[bag[0]]['relation']] 211 | for sent_id in bag: 212 | item = self.data[sent_id] 213 | seq = list(self.tokenizer(item)) 214 | if seqs is None: 215 | seqs = [] 216 | for i in range(len(seq)): 217 | seqs.append([]) 218 | for i in range(len(seq)): 219 | seqs[i].append(seq[i]) 220 | for i in range(len(seqs)): 221 | seqs[i] = torch.cat(seqs[i], 0) # (n, L), n is the size of bag 222 | return [rel, self.bag_name[index], len(bag)] + seqs 223 | 224 | def collate_fn(data): 225 | data = list(zip(*data)) 226 | label, bag_name, count = data[:3] 227 | seqs = data[3:] 228 | for i in range(len(seqs)): 229 | seqs[i] = torch.cat(seqs[i], 0) # (sumn, L) 230 | seqs[i] = seqs[i].expand((torch.cuda.device_count() if torch.cuda.device_count() > 0 else 1, ) + seqs[i].size()) 231 | scope = [] # (B, 2) 232 | start = 0 233 | for c in count: 234 | scope.append((start, start + c)) 235 | start += c 236 | assert(start == seqs[0].size(1)) 237 | scope = torch.tensor(scope).long() 238 | label = torch.tensor(label).long() # (B) 239 | return [label, bag_name, scope] + seqs 240 | 241 | def collate_bag_size_fn(data): 242 | data = list(zip(*data)) 243 | label, bag_name, count = data[:3] 244 | seqs = data[3:] 245 | for i in range(len(seqs)): 246 | seqs[i] = torch.stack(seqs[i], 0) # (batch, bag, L) 247 | scope = [] # (B, 2) 248 | start = 0 249 | for c in count: 250 | scope.append((start, start + c)) 251 | start += c 252 | label = torch.tensor(label).long() # (B) 253 | return [label, bag_name, scope] + seqs 254 | 255 | 256 | def eval(self, pred_result): 257 | """ 258 | Args: 259 | pred_result: a list with dict {'entpair': (head_id, tail_id), 'relation': rel, 'score': score}. 260 | Note that relation of NA should be excluded. 261 | Return: 262 | {'prec': narray[...], 'rec': narray[...], 'mean_prec': xx, 'f1': xx, 'auc': xx} 263 | prec (precision) and rec (recall) are in micro style. 264 | prec (precision) and rec (recall) are sorted in the decreasing order of the score. 265 | f1 is the max f1 score of those precison-recall points 266 | """ 267 | sorted_pred_result = sorted(pred_result, key=lambda x: x['score'], reverse=True) 268 | prec = [] 269 | rec = [] 270 | correct = 0 271 | total = len(self.facts) 272 | for i, item in enumerate(sorted_pred_result): 273 | if (item['entpair'][0], item['entpair'][1], item['relation']) in self.facts: 274 | correct += 1 275 | prec.append(float(correct) / float(i + 1)) 276 | rec.append(float(correct) / float(total)) 277 | auc = sklearn.metrics.auc(x=rec, y=prec) 278 | np_prec = np.array(prec) 279 | np_rec = np.array(rec) 280 | f1 = (2 * np_prec * np_rec / (np_prec + np_rec + 1e-20)).max() 281 | mean_prec = np_prec.mean() 282 | return {'micro_p': np_prec, 'micro_r': np_rec, 'micro_p_mean': mean_prec, 'micro_f1': f1, 'auc': auc} 283 | 284 | def BagRELoader(path, rel2id, tokenizer, batch_size, 285 | shuffle, entpair_as_bag=False, bag_size=0, num_workers=8, 286 | collate_fn=BagREDataset.collate_fn): 287 | if bag_size == 0: 288 | collate_fn = BagREDataset.collate_fn 289 | else: 290 | collate_fn = BagREDataset.collate_bag_size_fn 291 | dataset = BagREDataset(path, rel2id, tokenizer, entpair_as_bag=entpair_as_bag, bag_size=bag_size) 292 | data_loader = data.DataLoader(dataset=dataset, 293 | batch_size=batch_size, 294 | shuffle=shuffle, 295 | pin_memory=True, 296 | num_workers=num_workers, 297 | collate_fn=collate_fn) 298 | return data_loader 299 | -------------------------------------------------------------------------------- /opennre/framework/sentence_re.py: -------------------------------------------------------------------------------- 1 | import os, logging, json 2 | from tqdm import tqdm 3 | import torch 4 | from torch import nn, optim 5 | from .data_loader import SentenceRELoader 6 | from .utils import AverageMeter 7 | 8 | class SentenceRE(nn.Module): 9 | """ 10 | 模型训练函数 11 | """ 12 | def __init__(self, 13 | model, #初始化后的模型 14 | train_path, #训练文件路径 15 | val_path, 16 | test_path, 17 | ckpt, #要保存的模型的checkpoint路径 18 | batch_size=32, 19 | max_epoch=100, 20 | lr=0.1, 21 | weight_decay=1e-5, 22 | warmup_step=300, 23 | opt='sgd', 24 | parallel = False, #是否使用并发的GPU 25 | num_workers = 0 # Dataloader 的并发数,只主进程加载 26 | ): 27 | 28 | super().__init__() 29 | self.max_epoch = max_epoch 30 | # Load data 31 | if train_path != None: 32 | self.train_loader = SentenceRELoader( 33 | train_path, 34 | model.rel2id, 35 | model.sentence_encoder.tokenize, 36 | batch_size, 37 | shuffle=True, num_workers=num_workers) 38 | 39 | if val_path != None: 40 | self.val_loader = SentenceRELoader( 41 | val_path, 42 | model.rel2id, 43 | model.sentence_encoder.tokenize, 44 | batch_size, 45 | shuffle=False, num_workers=num_workers) 46 | 47 | if test_path != None: 48 | self.test_loader = SentenceRELoader( 49 | test_path, 50 | model.rel2id, 51 | model.sentence_encoder.tokenize, 52 | batch_size, 53 | shuffle=False, num_workers=num_workers 54 | ) 55 | # Model 56 | self.model = model 57 | if parallel: 58 | self.parallel_model = nn.DataParallel(self.model) 59 | else: 60 | self.parallel_model = None 61 | # Criterion 62 | self.criterion = nn.CrossEntropyLoss() 63 | # Params and optimizer 64 | params = self.parameters() 65 | self.lr = lr 66 | if opt == 'sgd': 67 | self.optimizer = optim.SGD(params, lr, weight_decay=weight_decay) 68 | elif opt == 'adam': 69 | self.optimizer = optim.Adam(params, lr, weight_decay=weight_decay) 70 | elif opt == 'adamw': # Optimizer for BERT 71 | from transformers import AdamW 72 | params = list(self.named_parameters()) 73 | no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight'] 74 | grouped_params = [ 75 | { 76 | 'params': [p for n, p in params if not any(nd in n for nd in no_decay)], 77 | 'weight_decay': 0.01, 78 | 'lr': lr, 79 | 'ori_lr': lr 80 | }, 81 | { 82 | 'params': [p for n, p in params if any(nd in n for nd in no_decay)], 83 | 'weight_decay': 0.0, 84 | 'lr': lr, 85 | 'ori_lr': lr 86 | } 87 | ] 88 | self.optimizer = AdamW(grouped_params, correct_bias=False) 89 | else: 90 | raise Exception("无效优化器. Must be 'sgd' or 'adam' or 'adamw'.") 91 | # Warmup 92 | if warmup_step > 0: 93 | from transformers import get_linear_schedule_with_warmup 94 | training_steps = self.train_loader.dataset.__len__() // batch_size * self.max_epoch 95 | self.scheduler = get_linear_schedule_with_warmup(self.optimizer, num_warmup_steps=warmup_step, num_training_steps=training_steps) 96 | else: 97 | self.scheduler = None 98 | # Cuda 99 | if torch.cuda.is_available(): 100 | logging.info("检测到GPU可用,使用GPU") 101 | self.cuda() 102 | # Ckpt 103 | self.ckpt = ckpt 104 | 105 | def train_model(self, metric='acc'): 106 | """ 107 | 训练模型 108 | :param metric: 使用哪个指标作为checkpoint的最佳指标,支持acc, micro_p, micro_r, micro_f1 109 | :return: 110 | """ 111 | best_metric = 0 112 | global_step = 0 113 | for epoch in range(self.max_epoch): 114 | self.train() 115 | logging.info("=== Epoch %d train ===" % epoch) 116 | avg_loss = AverageMeter() 117 | avg_acc = AverageMeter() 118 | t = tqdm(self.train_loader) 119 | # enumerate这里把dataset进行tokenizer化 120 | for iter, data in enumerate(t): 121 | if torch.cuda.is_available(): 122 | for i in range(len(data)): 123 | try: 124 | data[i] = data[i].cuda() 125 | except: 126 | pass 127 | # 一个batch_size的label, 16 128 | label = data[0] 129 | # args是一个包含4个元素的列表, [token ids, atten_mask, entity1_start_id, entity2_end_id] 130 | args = data[1:] 131 | # args是一个包含4个元素的列表作为特征放入模型, logits shape [batch_size, num_classes] 132 | if self.parallel_model: 133 | logits = self.parallel_model(*args) 134 | else: 135 | logits = self.model(*args) 136 | loss = self.criterion(logits, label) 137 | score, pred = logits.max(-1) # (B) 138 | acc = float((pred == label).long().sum()) / label.size(0) 139 | #记录日志 140 | avg_loss.update(loss.item(), 1) 141 | avg_acc.update(acc, 1) 142 | t.set_postfix(loss=avg_loss.avg, acc=avg_acc.avg) 143 | # Optimize 144 | loss.backward() 145 | self.optimizer.step() 146 | if self.scheduler is not None: 147 | self.scheduler.step() 148 | self.optimizer.zero_grad() 149 | global_step += 1 150 | # Val 151 | logging.info("=== Epoch %d val ===" % epoch) 152 | result = self.eval_model(self.val_loader) 153 | logging.info('Metric {} current / best: {} / {}'.format(metric, result[metric], best_metric)) 154 | if result[metric] > best_metric: 155 | logging.info(f"获得了更好的metric {result[metric]},保存模型") 156 | folder_path = '/'.join(self.ckpt.split('/')[:-1]) 157 | if not os.path.exists(folder_path): 158 | os.mkdir(folder_path) 159 | #保存模型 160 | torch.save({'state_dict': self.model.state_dict()}, self.ckpt) 161 | best_metric = result[metric] 162 | logging.info("Best %s on val set: %f" % (metric, best_metric)) 163 | 164 | def eval_model(self, eval_loader): 165 | """ 166 | 评估模型 167 | :param eval_loader: 评估数据集 168 | :return: 169 | """ 170 | self.eval() 171 | avg_acc = AverageMeter() 172 | pred_result = [] 173 | with torch.no_grad(): 174 | t = tqdm(eval_loader, desc='评估: ') 175 | for iter, data in enumerate(t): 176 | if torch.cuda.is_available(): 177 | for i in range(len(data)): 178 | try: 179 | data[i] = data[i].cuda() 180 | except: 181 | pass 182 | label = data[0] 183 | args = data[1:] 184 | if self.parallel_model: 185 | logits = self.parallel_model(*args) 186 | else: 187 | logits = self.model(*args) 188 | score, pred = logits.max(-1) # (B) 189 | # Save result 190 | for i in range(pred.size(0)): 191 | pred_result.append(pred[i].item()) 192 | # Log 193 | acc = float((pred == label).long().sum()) / label.size(0) 194 | avg_acc.update(acc, pred.size(0)) 195 | t.set_postfix(acc=avg_acc.avg) 196 | result = eval_loader.dataset.eval(pred_result) 197 | return result 198 | 199 | def load_state_dict(self, state_dict): 200 | self.model.load_state_dict(state_dict) 201 | 202 | -------------------------------------------------------------------------------- /opennre/framework/utils.py: -------------------------------------------------------------------------------- 1 | class AverageMeter(object): 2 | """ 3 | Computes and stores the average and current value of metrics. 4 | """ 5 | 6 | def __init__(self): 7 | self.reset() 8 | 9 | def reset(self): 10 | self.val = 0 11 | self.avg = 0 12 | self.sum = 0 13 | self.count = 0 14 | 15 | def update(self, val, n=0): 16 | self.val = val 17 | self.sum += val * n 18 | self.count += n 19 | self.avg = self.sum / (.0001 + self.count) 20 | 21 | def __str__(self): 22 | """ 23 | String representation for logging 24 | """ 25 | # for values that should be recorded exactly e.g. iteration number 26 | if self.count == 0: 27 | return str(self.val) 28 | # for stats 29 | return '%.4f (%.4f)' % (self.val, self.avg) -------------------------------------------------------------------------------- /opennre/model/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | from .base_model import SentenceRE, BagRE, FewShotRE, NER 6 | from .softmax_nn import SoftmaxNN 7 | from .bag_attention import BagAttention 8 | from .bag_average import BagAverage 9 | 10 | __all__ = [ 11 | 'SentenceRE', 12 | 'BagRE', 13 | 'FewShotRE', 14 | 'NER', 15 | 'SoftmaxNN', 16 | 'BagAttention' 17 | ] -------------------------------------------------------------------------------- /opennre/model/bag_attention.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn, optim 3 | from .base_model import BagRE 4 | 5 | class BagAttention(BagRE): 6 | """ 7 | Instance attention for bag-level relation extraction. 8 | """ 9 | 10 | def __init__(self, sentence_encoder, num_class, rel2id): 11 | """ 12 | Args: 13 | sentence_encoder: encoder for sentences 14 | num_class: number of classes 15 | id2rel: dictionary of id -> relation name mapping 16 | """ 17 | super().__init__() 18 | self.sentence_encoder = sentence_encoder 19 | self.num_class = num_class 20 | self.fc = nn.Linear(self.sentence_encoder.hidden_size, num_class) 21 | self.softmax = nn.Softmax(-1) 22 | self.rel2id = rel2id 23 | self.id2rel = {} 24 | self.drop = nn.Dropout() 25 | for rel, id in rel2id.items(): 26 | self.id2rel[id] = rel 27 | 28 | def infer(self, bag): 29 | """ 30 | Args: 31 | bag: bag of sentences with the same entity pair 32 | [{ 33 | 'text' or 'token': ..., 34 | 'h': {'pos': [start, end], ...}, 35 | 't': {'pos': [start, end], ...} 36 | }] 37 | Return: 38 | (relation, score) 39 | """ 40 | self.eval() 41 | tokens = [] 42 | pos1s = [] 43 | pos2s = [] 44 | masks = [] 45 | for item in bag: 46 | token, pos1, pos2, mask = self.sentence_encoder.tokenize(item) 47 | tokens.append(token) 48 | pos1s.append(pos1) 49 | pos2s.append(pos2) 50 | masks.append(mask) 51 | tokens = torch.cat(tokens, 0).unsqueeze(0) # (n, L) 52 | pos1s = torch.cat(pos1s, 0).unsqueeze(0) 53 | pos2s = torch.cat(pos2s, 0).unsqueeze(0) 54 | masks = torch.cat(masks, 0).unsqueeze(0) 55 | scope = torch.tensor([[0, len(bag)]]).long() # (1, 2) 56 | bag_logits = self.forward(None, scope, tokens, pos1s, pos2s, masks, train=False).squeeze(0) # (N) after softmax 57 | score, pred = bag_logits.max(0) 58 | score = score.item() 59 | pred = pred.item() 60 | rel = self.id2rel[pred] 61 | return (rel, score) 62 | 63 | def forward(self, label, scope, token, pos1, pos2, mask=None, train=True, bag_size=0): 64 | """ 65 | Args: 66 | label: (B), label of the bag 67 | scope: (B), scope for each bag 68 | token: (nsum, L), index of tokens 69 | pos1: (nsum, L), relative position to head entity 70 | pos2: (nsum, L), relative position to tail entity 71 | mask: (nsum, L), used for piece-wise CNN 72 | Return: 73 | logits, (B, N) 74 | """ 75 | if bag_size > 0: 76 | token = token.view(-1, token.size(-1)) 77 | pos1 = pos1.view(-1, pos1.size(-1)) 78 | pos2 = pos2.view(-1, pos2.size(-1)) 79 | if mask is not None: 80 | mask = mask.view(-1, mask.size(-1)) 81 | else: 82 | begin, end = scope[0][0], scope[-1][1] 83 | token = token[:, begin:end, :].view(-1, token.size(-1)) 84 | pos1 = pos1[:, begin:end, :].view(-1, pos1.size(-1)) 85 | pos2 = pos2[:, begin:end, :].view(-1, pos2.size(-1)) 86 | if mask is not None: 87 | mask = mask[:, begin:end, :].view(-1, mask.size(-1)) 88 | scope = torch.sub(scope, torch.zeros_like(scope).fill_(begin)) 89 | if mask is not None: 90 | rep = self.sentence_encoder(token, pos1, pos2, mask) # (nsum, H) 91 | else: 92 | rep = self.sentence_encoder(token, pos1, pos2) # (nsum, H) 93 | 94 | # Attention 95 | if train: 96 | if bag_size == 0: 97 | bag_rep = [] 98 | query = torch.zeros((rep.size(0))).long() 99 | if torch.cuda.is_available(): 100 | query = query.cuda() 101 | for i in range(len(scope)): 102 | query[scope[i][0]:scope[i][1]] = label[i] 103 | att_mat = self.fc.weight[query] # (nsum, H) 104 | att_score = (rep * att_mat).sum(-1) # (nsum) 105 | 106 | for i in range(len(scope)): 107 | bag_mat = rep[scope[i][0]:scope[i][1]] # (n, H) 108 | softmax_att_score = self.softmax(att_score[scope[i][0]:scope[i][1]]) # (n) 109 | bag_rep.append((softmax_att_score.unsqueeze(-1) * bag_mat).sum(0)) # (n, 1) * (n, H) -> (n, H) -> (H) 110 | bag_rep = torch.stack(bag_rep, 0) # (B, H) 111 | else: 112 | batch_size = label.size(0) 113 | query = label.unsqueeze(1) # (B, 1) 114 | att_mat = self.fc.weight[query] # (B, 1, H) 115 | rep = rep.view(batch_size, bag_size, -1) 116 | att_score = (rep * att_mat).sum(-1) # (B, bag) 117 | softmax_att_score = self.softmax(att_score) # (B, bag) 118 | bag_rep = (softmax_att_score.unsqueeze(-1) * rep).sum(1) # (B, bag, 1) * (B, bag, H) -> (B, bag, H) -> (B, H) 119 | bag_rep = self.drop(bag_rep) 120 | bag_logits = self.fc(bag_rep) # (B, N) 121 | else: 122 | if bag_size == 0: 123 | bag_logits = [] 124 | att_score = torch.matmul(rep, self.fc.weight.transpose(0, 1)) # (nsum, H) * (H, N) -> (nsum, N) 125 | for i in range(len(scope)): 126 | bag_mat = rep[scope[i][0]:scope[i][1]] # (n, H) 127 | softmax_att_score = self.softmax(att_score[scope[i][0]:scope[i][1]].transpose(0, 1)) # (N, (softmax)n) 128 | rep_for_each_rel = torch.matmul(softmax_att_score, bag_mat) # (N, n) * (n, H) -> (N, H) 129 | logit_for_each_rel = self.softmax(self.fc(rep_for_each_rel)) # ((each rel)N, (logit)N) 130 | logit_for_each_rel = logit_for_each_rel.diag() # (N) 131 | bag_logits.append(logit_for_each_rel) 132 | bag_logits = torch.stack(bag_logits,0) # after **softmax** 133 | else: 134 | batch_size = rep.size(0) // bag_size 135 | att_score = torch.matmul(rep, self.fc.weight.transpose(0, 1)) # (nsum, H) * (H, N) -> (nsum, N) 136 | att_score = att_score.view(batch_size, bag_size, -1) # (B, bag, N) 137 | rep = rep.view(batch_size, bag_size, -1) # (B, bag, H) 138 | softmax_att_score = self.softmax(att_score.transpose(1, 2)) # (B, N, (softmax)bag) 139 | rep_for_each_rel = torch.matmul(softmax_att_score, rep) # (B, N, bag) * (B, bag, H) -> (B, N, H) 140 | bag_logits = self.softmax(self.fc(rep_for_each_rel)).diagonal(dim1=1, dim2=2) # (B, (each rel)N) 141 | return bag_logits 142 | 143 | -------------------------------------------------------------------------------- /opennre/model/bag_average.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn, optim 3 | from .base_model import BagRE 4 | 5 | class BagAverage(BagRE): 6 | """ 7 | Average policy for bag-level relation extraction. 8 | """ 9 | 10 | def __init__(self, sentence_encoder, num_class, rel2id): 11 | """ 12 | Args: 13 | sentence_encoder: encoder for sentences 14 | num_class: number of classes 15 | id2rel: dictionary of id -> relation name mapping 16 | """ 17 | super().__init__() 18 | self.sentence_encoder = sentence_encoder 19 | self.num_class = num_class 20 | self.fc = nn.Linear(self.sentence_encoder.hidden_size, num_class) 21 | self.softmax = nn.Softmax(-1) 22 | self.rel2id = rel2id 23 | self.id2rel = {} 24 | self.drop = nn.Dropout() 25 | for rel, id in rel2id.items(): 26 | self.id2rel[id] = rel 27 | 28 | def infer(self, bag): 29 | """ 30 | Args: 31 | bag: bag of sentences with the same entity pair 32 | [{ 33 | 'text' or 'token': ..., 34 | 'h': {'pos': [start, end], ...}, 35 | 't': {'pos': [start, end], ...} 36 | }] 37 | Return: 38 | (relation, score) 39 | """ 40 | pass 41 | 42 | """ 43 | tokens = [] 44 | pos1s = [] 45 | pos2s = [] 46 | masks = [] 47 | for item in bag: 48 | if 'text' in item: 49 | token, pos1, pos2, mask = self.tokenizer(item['text'], 50 | item['h']['pos'], item['t']['pos'], is_token=False, padding=True) 51 | else: 52 | token, pos1, pos2, mask = self.tokenizer(item['token'], 53 | item['h']['pos'], item['t']['pos'], is_token=True, padding=True) 54 | tokens.append(token) 55 | pos1s.append(pos1) 56 | pos2s.append(pos2) 57 | masks.append(mask) 58 | tokens = torch.cat(tokens, 0) # (n, L) 59 | pos1s = torch.cat(pos1s, 0) 60 | pos2s = torch.cat(pos2s, 0) 61 | masks = torch.cat(masks, 0) 62 | scope = torch.tensor([[0, len(bag)]]).long() # (1, 2) 63 | bag_logits = self.forward(None, scope, tokens, pos1s, pos2s, masks, train=False).squeeze(0) # (N) after softmax 64 | score, pred = bag_logits.max() 65 | score = score.item() 66 | pred = pred.item() 67 | rel = self.id2rel[pred] 68 | return (rel, score) 69 | """ 70 | 71 | def forward(self, label, scope, token, pos1, pos2, mask=None, train=True, bag_size=None): 72 | """ 73 | Args: 74 | label: (B), label of the bag 75 | scope: (B), scope for each bag 76 | token: (nsum, L), index of tokens 77 | pos1: (nsum, L), relative position to head entity 78 | pos2: (nsum, L), relative position to tail entity 79 | mask: (nsum, L), used for piece-wise CNN 80 | Return: 81 | logits, (B, N) 82 | """ 83 | if mask: 84 | rep = self.sentence_encoder(token, pos1, pos2, mask) # (nsum, H) 85 | else: 86 | rep = self.sentence_encoder(token, pos1, pos2) # (nsum, H) 87 | 88 | # Average 89 | bag_rep = [] 90 | if bag_size is None: 91 | for i in range(len(scope)): 92 | bag_rep.append(rep[scope[i][0]:scope[i][1]].mean(0)) 93 | bag_rep = torch.stack(bag_rep, 0) # (B, H) 94 | else: 95 | batch_size = label.size(0) 96 | rep = rep.view(batch_size, bag_size, -1) # (B, bag, H) 97 | bag_rep = rep.mean(1) # (B, H) 98 | bag_rep = self.drop(bag_rep) 99 | bag_logits = self.fc(bag_rep) # (B, N) 100 | 101 | return bag_logits 102 | 103 | -------------------------------------------------------------------------------- /opennre/model/base_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import json 4 | 5 | class SentenceRE(nn.Module): 6 | def __init__(self): 7 | super().__init__() 8 | 9 | def infer(self, item): 10 | """ 11 | Args: 12 | item: {'text' or 'token', 'h': {'pos': [start, end]}, 't': ...} 13 | Return: 14 | (Name of the relation of the sentence, score) 15 | """ 16 | raise NotImplementedError 17 | 18 | 19 | class BagRE(nn.Module): 20 | def __init__(self): 21 | super().__init__() 22 | 23 | def infer(self, bag): 24 | """ 25 | Args: 26 | bag: bag of sentences with the same entity pair 27 | [{ 28 | 'text' or 'token': ..., 29 | 'h': {'pos': [start, end], ...}, 30 | 't': {'pos': [start, end], ...} 31 | }] 32 | Return: 33 | (relation, score) 34 | """ 35 | raise NotImplementedError 36 | 37 | class FewShotRE(nn.Module): 38 | def __init__(self): 39 | super().__init__() 40 | 41 | def infer(self, support, query): 42 | """ 43 | Args: 44 | support: supporting set. 45 | [{'text' or 'token': ..., 46 | 'h': {'pos': [start, end], ...}, 47 | 't': {'pos': [start, end], ...}, 48 | 'relation': ...}] 49 | query: same format as support 50 | Return: 51 | [(relation, score), ...] 52 | """ 53 | 54 | class NER(nn.Module): 55 | def __init__(self): 56 | super().__init__() 57 | 58 | def ner(self, sentence, is_token=False): 59 | """ 60 | Args: 61 | sentence: string, the input sentence 62 | is_token: if is_token == True, senetence becomes an array of token 63 | Return: 64 | [{name: xx, pos: [start, end]}], a list of named entities 65 | """ 66 | raise NotImplementedError 67 | -------------------------------------------------------------------------------- /opennre/model/softmax_nn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn, optim 3 | from .base_model import SentenceRE 4 | 5 | class SoftmaxNN(SentenceRE): 6 | """ 7 | Softmax分类器用于句子级关系抽取。 8 | """ 9 | 10 | def __init__(self, sentence_encoder, num_class, rel2id): 11 | """ 12 | Args: 13 | sentence_encoder: encoder for sentences, 初始化的模型 14 | num_class: number of classes, 类别数量 15 | id2rel: dictionary of id -> relation name mapping, 字典格式,关系到id的映射 16 | """ 17 | super().__init__() 18 | self.sentence_encoder = sentence_encoder 19 | self.num_class = num_class 20 | self.fc = nn.Linear(self.sentence_encoder.hidden_size, num_class) 21 | self.softmax = nn.Softmax(-1) 22 | self.rel2id = rel2id 23 | self.id2rel = {} 24 | self.drop = nn.Dropout() 25 | for rel, id in rel2id.items(): 26 | self.id2rel[id] = rel 27 | 28 | def infer(self, item): 29 | self.eval() 30 | item = self.sentence_encoder.tokenize(item) 31 | logits = self.forward(*item) 32 | logits = self.softmax(logits) 33 | score, pred = logits.max(-1) 34 | score = score.item() 35 | pred = pred.item() 36 | return self.id2rel[pred], score 37 | 38 | def forward(self, *args): 39 | """ 40 | Args: 41 | args: depends on the encoder 42 | Return: 43 | logits, (B, N) 44 | """ 45 | #调用enccoder模型 46 | rep = self.sentence_encoder(*args) # (B, H) 47 | rep = self.drop(rep) 48 | logits = self.fc(rep) # (B, N) 49 | return logits 50 | -------------------------------------------------------------------------------- /opennre/module/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | -------------------------------------------------------------------------------- /opennre/module/nn/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | from .cnn import CNN 6 | from .rnn import RNN 7 | from .lstm import LSTM 8 | 9 | __all__ = [ 10 | 'CNN', 11 | 'RNN', 12 | 'LSTM', 13 | ] -------------------------------------------------------------------------------- /opennre/module/nn/cnn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | class CNN(nn.Module): 6 | 7 | def __init__(self, input_size=50, hidden_size=256, dropout=0, kernel_size=3, padding=1, activation_function=F.relu): 8 | """ 9 | Args: 10 | input_size: dimention of input embedding 11 | kernel_size: kernel_size for CNN 12 | padding: padding for CNN 13 | hidden_size: hidden size 14 | """ 15 | super().__init__() 16 | self.conv = nn.Conv1d(input_size, hidden_size, kernel_size, padding=padding) 17 | self.act = activation_function 18 | self.dropout = nn.Dropout(dropout) 19 | 20 | def forward(self, x): 21 | """ 22 | Args: 23 | input features: (B, L, I_EMBED) 24 | Return: 25 | output features: (B, H_EMBED) 26 | """ 27 | # Check size of tensors 28 | x = x.transpose(1, 2) # (B, I_EMBED, L) 29 | x = self.conv(x) # (B, H_EMBED, L) 30 | x = self.act(x) # (B, H_EMBED, L) 31 | x = self.dropout(x) # (B, H_EMBED, L) 32 | x = x.transpose(1, 2) # (B, L, H_EMBED) 33 | return x 34 | -------------------------------------------------------------------------------- /opennre/module/nn/lstm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | class LSTM(nn.Module): 5 | 6 | def __init__(self, input_size=50, hidden_size=256, dropout=0, bidirectional=False, num_layers=1, activation_function="tanh"): 7 | """ 8 | Args: 9 | input_size: dimention of input embedding 10 | hidden_size: hidden size 11 | dropout: dropout layer on the outputs of each RNN layer except the last layer 12 | bidirectional: if it is a bidirectional RNN 13 | num_layers: number of recurrent layers 14 | activation_function: the activation function of RNN, tanh/relu 15 | """ 16 | super().__init__() 17 | if bidirectional: 18 | hidden_size /= 2 19 | self.lstm = nn.LSTM(input_size, 20 | hidden_size, 21 | num_layers, 22 | nonlinearity=activation_function, 23 | dropout=dropout, 24 | bidirectional=bidirectional) 25 | 26 | def forward(self, x): 27 | """ 28 | Args: 29 | input features: (B, L, I_EMBED) 30 | Return: 31 | output features: (B, L, H_EMBED) 32 | """ 33 | # Check size of tensors 34 | x = x.transpose(0, 1) # (L, B, I_EMBED) 35 | x, h, c = self.lstm(x) # (L, B, H_EMBED) 36 | x = x.transpose(0, 1) # (B, L, I_EMBED) 37 | return x 38 | -------------------------------------------------------------------------------- /opennre/module/nn/rnn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | class RNN(nn.Module): 5 | 6 | def __init__(self, input_size=50, hidden_size=256, dropout=0, bidirectional=False, num_layers=1, activation_function="tanh"): 7 | """ 8 | Args: 9 | input_size: dimention of input embedding 10 | hidden_size: hidden size 11 | dropout: dropout layer on the outputs of each RNN layer except the last layer 12 | bidirectional: if it is a bidirectional RNN 13 | num_layers: number of recurrent layers 14 | activation_function: the activation function of RNN, tanh/relu 15 | """ 16 | super().__init__() 17 | if bidirectional: 18 | hidden_size /= 2 19 | self.rnn = nn.RNN(input_size, 20 | hidden_size, 21 | num_layers, 22 | nonlinearity=activation_function, 23 | dropout=dropout, 24 | bidirectional=bidirectional) 25 | 26 | def forward(self, x): 27 | """ 28 | Args: 29 | input features: (B, L, I_EMBED) 30 | Return: 31 | output features: (B, L, H_EMBED) 32 | """ 33 | # Check size of tensors 34 | x = x.transpose(0, 1) # (L, B, I_EMBED) 35 | x, h = self.rnn(x) # (L, B, H_EMBED) 36 | x = x.transpose(0, 1) # (B, L, I_EMBED) 37 | return x 38 | -------------------------------------------------------------------------------- /opennre/module/pool/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | from .max_pool import MaxPool 6 | from .avg_pool import AvgPool 7 | 8 | __all__ = [ 9 | 'MaxPool', 10 | 'AvgPool' 11 | ] -------------------------------------------------------------------------------- /opennre/module/pool/avg_pool.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | 6 | class AvgPool(nn.Module): 7 | 8 | def __init__(self, kernel_size, segment_num=None): 9 | """ 10 | Args: 11 | input_size: dimention of input embedding 12 | kernel_size: kernel_size for CNN 13 | padding: padding for CNN 14 | hidden_size: hidden size 15 | """ 16 | super().__init__() 17 | self.segment_num = segment_num 18 | if self.segment_num != None: 19 | self.mask_embedding = nn.Embedding(segment_num + 1, segment_num) 20 | self.mask_embedding.weight.data.copy_(torch.FloatTensor(np.concatenate([np.zeros(segment_num), np.identity(segment_num)], axis = 0))) 21 | self.mask_embedding.weight.requires_grad = False 22 | self.pool = nn.AvgPool1d(kernel_size) 23 | 24 | def forward(self, x, mask=None): 25 | """ 26 | Args: 27 | input features: (B, L, I_EMBED) 28 | Return: 29 | output features: (B, H_EMBED) 30 | """ 31 | # Check size of tensors 32 | if mask == None or self.segment_num == None or self.segment_num == 1: 33 | x = x.transpose(1, 2) # (B, L, I_EMBED) -> (B, I_EMBED, L) 34 | x = self.pool(x).squeeze(-1) # (B, I_EMBED, 1) -> (B, I_EMBED) 35 | return x 36 | else: 37 | B, L, I_EMBED = x.size()[:2] 38 | mask = self.mask_embedding(mask).transpose(1, 2).unsqueeze(2) # (B, L) -> (B, L, S) -> (B, S, L) -> (B, S, 1, L) 39 | x = x.transpose(1, 2).unsqueeze(1) # (B, L, I_EMBED) -> (B, I_EMBED, L) -> (B, 1, I_EMBED, L) 40 | x = (x * mask).view([-1, I_EMBED, L]) # (B, S, I_EMBED, L) -> (B * S, I_EMBED, L) 41 | x = self.pool(x).squeeze(-1) # (B * S, I_EMBED, 1) -> (B * S, I_EMBED) 42 | x = x.view([B, -1]) - self._minus # (B, S * I_EMBED) 43 | return x -------------------------------------------------------------------------------- /opennre/module/pool/max_pool.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | 6 | class MaxPool(nn.Module): 7 | 8 | def __init__(self, kernel_size, segment_num=None): 9 | """ 10 | Args: 11 | input_size: dimention of input embedding 12 | kernel_size: kernel_size for CNN 13 | padding: padding for CNN 14 | hidden_size: hidden size 15 | """ 16 | super().__init__() 17 | self.segment_num = segment_num 18 | if self.segment_num != None: 19 | self.mask_embedding = nn.Embedding(segment_num + 1, segment_num) 20 | self.mask_embedding.weight.data.copy_(torch.FloatTensor(np.concatenate([np.zeros((1, segment_num)), np.identity(segment_num)], axis=0))) 21 | self.mask_embedding.weight.requires_grad = False 22 | self._minus = -100 23 | self.pool = nn.MaxPool1d(kernel_size) 24 | 25 | def forward(self, x, mask=None): 26 | """ 27 | Args: 28 | input features: (B, L, I_EMBED) 29 | Return: 30 | output features: (B, H_EMBED) 31 | """ 32 | # Check size of tensors 33 | if mask is None or self.segment_num is None or self.segment_num == 1: 34 | x = x.transpose(1, 2) # (B, L, I_EMBED) -> (B, I_EMBED, L) 35 | x = self.pool(x).squeeze(-1) # (B, I_EMBED, 1) -> (B, I_EMBED) 36 | return x 37 | else: 38 | B, L, I_EMBED = x.size()[:3] 39 | # mask = 1 - self.mask_embedding(mask).transpose(1, 2).unsqueeze(2) # (B, L) -> (B, L, S) -> (B, S, L) -> (B, S, 1, L) 40 | # x = x.transpose(1, 2).unsqueeze(1) # (B, L, I_EMBED) -> (B, I_EMBED, L) -> (B, 1, I_EMBED, L) 41 | # x = (x + self._minus * mask).contiguous().view([-1, I_EMBED, L]) # (B, S, I_EMBED, L) -> (B * S, I_EMBED, L) 42 | # x = self.pool(x).squeeze(-1) # (B * S, I_EMBED, 1) -> (B * S, I_EMBED) 43 | # x = x.view([B, -1]) # (B, S * I_EMBED) 44 | # return x 45 | mask = 1 - self.mask_embedding(mask).transpose(1, 2) 46 | x = x.transpose(1, 2) 47 | pool1 = self.pool(x + self._minus * mask[:, 0:1, :]) 48 | pool2 = self.pool(x + self._minus * mask[:, 1:2, :]) 49 | pool3 = self.pool(x + self._minus * mask[:, 2:3, :]) 50 | 51 | x = torch.cat([pool1, pool2, pool3], 1) 52 | # x = x.squeeze(-1) 53 | return x 54 | -------------------------------------------------------------------------------- /opennre/pretrain.py: -------------------------------------------------------------------------------- 1 | from . import encoder 2 | from . import model 3 | from . import framework 4 | import torch 5 | import os 6 | import sys 7 | import json 8 | import numpy as np 9 | import logging 10 | 11 | root_url = "https://thunlp.oss-cn-qingdao.aliyuncs.com/" 12 | default_root_path = os.path.join(os.getenv('HOME'), '.opennre') 13 | 14 | def check_root(root_path=default_root_path): 15 | if not os.path.exists(root_path): 16 | os.mkdir(root_path) 17 | os.mkdir(os.path.join(root_path, 'benchmark')) 18 | os.mkdir(os.path.join(root_path, 'pretrain')) 19 | os.mkdir(os.path.join(root_path, 'pretrain/nre')) 20 | 21 | def download_wiki80(root_path=default_root_path): 22 | check_root() 23 | if not os.path.exists(os.path.join(root_path, 'benchmark/wiki80')): 24 | os.mkdir(os.path.join(root_path, 'benchmark/wiki80')) 25 | os.system('wget -P ' + os.path.join(root_path, 'benchmark/wiki80') + ' ' + root_url + 'opennre/benchmark/wiki80/wiki80_rel2id.json') 26 | os.system('wget -P ' + os.path.join(root_path, 'benchmark/wiki80') + ' ' + root_url + 'opennre/benchmark/wiki80/wiki80_train.txt') 27 | os.system('wget -P ' + os.path.join(root_path, 'benchmark/wiki80') + ' ' + root_url + 'opennre/benchmark/wiki80/wiki80_val.txt') 28 | 29 | def download_tacred(root_path=default_root_path): 30 | check_root() 31 | if not os.path.exists(os.path.join(root_path, 'benchmark/tacred')): 32 | os.mkdir(os.path.join(root_path, 'benchmark/tacred')) 33 | os.system('wget -P ' + os.path.join(root_path, 'benchmark/tacred') + ' ' + root_url + 'opennre/benchmark/tacred/tacred_rel2id.json') 34 | logging.info('Due to copyright limits, we only provide rel2id for TACRED. Please download TACRED manually and convert the data to OpenNRE format if needed.') 35 | 36 | def download_nyt10(root_path=default_root_path): 37 | check_root() 38 | if not os.path.exists(os.path.join(root_path, 'benchmark/nyt10')): 39 | os.mkdir(os.path.join(root_path, 'benchmark/nyt10')) 40 | os.system('wget -P ' + os.path.join(root_path, 'benchmark/nyt10') + ' ' + root_url + 'opennre/benchmark/nyt10/nyt10_rel2id.json') 41 | os.system('wget -P ' + os.path.join(root_path, 'benchmark/nyt10') + ' ' + root_url + 'opennre/benchmark/nyt10/nyt10_train.txt') 42 | os.system('wget -P ' + os.path.join(root_path, 'benchmark/nyt10') + ' ' + root_url + 'opennre/benchmark/nyt10/nyt10_test.txt') 43 | 44 | def download_wiki_distant(root_path=default_root_path): 45 | check_root() 46 | if not os.path.exists(os.path.join(root_path, 'benchmark/wiki_distant')): 47 | os.mkdir(os.path.join(root_path, 'benchmark/wiki_distant')) 48 | os.system('wget -P ' + os.path.join(root_path, 'benchmark/wiki_distant') + ' ' + root_url + 'opennre/benchmark/wiki_distant/wiki_distant_rel2id.json') 49 | os.system('wget -P ' + os.path.join(root_path, 'benchmark/wiki_distant') + ' ' + root_url + 'opennre/benchmark/wiki_distant/wiki_distant_train.txt') 50 | os.system('wget -P ' + os.path.join(root_path, 'benchmark/wiki_distant') + ' ' + root_url + 'opennre/benchmark/wiki_distant/wiki_distant_test.txt') 51 | os.system('wget -P ' + os.path.join(root_path, 'benchmark/wiki_distant') + ' ' + root_url + 'opennre/benchmark/wiki_distant/wiki_distant_val.txt') 52 | 53 | def download_semeval(root_path=default_root_path): 54 | check_root() 55 | if not os.path.exists(os.path.join(root_path, 'benchmark/semeval')): 56 | os.mkdir(os.path.join(root_path, 'benchmark/semeval')) 57 | os.system('wget -P ' + os.path.join(root_path, 'benchmark/semeval') + ' ' + root_url + 'opennre/benchmark/semeval/semeval_rel2id.json') 58 | os.system('wget -P ' + os.path.join(root_path, 'benchmark/semeval') + ' ' + root_url + 'opennre/benchmark/semeval/semeval_train.txt') 59 | os.system('wget -P ' + os.path.join(root_path, 'benchmark/semeval') + ' ' + root_url + 'opennre/benchmark/semeval/semeval_test.txt') 60 | os.system('wget -P ' + os.path.join(root_path, 'benchmark/semeval') + ' ' + root_url + 'opennre/benchmark/semeval/semeval_val.txt') 61 | 62 | def download_glove(root_path=default_root_path): 63 | check_root() 64 | if not os.path.exists(os.path.join(root_path, 'pretrain/glove')): 65 | os.mkdir(os.path.join(root_path, 'pretrain/glove')) 66 | os.system('wget -P ' + os.path.join(root_path, 'pretrain/glove') + ' ' + root_url + 'opennre/pretrain/glove/glove.6B.50d_mat.npy') 67 | os.system('wget -P ' + os.path.join(root_path, 'pretrain/glove') + ' ' + root_url + 'opennre/pretrain/glove/glove.6B.50d_word2id.json') 68 | 69 | def download_bert_base_uncased(root_path=default_root_path): 70 | check_root() 71 | if not os.path.exists(os.path.join(root_path, 'pretrain/bert-base-uncased')): 72 | os.mkdir(os.path.join(root_path, 'pretrain/bert-base-uncased')) 73 | os.system('wget -P ' + os.path.join(root_path, 'pretrain/bert-base-uncased') + ' ' + root_url + 'opennre/pretrain/bert-base-uncased/config.json') 74 | os.system('wget -P ' + os.path.join(root_path, 'pretrain/bert-base-uncased') + ' ' + root_url + 'opennre/pretrain/bert-base-uncased/pytorch_model.bin') 75 | os.system('wget -P ' + os.path.join(root_path, 'pretrain/bert-base-uncased') + ' ' + root_url + 'opennre/pretrain/bert-base-uncased/vocab.txt') 76 | 77 | def download_pretrain(model_name, root_path=default_root_path): 78 | ckpt = os.path.join(root_path, 'pretrain/nre/' + model_name + '.pth.tar') 79 | if not os.path.exists(ckpt): 80 | os.system('wget -P ' + os.path.join(root_path, 'pretrain/nre') + ' ' + root_url + 'opennre/pretrain/nre/' + model_name + '.pth.tar') 81 | 82 | def download(name, root_path=default_root_path): 83 | if not os.path.exists(os.path.join(root_path, 'benchmark')): 84 | os.mkdir(os.path.join(root_path, 'benchmark')) 85 | if not os.path.exists(os.path.join(root_path, 'pretrain')): 86 | os.mkdir(os.path.join(root_path, 'pretrain')) 87 | if name == 'nyt10': 88 | download_nyt10(root_path=root_path) 89 | elif name == 'wiki_distant': 90 | download_wiki_distant(root_path=root_path) 91 | elif name == 'semeval': 92 | download_semeval(root_path=root_path) 93 | elif name == 'wiki80': 94 | download_wiki80(root_path=root_path) 95 | elif name == 'tacred': 96 | download_tacred(root_path=root_path) 97 | elif name == 'glove': 98 | download_glove(root_path=root_path) 99 | elif name == 'liter': 100 | pass 101 | elif name == 'brand': 102 | pass 103 | elif name == 'bert_base_uncased': 104 | download_bert_base_uncased(root_path=root_path) 105 | else: 106 | raise Exception('不能找到对应的数据') 107 | 108 | def get_model(model_name, root_path=default_root_path): 109 | check_root() 110 | ckpt = os.path.join(root_path, 'pretrain/nre/' + model_name + '.pth.tar') 111 | if model_name == 'wiki80_cnn_softmax': 112 | download_pretrain(model_name, root_path=root_path) 113 | download('glove', root_path=root_path) 114 | download('wiki80', root_path=root_path) 115 | wordi2d = json.load(open(os.path.join(root_path, 'pretrain/glove/glove.6B.50d_word2id.json'))) 116 | word2vec = np.load(os.path.join(root_path, 'pretrain/glove/glove.6B.50d_mat.npy')) 117 | rel2id = json.load(open(os.path.join(root_path, 'benchmark/wiki80/wiki80_rel2id.json'))) 118 | sentence_encoder = encoder.CNNEncoder(token2id=wordi2d, 119 | max_length=40, 120 | word_size=50, 121 | position_size=5, 122 | hidden_size=230, 123 | blank_padding=True, 124 | kernel_size=3, 125 | padding_size=1, 126 | word2vec=word2vec, 127 | dropout=0.5) 128 | m = model.SoftmaxNN(sentence_encoder, len(rel2id), rel2id) 129 | m.load_state_dict(torch.load(ckpt, map_location='cpu')['state_dict']) 130 | return m 131 | elif model_name in ['wiki80_bert_softmax', 'wiki80_bertentity_softmax']: 132 | download_pretrain(model_name, root_path=root_path) 133 | download('bert_base_uncased', root_path=root_path) 134 | download('wiki80', root_path=root_path) 135 | rel2id = json.load(open(os.path.join(root_path, 'benchmark/wiki80/wiki80_rel2id.json'))) 136 | if 'entity' in model_name: 137 | sentence_encoder = encoder.BERTEntityEncoder( 138 | max_length=80, pretrain_path=os.path.join(root_path, 'pretrain/bert-base-uncased')) 139 | else: 140 | sentence_encoder = encoder.BERTEncoder( 141 | max_length=80, pretrain_path=os.path.join(root_path, 'pretrain/bert-base-uncased')) 142 | m = model.SoftmaxNN(sentence_encoder, len(rel2id), rel2id) 143 | m.load_state_dict(torch.load(ckpt, map_location='cpu')['state_dict']) 144 | return m 145 | elif model_name in ['tacred_bert_softmax', 'tacred_bertentity_softmax']: 146 | download_pretrain(model_name, root_path=root_path) 147 | download('bert_base_uncased', root_path=root_path) 148 | download('tacred', root_path=root_path) 149 | rel2id = json.load(open(os.path.join(root_path, 'benchmark/tacred/tacred_rel2id.json'))) 150 | if 'entity' in model_name: 151 | sentence_encoder = encoder.BERTEntityEncoder( 152 | max_length=80, pretrain_path=os.path.join(root_path, 'pretrain/bert-base-uncased')) 153 | else: 154 | sentence_encoder = encoder.BERTEncoder( 155 | max_length=80, pretrain_path=os.path.join(root_path, 'pretrain/bert-base-uncased')) 156 | m = model.SoftmaxNN(sentence_encoder, len(rel2id), rel2id) 157 | m.load_state_dict(torch.load(ckpt, map_location='cpu')['state_dict']) 158 | return m 159 | else: 160 | raise NotImplementedError 161 | -------------------------------------------------------------------------------- /opennre/tokenization/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | 3 | from __future__ import absolute_import 4 | from __future__ import division 5 | from __future__ import print_function 6 | 7 | from .basic_tokenizer import BasicTokenizer 8 | from .word_piece_tokenizer import WordpieceTokenizer 9 | from .word_tokenizer import WordTokenizer 10 | from .bert_tokenizer import BertTokenizer 11 | 12 | __all__ = [ 13 | 'BasicTokenizer', 14 | 'WordpieceTokenizer', 15 | 'WordTokenizer', 16 | 'BertTokenizer', 17 | ] 18 | 19 | 20 | -------------------------------------------------------------------------------- /opennre/tokenization/basic_tokenizer.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """BasicTokenizer classes.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | from .utils import (convert_to_unicode, 23 | clean_text, 24 | split_on_whitespace, 25 | split_on_punctuation, 26 | tokenize_chinese_chars, 27 | strip_accents) 28 | 29 | class BasicTokenizer(object): 30 | """Runs basic tokenization (punctuation splitting, lower casing, etc.).""" 31 | 32 | def __init__(self, 33 | do_lower_case=True, 34 | never_split=("[UNK]", "[SEP]", "[PAD]", "[CLS]", "[MASK]")): 35 | """Constructs a BasicTokenizer. 36 | Args: 37 | do_lower_case: Whether to lower case the input. 38 | """ 39 | self.do_lower_case = do_lower_case 40 | self.never_split = never_split 41 | 42 | def tokenize(self, text): 43 | """Tokenizes a piece of text.""" 44 | text = convert_to_unicode(text) 45 | text = clean_text(text) 46 | text = tokenize_chinese_chars(text) 47 | # This was added on November 1st, 2018 for the multilingual and Chinese 48 | # models. This is also applied to the English models now, but it doesn't 49 | # matter since the English models were not trained on any Chinese data 50 | # and generally don't have any Chinese data in them (there are Chinese 51 | # characters in the vocabulary because Wikipedia does have some Chinese 52 | # words in the English Wikipedia.). 53 | orig_tokens = split_on_whitespace(text) 54 | split_tokens = [] 55 | current_positions = [] 56 | for token in orig_tokens: 57 | if self.do_lower_case and token not in self.never_split: 58 | token = token.lower() 59 | token = strip_accents(token) 60 | current_positions.append([]) 61 | current_positions[-1].append(len(split_tokens)) 62 | split_tokens.extend(split_on_punctuation(token)) 63 | current_positions[-1].append(len(split_tokens)) 64 | return split_tokens, current_positions 65 | -------------------------------------------------------------------------------- /opennre/tokenization/bert_tokenizer.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """BertTokenizer classes.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import collections 23 | import unicodedata 24 | import six 25 | 26 | from .basic_tokenizer import BasicTokenizer 27 | from .word_piece_tokenizer import WordpieceTokenizer 28 | from .utils import (convert_to_unicode, 29 | load_vocab, 30 | convert_by_vocab, 31 | convert_tokens_to_ids, 32 | convert_ids_to_tokens) 33 | 34 | class BertTokenizer(object): 35 | 36 | def __init__(self, 37 | vocab = None, 38 | do_lower_case = True, 39 | do_basic_tokenize = True, 40 | never_split=["[UNK]", "[SEP]", "[PAD]", "[CLS]", "[MASK]"]): 41 | 42 | self.vocab = load_vocab(vocab) 43 | self.inv_vocab = {v: k for k, v in self.vocab.items()} 44 | self.basic_tokenizer = BasicTokenizer(do_lower_case = do_lower_case, never_split = never_split) 45 | self.wordpiece_tokenizer = WordpieceTokenizer(vocab = self.vocab) 46 | self.do_basic_tokenize = do_basic_tokenize 47 | 48 | def tokenize(self, text): 49 | split_tokens = [] 50 | if self.do_basic_tokenize: 51 | tokens, _ = self.basic_tokenizer.tokenize(text) 52 | text = " ".join(tokens) 53 | split_tokens, current_positions = self.wordpiece_tokenizer.tokenize(text) 54 | return split_tokens, current_positions 55 | 56 | def convert_tokens_to_ids(self, tokens): 57 | return convert_by_vocab(self.vocab, tokens) 58 | 59 | def convert_ids_to_tokens(self, ids): 60 | return convert_by_vocab(self.inv_vocab, ids) 61 | -------------------------------------------------------------------------------- /opennre/tokenization/utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import collections 6 | import unicodedata 7 | import six 8 | 9 | def is_whitespace(char): 10 | """ Checks whether `chars` is a whitespace character. 11 | \t, \n, and \r are technically contorl characters but we treat them 12 | as whitespace since they are generally considered as such. 13 | """ 14 | if char == " " or char == "\t" or char == "\n" or char == "\r": 15 | return True 16 | cat = unicodedata.category(char) 17 | if cat == "Zs": 18 | return True 19 | return False 20 | 21 | def is_control(char): 22 | """ Checks whether `chars` is a control character. 23 | These are technically control characters but we count them as whitespace characters. 24 | """ 25 | if char == "\t" or char == "\n" or char == "\r": 26 | return False 27 | cat = unicodedata.category(char) 28 | if cat.startswith("C"): 29 | return True 30 | return False 31 | 32 | def is_punctuation(char): 33 | """ Checks whether `chars` is a punctuation character. 34 | We treat all non-letter/number ASCII as punctuation. Characters such as "^", "$", and "`" are not in the Unicode. 35 | Punctuation class but we treat them as punctuation anyways, for consistency. 36 | """ 37 | cp = ord(char) 38 | if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)): 39 | return True 40 | cat = unicodedata.category(char) 41 | if cat.startswith("P"): 42 | return True 43 | return False 44 | 45 | def is_chinese_char(cp): 46 | """ Checks whether CP is the codepoint of a CJK character. 47 | This defines a "chinese character" as anything in the CJK Unicode block: 48 | https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block) 49 | Note that the CJK Unicode block is NOT all Japanese and Korean characters, 50 | despite its name. The modern Korean Hangul alphabet is a different block, 51 | as is Japanese Hiragana and Katakana. Those alphabets are used to write 52 | space-separated words, so they are not treated specially and handled 53 | like the all of the other languages. 54 | """ 55 | if ((cp >= 0x4E00 and cp <= 0x9FFF) or 56 | (cp >= 0x3400 and cp <= 0x4DBF) or 57 | (cp >= 0x20000 and cp <= 0x2A6DF) or 58 | (cp >= 0x2A700 and cp <= 0x2B73F) or 59 | (cp >= 0x2B740 and cp <= 0x2B81F) or 60 | (cp >= 0x2B820 and cp <= 0x2CEAF) or 61 | (cp >= 0xF900 and cp <= 0xFAFF) or 62 | (cp >= 0x2F800 and cp <= 0x2FA1F)): 63 | return True 64 | return False 65 | 66 | def convert_to_unicode(text): 67 | """Converts `text` to Unicode (if it's not already), assuming utf-8 input.""" 68 | if six.PY3: 69 | if isinstance(text, str): 70 | return text 71 | elif isinstance(text, bytes): 72 | return text.decode("utf-8", "ignore") 73 | else: 74 | raise ValueError("Unsupported string type: %s" % (type(text))) 75 | elif six.PY2: 76 | if isinstance(text, str): 77 | return text.decode("utf-8", "ignore") 78 | elif isinstance(text, unicode): 79 | return text 80 | else: 81 | raise ValueError("Unsupported string type: %s" % (type(text))) 82 | else: 83 | raise ValueError("Not running on Python2 or Python 3?") 84 | 85 | def clean_text(text): 86 | output = [] 87 | for char in text: 88 | cp = ord(char) 89 | if cp == 0 or cp == 0xfffd or is_control(char): 90 | continue 91 | if is_whitespace(char): 92 | output.append(" ") 93 | else: 94 | output.append(char) 95 | return "".join(output) 96 | 97 | def split_on_whitespace(text): 98 | """ Runs basic whitespace cleaning and splitting on a peice of text. 99 | e.g, 'a b c' -> ['a', 'b', 'c'] 100 | """ 101 | text = text.strip() 102 | if not text: 103 | return [] 104 | return text.split() 105 | 106 | def split_on_punctuation(text): 107 | """Splits punctuation on a piece of text.""" 108 | start_new_word = True 109 | output = [] 110 | for char in text: 111 | if is_punctuation(char): 112 | output.append([char]) 113 | start_new_word = True 114 | else: 115 | if start_new_word: 116 | output.append([]) 117 | start_new_word = False 118 | output[-1].append(char) 119 | return ["".join(x) for x in output] 120 | 121 | def tokenize_chinese_chars(text): 122 | """Adds whitespace around any CJK character.""" 123 | output = [] 124 | for char in text: 125 | cp = ord(char) 126 | if is_chinese_char(cp): 127 | output.append(" ") 128 | output.append(char) 129 | output.append(" ") 130 | else: 131 | output.append(char) 132 | return "".join(output) 133 | 134 | def strip_accents(text): 135 | """Strips accents from a piece of text.""" 136 | text = unicodedata.normalize("NFD", text) 137 | output = [] 138 | for char in text: 139 | cat = unicodedata.category(char) 140 | if cat == "Mn": 141 | continue 142 | output.append(char) 143 | return "".join(output) 144 | 145 | def load_vocab(vocab_file): 146 | """Loads a vocabulary file into a dictionary.""" 147 | if vocab_file == None: 148 | raise ValueError("Unsupported string type: %s" % (type(text))) 149 | if isinstance(vocab_file, str) or isinstance(vocab_file, bytes): 150 | vocab = collections.OrderedDict() 151 | index = 0 152 | with open(vocab_file, "r", encoding="utf-8") as reader: 153 | while True: 154 | token = reader.readline() 155 | if not token: 156 | break 157 | token = token.strip() 158 | vocab[token] = index 159 | index += 1 160 | return vocab 161 | else: 162 | return vocab_file 163 | 164 | def printable_text(text): 165 | """ Returns text encoded in a way suitable for print or `tf.logging`. 166 | These functions want `str` for both Python2 and Python3, but in one case 167 | it's a Unicode string and in the other it's a byte string. 168 | """ 169 | if six.PY3: 170 | if isinstance(text, str): 171 | return text 172 | elif isinstance(text, bytes): 173 | return text.decode("utf-8", "ignore") 174 | else: 175 | raise ValueError("Unsupported string type: %s" % (type(text))) 176 | elif six.PY2: 177 | if isinstance(text, str): 178 | return text 179 | elif isinstance(text, unicode): 180 | return text.encode("utf-8") 181 | else: 182 | raise ValueError("Unsupported string type: %s" % (type(text))) 183 | else: 184 | raise ValueError("Not running on Python2 or Python 3?") 185 | 186 | def convert_by_vocab(vocab, items, max_seq_length = None, blank_id = 0, unk_id = 1, uncased = True): 187 | """Converts a sequence of [tokens|ids] using the vocab.""" 188 | output = [] 189 | for item in items: 190 | if uncased: 191 | item = item.lower() 192 | if item in vocab: 193 | output.append(vocab[item]) 194 | else: 195 | output.append(unk_id) 196 | if max_seq_length != None: 197 | if len(output) > max_seq_length: 198 | output = output[:max_seq_length] 199 | else: 200 | while len(output) < max_seq_length: 201 | output.append(blank_id) 202 | return output 203 | 204 | def convert_tokens_to_ids(vocab, tokens, max_seq_length = None, blank_id = 0, unk_id = 1): 205 | return convert_by_vocab(vocab, tokens, max_seq_length, blank_id, unk_id) 206 | 207 | def convert_ids_to_tokens(inv_vocab, ids): 208 | return convert_by_vocab(inv_vocab, ids) 209 | 210 | def truncate_seq_pair(tokens_a, tokens_b, max_num_tokens, rng): 211 | """Truncates a pair of sequences to a maximum sequence length.""" 212 | while True: 213 | total_length = len(tokens_a) + len(tokens_b) 214 | if total_length <= max_num_tokens: 215 | break 216 | trunc_tokens = tokens_a if len(tokens_a) > len(tokens_b) else tokens_b 217 | assert len(trunc_tokens) >= 1 218 | # We want to sometimes truncate from the front and sometimes from the 219 | # back to add more randomness and avoid biases. 220 | if rng.random() < 0.5: 221 | del trunc_tokens[0] 222 | else: 223 | trunc_tokens.pop() 224 | 225 | def create_int_feature(values): 226 | feature = tf.train.Feature(int64_list=tf.train.Int64List(value=list(values))) 227 | return feature 228 | 229 | def create_float_feature(values): 230 | feature = tf.train.Feature(float_list=tf.train.FloatList(value=list(values))) 231 | return feature 232 | 233 | def add_token(tokens_a, tokens_b = None): 234 | assert len(tokens_a) >= 1 235 | 236 | tokens = [] 237 | segment_ids = [] 238 | 239 | tokens.append("[CLS]") 240 | segment_ids.append(0) 241 | 242 | for token in tokens_a: 243 | tokens.append(token) 244 | segment_ids.append(0) 245 | 246 | tokens.append("[SEP]") 247 | segment_ids.append(0) 248 | 249 | if tokens_b != None: 250 | assert len(tokens_b) >= 1 251 | 252 | for token in tokens_b: 253 | tokens.append(token) 254 | segment_ids.append(1) 255 | 256 | tokens.append("[SEP]") 257 | segment_ids.append(1) 258 | 259 | return tokens, segment_ids 260 | -------------------------------------------------------------------------------- /opennre/tokenization/word_piece_tokenizer.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """WordpieceTokenizer classes.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import unicodedata 23 | 24 | from .utils import (load_vocab, 25 | convert_to_unicode, 26 | clean_text, 27 | split_on_whitespace, 28 | convert_by_vocab, 29 | tokenize_chinese_chars) 30 | 31 | class WordpieceTokenizer(object): 32 | """Runs WordPiece tokenziation.""" 33 | 34 | def __init__(self, vocab = None, unk_token="[UNK]", max_input_chars_per_word=200): 35 | self.vocab = load_vocab(vocab) 36 | self.inv_vocab = {v: k for k, v in self.vocab.items()} 37 | self.unk_token = unk_token 38 | self.max_input_chars_per_word = max_input_chars_per_word 39 | 40 | def tokenize(self, text): 41 | """ Tokenizes a piece of text into its word pieces. This uses a greedy longest-match-first algorithm to perform tokenization 42 | using the given vocabulary. 43 | 44 | For example: 45 | input = "unaffable" 46 | output = ["un", "##aff", "##able"] 47 | 48 | Args: 49 | text: A single token or whitespace separated tokens. This should have already been passed through `BasicTokenizer`. 50 | Returns: 51 | output_tokens: A list of wordpiece tokens. 52 | current_positions: A list of the current positions for the original words in text . 53 | """ 54 | text = convert_to_unicode(text) 55 | text = clean_text(text) 56 | text = tokenize_chinese_chars(text) 57 | output_tokens = [] 58 | current_positions = [] 59 | token_list = split_on_whitespace(text) 60 | for chars in token_list: 61 | if len(chars) > self.max_input_chars_per_word: 62 | output_tokens.append(self.unk_token) 63 | continue 64 | is_bad = False 65 | start = 0 66 | sub_tokens = [] 67 | while start < len(chars): 68 | end = len(chars) 69 | if start > 0: 70 | substr = "##" + chars[start:end] 71 | else: 72 | substr = chars[start:end] 73 | cur_substr = None 74 | while start < end: 75 | if substr in self.vocab: 76 | cur_substr = substr 77 | break 78 | end -= 1 79 | substr = substr[:-1] 80 | if cur_substr is None: 81 | is_bad = True 82 | break 83 | else: 84 | sub_tokens.append(cur_substr) 85 | start = end 86 | current_positions.append([]) 87 | if is_bad: 88 | current_positions[-1].append(len(output_tokens)) 89 | output_tokens.append(self.unk_token) 90 | current_positions[-1].append(len(output_tokens)) 91 | else: 92 | current_positions[-1].append(len(output_tokens)) 93 | output_tokens.extend(sub_tokens) 94 | current_positions[-1].append(len(output_tokens)) 95 | 96 | return output_tokens, current_positions 97 | 98 | def convert_tokens_to_ids(self, tokens): 99 | return convert_by_vocab(self.vocab, tokens) 100 | 101 | def convert_ids_to_tokens(self, ids): 102 | return convert_by_vocab(self.inv_vocab, ids) 103 | -------------------------------------------------------------------------------- /opennre/tokenization/word_tokenizer.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """WordpieceTokenizer classes.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import unicodedata 23 | 24 | from .utils import (load_vocab, 25 | convert_to_unicode, 26 | clean_text, 27 | split_on_whitespace, 28 | convert_by_vocab, 29 | tokenize_chinese_chars) 30 | 31 | class WordTokenizer(object): 32 | """Runs WordPiece tokenziation.""" 33 | 34 | def __init__(self, vocab = None, unk_token="[UNK]"): 35 | self.vocab = load_vocab(vocab) 36 | self.inv_vocab = {v: k for k, v in self.vocab.items()} 37 | self.unk_token = unk_token 38 | 39 | def tokenize(self, text): 40 | """ Tokenizes a piece of text into its word pieces. This uses a greedy longest-match-first algorithm to perform tokenization 41 | using the given vocabulary. 42 | 43 | For example: 44 | input = "unaffable" 45 | output = ["un", "##aff", "##able"] 46 | 47 | Args: 48 | text: A single token or whitespace separated tokens. This should have already been passed through `BasicTokenizer`. 49 | Returns: 50 | output_tokens: A list of wordpiece tokens. 51 | current_positions: A list of the current positions for the original words in text . 52 | """ 53 | text = convert_to_unicode(text) 54 | text = clean_text(text) 55 | text = tokenize_chinese_chars(text) 56 | # output_tokens = [] 57 | token_list = split_on_whitespace(text) 58 | # for chars in token_list: 59 | # # current_positions.append([]) 60 | # if chars in self.vocab: 61 | # output_tokens.append(chars) 62 | # else: 63 | # output_tokens.append(self.unk_token) 64 | return token_list 65 | 66 | def convert_tokens_to_ids(self, tokens, max_seq_length = None, blank_id = 0, unk_id = 1, uncased = True): 67 | return convert_by_vocab(self.vocab, tokens, max_seq_length, blank_id, unk_id, uncased=uncased) 68 | 69 | def convert_ids_to_tokens(self, ids): 70 | return convert_by_vocab(self.inv_vocab, ids) 71 | -------------------------------------------------------------------------------- /pretrain/download_bert.sh: -------------------------------------------------------------------------------- 1 | mkdir bert-base-uncased 2 | wget -P bert-base-uncased https://thunlp.oss-cn-qingdao.aliyuncs.com/opennre/pretrain/bert-base-uncased/config.json 3 | wget -P bert-base-uncased https://thunlp.oss-cn-qingdao.aliyuncs.com/opennre/pretrain/bert-base-uncased/pytorch_model.bin 4 | wget -P bert-base-uncased https://thunlp.oss-cn-qingdao.aliyuncs.com/opennre/pretrain/bert-base-uncased/vocab.txt 5 | -------------------------------------------------------------------------------- /pretrain/download_glove.sh: -------------------------------------------------------------------------------- 1 | mkdir glove 2 | wget -P glove https://thunlp.oss-cn-qingdao.aliyuncs.com/opennre/pretrain/glove/glove.6B.50d_mat.npy 3 | wget -P glove https://thunlp.oss-cn-qingdao.aliyuncs.com/opennre/pretrain/glove/glove.6B.50d_word2id.json 4 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch 2 | transformers==4.1.1 3 | pytest 4 | scikit-learn 5 | scipy 6 | nltk 7 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import setuptools 2 | with open("README.md", "r") as fh: 3 | setuptools.setup( 4 | name='opennre', 5 | version='0.1', 6 | author="Tianyu Gao", 7 | author_email="gaotianyu1350@126.com", 8 | description="An open source toolkit for relation extraction", 9 | url="https://github.com/thunlp/opennre", 10 | packages=setuptools.find_packages(), 11 | classifiers=[ 12 | "Programming Language :: Python :: 3", 13 | "License :: OSI Approved :: MIT License", 14 | "Operating System :: Linux", 15 | ], 16 | setup_requires=['wheel'] 17 | ) 18 | -------------------------------------------------------------------------------- /tests/test_inference.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import opennre 3 | 4 | class TestInference(unittest.TestCase): 5 | 6 | def test_wiki80_cnn_softmax(self): 7 | model = opennre.get_model('wiki80_cnn_softmax') 8 | result = model.infer({'text': 'He was the son of Máel Dúin mac Máele Fithrich, and grandson of the high king Áed Uaridnach (died 612).', 'h': {'pos': (18, 46)}, 't': {'pos': (78, 91)}}) 9 | print(result) 10 | self.assertEqual(result[0], 'father') 11 | self.assertTrue(abs(result[1] - 0.7500484585762024) < 1e-6) 12 | 13 | if __name__ == '__main__': 14 | unittest.main() 15 | --------------------------------------------------------------------------------