├── .gitignore ├── LICENSE ├── README.md ├── benchmark ├── download_fewrel.sh ├── download_nyt10.sh ├── download_nyt10m.sh ├── download_semeval.sh ├── download_wiki20m.sh ├── download_wiki80.sh └── rel4wiki80.md ├── example ├── test_multilabel_bert.py ├── test_multilabel_cnn.py ├── train_bag_bert.py ├── train_bag_cnn.py ├── train_supervised_bert.py └── train_supervised_cnn.py ├── opennre ├── __init__.py ├── encoder │ ├── __init__.py │ ├── base_encoder.py │ ├── bert_encoder.py │ ├── cnn_encoder.py │ └── pcnn_encoder.py ├── framework │ ├── __init__.py │ ├── bag_re.py │ ├── data_loader.py │ ├── multi_label_sentence_re.py │ ├── sentence_re.py │ └── utils.py ├── model │ ├── __init__.py │ ├── bag_attention.py │ ├── bag_average.py │ ├── bag_one.py │ ├── base_model.py │ ├── sigmoid_nn.py │ └── softmax_nn.py ├── module │ ├── __init__.py │ ├── nn │ │ ├── __init__.py │ │ ├── cnn.py │ │ ├── lstm.py │ │ └── rnn.py │ └── pool │ │ ├── __init__.py │ │ ├── avg_pool.py │ │ └── max_pool.py ├── pretrain.py └── tokenization │ ├── __init__.py │ ├── basic_tokenizer.py │ ├── bert_tokenizer.py │ ├── utils.py │ ├── word_piece_tokenizer.py │ └── word_tokenizer.py ├── pretrain ├── download_bert.sh └── download_glove.sh ├── requirements.txt ├── setup.py └── tests └── test_inference.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | 106 | # test 107 | test.py 108 | 109 | # ckpt 110 | ckpt 111 | 112 | # vscode 113 | .vscode 114 | 115 | # tacred 116 | benchmark/tacred 117 | *.swp 118 | 119 | <<<<<<< HEAD 120 | # data and pretrain 121 | pretrain 122 | benchmark 123 | !benchmark/*.sh 124 | !pretrain/*.sh 125 | 126 | # test env 127 | .test 128 | 129 | # package 130 | opennre-egg.info 131 | 132 | *.sh 133 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Tianyu Gao 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # OpenNRE (sub-project of OpenSKL) 2 | 3 | OpenNRE is a sub-project of OpenSKL, providing an **Open**-source **N**eural **R**elation **E**xtraction toolkit for extracting structured knowledge from plain text, with [ATT](https://aclanthology.org/P16-1200.pdf) as key features to consider relation-associated text information. 4 | 5 | ## Overview 6 | 7 | OpenNRE is an open-source and extensible toolkit that provides a unified framework to implement relation extraction models. We unify the input and output interfaces of different relation extraction models and provide scalable options for each model. The toolkit covers both supervised and distant supervised settings, and is compatible with both conventional neural networks and pre-trained language models. 8 | 9 | Relation extraction is a natural language processing (NLP) task aiming at extracting relations (e.g., *founder of*) between entities (e.g., **Bill Gates** and **Microsoft**). For example, from the sentence *Bill Gates founded Microsoft*, we can extract the relation triple (**Bill Gates**, *founder of*, **Microsoft**). 10 | 11 | Relation extraction is a crucial technique in automatic knowledge graph construction. By using relation extraction, we can accumulatively extract new relation facts and expand the knowledge graph, which, as a way for machines to understand the human world, has many downstream applications like question answering, recommender system and search engine. If you want to learn more about neural relation extraction, visit another project of ours ([NREPapers](https://github.com/thunlp/NREPapers)). 12 | 13 | It's our honor to help you better explore relation extraction with our OpenNRE toolkit! You can refer to our [document](https://opennre-docs.readthedocs.io/en/latest/) for more details about this project. 14 | 15 | 24 | 25 | ## Models 26 | In this toolkit, we support CNN-based relation extraction models including standard CNN and our proposed [CNN+ATT](https://aclanthology.org/P16-1200v2.pdf). We also implement methods based on pre-trained language models (BERT). 27 | 28 | ## Evaluation 29 | 30 | To validate the effectiveness of this toolkit, we employ the Bag-Level Relation Extraction task for evaluation. 31 | 32 | ### Settings 33 | 34 | We utilize the NYT10 dataset, which is a distantly supervised collection derived from the New York Times corpus and FreeBase. We mainly experiment on CNN-ATT model, which employs instance-level attention and shows superior performance compared with vanilla CNN. 35 | 36 | ### Results 37 | 38 | We report AUC and F1 scores of two models. The right two columns marked with (\*) indicates the results sourced from [Gao et al.(2021)](https://aclanthology.org/2021.findings-acl.112.pdf) and [Lin et al.(2016)](https://aclanthology.org/P16-1200v2.pdf). The results show that our implementation of CNN-ATT model is slighly better than the original paper, and also confirm the better performance of CNN-ATT over standard CNN model. 39 | 40 | | Model | AUC | F1 | AUC(Paper \*) | F1(Paper \*) | 41 | | :-: | :-: | :-: | :-: | :-: | 42 | | CNN | - | - | 0.212 | 0.318 | 43 | | CNN-ATT | 0.333 | 0.397 | 0.318 | 0.380 | 44 | 45 | ## Usage 46 | 47 | ### Installation 48 | 49 | #### Install as A Python Package 50 | 51 | We are now working on deploy OpenNRE as a Python package. Coming soon! 52 | 53 | #### Using Git Repository 54 | 55 | Clone the repository from our github page (don't forget to star us!) 56 | 57 | ```bash 58 | git clone https://github.com/thunlp/OpenNRE.git 59 | ``` 60 | 61 | If it is too slow, you can try 62 | ``` 63 | git clone https://github.com/thunlp/OpenNRE.git --depth 1 64 | ``` 65 | 66 | Then install all the requirements: 67 | 68 | ``` 69 | pip install -r requirements.txt 70 | ``` 71 | 72 | **Note**: Please choose appropriate PyTorch version based on your machine (related to your CUDA version). For details, refer to https://pytorch.org/. 73 | 74 | Then install the package with 75 | ``` 76 | python setup.py install 77 | ``` 78 | 79 | If you also want to modify the code, run this: 80 | ``` 81 | python setup.py develop 82 | ``` 83 | 84 | Note that we have excluded all data and pretrain files for fast deployment. You can manually download them by running scripts in the ``benchmark`` and ``pretrain`` folders. For example, if you want to download FewRel dataset, you can run 85 | 86 | ```bash 87 | bash benchmark/download_fewrel.sh 88 | ``` 89 | ### Data 90 | 91 | You can go into the `benchmark` folder and download datasets using our scripts. We also list some of the information about the datasets in [this document](https://opennre-docs.readthedocs.io/en/latest/get_started/benchmark.html#bag-level-relation-extraction). We provide two distantly-supervised datasets with human-annotated test sets, **NYT10m** and **Wiki20m**. Check the [datasets](#datasets) section for details. 92 | 93 | ### Easy Start 94 | 95 | Make sure you have installed OpenNRE as instructed above. Then import our package and load pre-trained models. 96 | 97 | ```python 98 | >>> import opennre 99 | >>> model = opennre.get_model('wiki80_cnn_softmax') 100 | ``` 101 | 102 | Note that it may take a few minutes to download checkpoint and data for the first time. Then use `infer` to do sentence-level relation extraction 103 | 104 | ```python 105 | >>> model.infer({'text': 'He was the son of Máel Dúin mac Máele Fithrich, and grandson of the high king Áed Uaridnach (died 612).', 'h': {'pos': (18, 46)}, 't': {'pos': (78, 91)}}) 106 | ('father', 0.5108704566955566) 107 | ``` 108 | 109 | You will get the relation result and its confidence score. 110 | 111 | If you want to use the model on your GPU, just run 112 | ```python 113 | >>> model = model.cuda() 114 | ``` 115 | before calling the inference function. 116 | 117 | For now, we have the following available models: 118 | 119 | * `wiki80_cnn_softmax`: trained on `wiki80` dataset with a CNN encoder. 120 | * `wiki80_bert_softmax`: trained on `wiki80` dataset with a BERT encoder. 121 | * `wiki80_bertentity_softmax`: trained on `wiki80` dataset with a BERT encoder (using entity representation concatenation). 122 | * `tacred_bert_softmax`: trained on `TACRED` dataset with a BERT encoder. 123 | * `tacred_bertentity_softmax`: trained on `TACRED` dataset with a BERT encoder (using entity representation concatenation). 124 | 125 | ### Training 126 | 127 | You can train your own models on your own data with OpenNRE. In `example` folder we give example training codes for supervised RE models and bag-level RE models. You can either use our provided datasets or your own datasets. For example, you can use the following script to train a PCNN-ATT bag-level model on the NYT10 dataset with manual test set. The ATT algorithm is a typical method to combine a bag of sentences for extracting relations between entities. 128 | 129 | ```bash 130 | python example/train_bag_cnn.py \ 131 | --metric auc \ 132 | --dataset nyt10m \ 133 | --batch_size 160 \ 134 | --lr 0.1 \ 135 | --weight_decay 1e-5 \ 136 | --max_epoch 100 \ 137 | --max_length 128 \ 138 | --seed 42 \ 139 | --encoder pcnn \ 140 | --aggr att 141 | ``` 142 | 143 | Or use the following script to train a BERT model on the Wiki80 dataset: 144 | ```bash 145 | python example/train_supervised_bert.py \ 146 | --pretrain_path bert-base-uncased \ 147 | --dataset wiki80 148 | ``` 149 | 150 | We provide many options in the example training code and you can check them out for detailed instructions. 151 | 152 | ## Citation 153 | 154 | If you find OpenNRE is useful for your research, please consider citing the following papers: 155 | 156 | ``` 157 | @inproceedings{han-etal-2019-opennre, 158 | title = "{O}pen{NRE}: An Open and Extensible Toolkit for Neural Relation Extraction", 159 | author = "Han, Xu and Gao, Tianyu and Yao, Yuan and Ye, Deming and Liu, Zhiyuan and Sun, Maosong", 160 | booktitle = "Proceedings of EMNLP-IJCNLP: System Demonstrations", 161 | year = "2019", 162 | url = "https://www.aclweb.org/anthology/D19-3029", 163 | doi = "10.18653/v1/D19-3029", 164 | pages = "169--174" 165 | } 166 | ``` 167 | 168 | This package is mainly contributed by [Tianyu Gao](https://github.com/gaotianyu1350), [Xu Han](https://github.com/THUCSTHanxu13), [Shulian Cao](https://github.com/ShulinCao), [Lumin Tang](https://github.com/Tsingularity), [Yankai Lin](https://github.com/Mrlyk423), [Zhiyuan Liu](http://nlp.csai.tsinghua.edu.cn/~lzy/) 169 | 170 | 171 | ****************** 172 | ## About OpenSKL 173 | OpenSKL project aims to harness the power of both structured knowledge and natural languages via representation learning. All sub-projects of OpenSKL, under the categories of **Algorithm**, **Resource** and **Application**, are as follows. 174 | 175 | - **Algorithm**: 176 | - [OpenKE](https://www.github.com/thunlp/OpenKE) 177 | - An effective and efficient toolkit for representing structured knowledge in large-scale knowledge graphs as embeddings, with TransR and PTransE as key features to handle complex relations and relational paths. 178 | - This toolkit also includes three repositories: 179 | - [KB2E](https://www.github.com/thunlp/KB2E) 180 | - [TensorFlow-Transx](https://www.github.com/thunlp/TensorFlow-Transx) 181 | - [Fast-TransX](https://www.github.com/thunlp/Fast-TransX) 182 | - [ERNIE](https://github.com/thunlp/ERNIE) 183 | - An effective and efficient toolkit for augmenting pre-trained language models with knowledge graph representations. 184 | - [OpenNE](https://www.github.com/thunlp/OpenNE) 185 | - An effective and efficient toolkit for representing nodes in large-scale graphs as embeddings, with [TADW](https://www.ijcai.org/Proceedings/15/Papers/299.pdf) as key features to incorporate text attributes of nodes. 186 | - [OpenNRE](https://www.github.com/thunlp/OpenNRE) 187 | - An effective and efficient toolkit for implementing neural networks for extracting structured knowledge from text, with [ATT](https://aclanthology.org/P16-1200.pdf) as key features to consider relation-associated text information. 188 | - This toolkit also includes two repositories: 189 | - [JointNRE](https://www.github.com/thunlp/JointNRE) 190 | - [NRE](https://github.com/thunlp/NRE) 191 | - **Resource**: 192 | - The embeddings of large-scale knowledge graphs pre-trained by OpenKE, covering three typical large-scale knowledge graphs: Wikidata, Freebase, and XLORE. The embeddings are free to use under the [MIT license](https://opensource.org/license/mit/), and please click the following link to submit [download requests](http://139.129.163.161/download/wikidata). 193 | - OpenKE-Wikidata 194 | - Wikidata is a free and collaborative database, collecting structured data to provide support for Wikipedia. The original Wikidata contains 20,982,733 entities, 594 relations and 68,904,773 triplets. In particular, Wikidata-5M is the core subgraph of Wikidata, containing 5,040,986 high-frequency entities from Wikidata with their corresponding 927 relations and 24,267,796 triplets. 195 | - [TransE version](http://139.129.163.161/download/wikidata): Knowledge embeddings of Wikidata pre-trained by OpenKE. 196 | - [TransR version](http://139.129.163.161/download/wikidata) of Wikidata-5M: Knowledge embeddings of Wikidata-5M pre-trained by OpenKE. 197 | - OpenKE-Freebase 198 | - Freebase was a large collaborative knowledge base consisting of data composed mainly by its community members. It was an online collection of structured data harvested from many sources. Freebase contains 86,054,151 entities, 14,824 relations and 338,586,276 triplets. 199 | - [TransE version](http://139.129.163.161/download/wikidata): Knowledge embeddings of Freebase pre-trained by OpenKE. 200 | - OpenKE-XLORE 201 | - XLORE is one of the most popular Chinese knowledge graphs developed by THUKEG. XLORE contains 10,572,209 entities, 138,581 relations and 35,954,249 triplets. 202 | - [TransE version](http://139.129.163.161/download/wikidata): Knowledge embeddings of XLORE pre-trained by OpenKE. 203 | - **Application**: 204 | - [Knowledge-Plugin](https://github.com/THUNLP/Knowledge-Plugin) 205 | - An effective and efficient toolkit of plug-and-play knowledge injection for pre-trained language models. Knowledge-Plugin is general for all kinds of knowledge graph embeddings mentioned above. In the toolkit, we plug the TransR version of Wikidata-5M into BERT as an example of applications. With the TransR embedding, we enhance the knowledge ability of BERT without fine-tuning the original model, e.g., up to 8% improvement on question answering. 206 | -------------------------------------------------------------------------------- /benchmark/download_fewrel.sh: -------------------------------------------------------------------------------- 1 | mkdir fewrel 2 | wget -P fewrel https://thunlp.oss-cn-qingdao.aliyuncs.com/opennre/benchmark/fewrel/fewrel_train.txt 3 | wget -P fewrel https://thunlp.oss-cn-qingdao.aliyuncs.com/opennre/benchmark/fewrel/fewrel_train_rel2id.json 4 | wget -P fewrel https://thunlp.oss-cn-qingdao.aliyuncs.com/opennre/benchmark/fewrel/fewrel_val.txt 5 | wget -P fewrel https://thunlp.oss-cn-qingdao.aliyuncs.com/opennre/benchmark/fewrel/fewrel_val_rel2id.json 6 | -------------------------------------------------------------------------------- /benchmark/download_nyt10.sh: -------------------------------------------------------------------------------- 1 | mkdir nyt10 2 | wget -P nyt10 https://thunlp.oss-cn-qingdao.aliyuncs.com/opennre/benchmark/nyt10/nyt10_rel2id.json 3 | wget -P nyt10 https://thunlp.oss-cn-qingdao.aliyuncs.com/opennre/benchmark/nyt10/nyt10_train.txt 4 | wget -P nyt10 https://thunlp.oss-cn-qingdao.aliyuncs.com/opennre/benchmark/nyt10/nyt10_test.txt 5 | -------------------------------------------------------------------------------- /benchmark/download_nyt10m.sh: -------------------------------------------------------------------------------- 1 | mkdir nyt10m 2 | wget -P nyt10m https://thunlp.oss-cn-qingdao.aliyuncs.com/opennre/benchmark/nyt10m/nyt10m_rel2id.json 3 | wget -P nyt10m https://thunlp.oss-cn-qingdao.aliyuncs.com/opennre/benchmark/nyt10m/nyt10m_train.txt 4 | wget -P nyt10m https://thunlp.oss-cn-qingdao.aliyuncs.com/opennre/benchmark/nyt10m/nyt10m_val.txt 5 | wget -P nyt10m https://thunlp.oss-cn-qingdao.aliyuncs.com/opennre/benchmark/nyt10m/nyt10m_test.txt 6 | -------------------------------------------------------------------------------- /benchmark/download_semeval.sh: -------------------------------------------------------------------------------- 1 | mkdir semeval 2 | wget -P semeval https://thunlp.oss-cn-qingdao.aliyuncs.com/opennre/benchmark/semeval/semeval_rel2id.json 3 | wget -P semeval https://thunlp.oss-cn-qingdao.aliyuncs.com/opennre/benchmark/semeval/semeval_train.txt 4 | wget -P semeval https://thunlp.oss-cn-qingdao.aliyuncs.com/opennre/benchmark/semeval/semeval_val.txt 5 | wget -P semeval https://thunlp.oss-cn-qingdao.aliyuncs.com/opennre/benchmark/semeval/semeval_test.txt 6 | -------------------------------------------------------------------------------- /benchmark/download_wiki20m.sh: -------------------------------------------------------------------------------- 1 | mkdir wiki20m 2 | wget -P wiki20m https://thunlp.oss-cn-qingdao.aliyuncs.com/opennre/benchmark/wiki20m/wiki20m_rel2id.json 3 | wget -P wiki20m https://thunlp.oss-cn-qingdao.aliyuncs.com/opennre/benchmark/wiki20m/wiki20m_train.txt 4 | wget -P wiki20m https://thunlp.oss-cn-qingdao.aliyuncs.com/opennre/benchmark/wiki20m/wiki20m_val.txt 5 | wget -P wiki20m https://thunlp.oss-cn-qingdao.aliyuncs.com/opennre/benchmark/wiki20m/wiki20m_test.txt 6 | -------------------------------------------------------------------------------- /benchmark/download_wiki80.sh: -------------------------------------------------------------------------------- 1 | mkdir wiki80 2 | wget -P wiki80 https://thunlp.oss-cn-qingdao.aliyuncs.com/opennre/benchmark/wiki80/wiki80_rel2id.json 3 | wget -P wiki80 https://thunlp.oss-cn-qingdao.aliyuncs.com/opennre/benchmark/wiki80/wiki80_train.txt 4 | wget -P wiki80 https://thunlp.oss-cn-qingdao.aliyuncs.com/opennre/benchmark/wiki80/wiki80_val.txt 5 | -------------------------------------------------------------------------------- /benchmark/rel4wiki80.md: -------------------------------------------------------------------------------- 1 | 2 | | Wikidata ID | Name | Description| 3 | | ---- | ---- | ---- | 4 | |P6 | head of government | head of the executive power of this town, city, municipality, state, country, or other governmental body | 5 | |P17 | country | sovereign state of this item; don't use on humans | 6 | |P22 | father | male parent of the subject. For stepfather, use "stepparent" (P3448) | 7 | |P27 | country of citizenship | the object is a country that recognizes the subject as its citizen | 8 | |P31 | instance of | that class of which this subject is a particular example and member. (Subject typically an individual member with Proper Name label.) Different from P279 (subclass of) | 9 | |P39 | position held | subject currently or formerly holds the object position or public office | 10 | |P57 | director | director(s) of this motion picture, TV-series, stageplay, video game or similar | 11 | |P58 | screenwriter | author(s) of the screenplay or script for this work | 12 | |P84 | architect | person or architectural firm that designed this building | 13 | |P86 | composer | person(s) who wrote the music; also use P676 for lyricist | 14 | |P101 | field of work | specialization of a person or organization; see P106 for the occupation | 15 | |P102 | member of political party | the political party of which this politician is or has been a member | 16 | |P105 | taxon rank | level in a taxonomic hierarchy | 17 | |P106 | occupation | occupation of a person; see also "field of work" (Property:P101), "position held" (Property:P39) | 18 | |P118 | league | league in which team or player plays or has played in | 19 | |P123 | publisher | organization or person responsible for publishing books, periodicals, games or software | 20 | |P127 | owned by | owner of the subject | 21 | |P131 | located in the administrative territorial entity | the item is located on the territory of the following administrative entity. Use P276 (location) for specifying the location of non-administrative places and for items about events | 22 | |P135 | movement | literary, artistic, scientific or philosophical movement associated with this person or work | 23 | |P136 | genre | creative work's genre or an artist's field of work (P101). Use main subject (P921) to relate creative works to their topic | 24 | |P137 | operator | person or organization that operates the equipment, facility, or service; use country for diplomatic missions | 25 | |P140 | religion | religion of a person, organization or religious building, or associated with this subject | 26 | |P150 | contains administrative territorial entity | (list of) direct subdivisions of an administrative territorial entity | 27 | |P156 | followed by | immediately following item in some series of which the subject is part. Use P1366 (replaced by) if the item is replaced, e.g. political offices, states | 28 | |P159 | headquarters location | specific location where an organization's headquarters is or has been situated | 29 | |P175 | performer | performer involved in the performance or the recording of a work | 30 | |P176 | manufacturer | manufacturer or producer of this product | 31 | |P178 | developer | organisation or person that developed this item | 32 | |P241 | military branch | branch to which this military unit, award, office, or person belongs, e.g. Royal Navy | 33 | |P264 | record label | brand and trademark associated with the marketing of subject music recordings and music videos | 34 | |P276 | location | location of the item, physical object or event is within. In case of an administrative entity use P131. In case of a distinct terrain feature use P706. | 35 | |P306 | operating system | operating system (OS) on which a software works or the OS installed on hardware | 36 | |P355 | subsidiary | subsidiary of a company or organization, opposite of parent company (P749) | 37 | |P400 | platform | platform for which a work has been developed or released / specific platform version of a software developed | 38 | |P403 | mouth of the watercourse | the body of water to which the watercourse drains | 39 | |P407 | language of work or name | language associated with this work or name (for persons use P103 and P1412) | 40 | |P449 | original network | network(s) the radio or television show was originally aired on, including | 41 | |P460 | said to be the same as | this item is said to be the same as that item, but the statement is disputed | 42 | |P466 | occupant | person or organization occupying a building or facility | 43 | |P495 | country of origin | country of origin of the creative work or subject item | 44 | |P527 | has part | part of this subject. Inverse property of "part of" (P361). | 45 | |P551 | residence | the place where the person is, or has been, resident | 46 | |P674 | characters | characters which appear in this item (like plays, operas, operettas, books, comics, films, TV series, video games) | 47 | |P706 | located on terrain feature | located on the specified landform. Should not be used when the value is only political/administrative (provinces, states, countries, etc.). Use P131 for administrative entity. | 48 | |P710 | participant | person, group of people or organization (object) that actively takes/took part in the event (subject). Preferably qualify with "object has role" (P3831). Use P1923 for team participants. | 49 | |P740 | location of formation | location where a group or organization was formed | 50 | |P750 | distributor | distributor of a creative work | 51 | |P800 | notable work | notable scientific, artistic or literary work, or other work of significance among subject's works | 52 | |P931 | place served by transport hub | city or region served by this transport hub (airport, train station, etc.) | 53 | |P937 | work location | location where persons were active | 54 | |P974 | tributary | stream or river that flows into this main stem (or parent) river | 55 | |P991 | successful candidate | person(s) elected after the election | 56 | |P1001 | applies to jurisdiction | the item (an institution, law, public office ...) belongs to or has power over or applies to the value (a territorial jurisdiction: a country, state, municipality, ...) | 57 | |P1303 | instrument | musical instrument that a person plays | 58 | |P1344 | participant of | event a person or an organization was a participant in, inverse of P710 or P1923 | 59 | |P1346 | winner | winner of an event - do not use for wars or battles | 60 | |P1408 | licensed to broadcast to | place that a radio/TV station is licensed/required to broadcast to | 61 | |P1411 | nominated for | award nomination received by a person, organisation or creative work (inspired from "award received" (Property:P166)) | 62 | |P1435 | heritage designation | heritage designation of a historical site | 63 | |P1877 | after a work by | artist whose work strongly inspired/ was copied in this item | 64 | |P1923 | participating teams | Like 'Participant' (P710) but for teams. For an event like a cycle race or a football match you can use this property to list the teams and P710 to list the individuals (with 'member of sports team' (P54)' as a qualifier for the individuals) | 65 | |P3373 | sibling | the subject has the object as their sibling (brother, sister, etc.). Use "relative" (P1038) for siblings-in-law (brother-in-law, sister-in-law, etc.) and step-siblings (step-brothers, step-sisters, etc.) | 66 | |P3450 | sports season of | property that shows the competition of which the item is a season | 67 | |P4552 | mountain range | range or subrange to which the geographical item belongs | 68 | |P25 | mother | female parent of the subject. For stepmother, use "stepparent" (P3448) | 69 | |P26 | spouse | the subject has the object as their spouse (husband, wife, partner, etc.). Use "unmarried partner" (P451) for non-married companions | 70 | |P40 | child | subject has the object in their family as their offspring son or daughter (independently of their age) | 71 | |P59 | constellation | the area of the celestial sphere of which the subject is a part (from a scientific standpoint, not an astrological one) | 72 | |P155 | follows | immediately prior item in some series of which the subject is part. Use P1365 (replaces) if the preceding item was replaced, e.g. political offices, states and there is no identity between precedent and following geographic unit | 73 | |P177 | crosses | obstacle (body of water, road, ...) which this bridge crosses over or this tunnel goes under | 74 | |P206 | located in or next to body of water | sea, lake or river | 75 | |P361 | part of | object of which the subject is a part. Inverse property of "has part" (P527). See also "has parts of the class" (P2670). | 76 | |P364 | original language of work | language in which a film or a performance work was originally created. Deprecated for written works; use P407 ("language of work or name") instead. | 77 | |P410 | military rank | military rank achieved by a person (should usually have a "start date" qualifier), or military rank associated with a position | 78 | |P412 | voice type | person's voice type. expected values: soprano, mezzo-soprano, contralto, countertenor, tenor, baritone, bass (and derivatives) | 79 | |P413 | position played on team / speciality | position or specialism of a player on a team, e.g. Small Forward | 80 | |P463 | member of | organization or club to which the subject belongs. Do not use for membership in ethnic or social groups, nor for holding a position such as a member of parliament (use P39 for that). | 81 | |P641 | sport | sport in which the entity participates or belongs to | 82 | |P921 | main subject | primary topic of a work (see also P180: depicts) | 83 | |P2094 | competition class | official classification by a regulating body under which the subject (events, teams, participants, or equipment) qualifies for inclusion | 84 | -------------------------------------------------------------------------------- /example/test_multilabel_bert.py: -------------------------------------------------------------------------------- 1 | # coding:utf-8 2 | import torch 3 | import numpy as np 4 | import json 5 | import opennre 6 | from opennre import encoder, model, framework 7 | import sys 8 | import os 9 | import argparse 10 | import logging 11 | import random 12 | 13 | def set_seed(seed): 14 | random.seed(seed) 15 | np.random.seed(seed) 16 | torch.manual_seed(seed) 17 | torch.cuda.manual_seed_all(seed) 18 | 19 | 20 | parser = argparse.ArgumentParser() 21 | parser.add_argument('--pretrain_path', default='bert-base-uncased', 22 | help='Pre-trained ckpt path / model name (hugginface)') 23 | parser.add_argument('--ckpt', default='', 24 | help='Checkpoint name') 25 | parser.add_argument('--pooler', default='entity', choices=['cls', 'entity'], 26 | help='Sentence representation pooler') 27 | parser.add_argument('--only_test', action='store_true', 28 | help='Only run test') 29 | parser.add_argument('--mask_entity', action='store_true', 30 | help='Mask entity mentions') 31 | 32 | # Data 33 | parser.add_argument('--metric', default='micro_f1', choices=['micro_f1', 'acc'], 34 | help='Metric for picking up best checkpoint') 35 | parser.add_argument('--dataset', default='none', choices=['none', 'semeval', 'wiki80', 'tacred'], 36 | help='Dataset. If not none, the following args can be ignored') 37 | parser.add_argument('--train_file', default='', type=str, 38 | help='Training data file') 39 | parser.add_argument('--val_file', default='', type=str, 40 | help='Validation data file') 41 | parser.add_argument('--test_file', default='', type=str, 42 | help='Test data file') 43 | parser.add_argument('--rel2id_file', default='', type=str, 44 | help='Relation to ID file') 45 | 46 | # Hyper-parameters 47 | parser.add_argument('--batch_size', default=64, type=int, 48 | help='Batch size') 49 | parser.add_argument('--lr', default=2e-5, type=float, 50 | help='Learning rate') 51 | parser.add_argument('--max_length', default=128, type=int, 52 | help='Maximum sentence length') 53 | parser.add_argument('--max_epoch', default=3, type=int, 54 | help='Max number of training epochs') 55 | 56 | # Seed 57 | parser.add_argument('--seed', default=42, type=int, 58 | help='Seed') 59 | 60 | args = parser.parse_args() 61 | 62 | # Set random seed 63 | set_seed(args.seed) 64 | 65 | # Some basic settings 66 | root_path = '.' 67 | sys.path.append(root_path) 68 | if not os.path.exists('ckpt'): 69 | os.mkdir('ckpt') 70 | if len(args.ckpt) == 0: 71 | args.ckpt = '{}_{}_{}'.format(args.dataset, args.pretrain_path, args.pooler) 72 | ckpt = 'ckpt/{}.pth.tar'.format(args.ckpt) 73 | 74 | if args.dataset != 'none': 75 | opennre.download(args.dataset, root_path=root_path) 76 | args.train_file = os.path.join(root_path, 'benchmark', args.dataset, '{}_train.txt'.format(args.dataset)) 77 | args.val_file = os.path.join(root_path, 'benchmark', args.dataset, '{}_val.txt'.format(args.dataset)) 78 | args.test_file = os.path.join(root_path, 'benchmark', args.dataset, '{}_test.txt'.format(args.dataset)) 79 | if not os.path.exists(args.test_file): 80 | logging.warn("Test file {} does not exist! Use val file instead".format(args.test_file)) 81 | args.test_file = args.val_file 82 | args.rel2id_file = os.path.join(root_path, 'benchmark', args.dataset, '{}_rel2id.json'.format(args.dataset)) 83 | if args.dataset == 'wiki80': 84 | args.metric = 'acc' 85 | else: 86 | args.metric = 'micro_f1' 87 | else: 88 | if not (os.path.exists(args.train_file) and os.path.exists(args.val_file) and os.path.exists(args.test_file) and os.path.exists(args.rel2id_file)): 89 | raise Exception('--train_file, --val_file, --test_file and --rel2id_file are not specified or files do not exist. Or specify --dataset') 90 | 91 | logging.info('Arguments:') 92 | for arg in vars(args): 93 | logging.info(' {}: {}'.format(arg, getattr(args, arg))) 94 | 95 | rel2id = json.load(open(args.rel2id_file)) 96 | 97 | # Define the sentence encoder 98 | if args.pooler == 'entity': 99 | sentence_encoder = opennre.encoder.BERTEntityEncoder( 100 | max_length=args.max_length, 101 | pretrain_path=args.pretrain_path, 102 | mask_entity=args.mask_entity 103 | ) 104 | elif args.pooler == 'cls': 105 | sentence_encoder = opennre.encoder.BERTEncoder( 106 | max_length=args.max_length, 107 | pretrain_path=args.pretrain_path, 108 | mask_entity=args.mask_entity 109 | ) 110 | else: 111 | raise NotImplementedError 112 | 113 | # Define the model 114 | model = opennre.model.SoftmaxNN(sentence_encoder, len(rel2id), rel2id) 115 | 116 | # Define the whole training framework 117 | framework = opennre.framework.MultiLabelSentenceRE( 118 | train_path=args.train_file, 119 | val_path=args.val_file, 120 | test_path=args.test_file, 121 | model=model, 122 | ckpt=ckpt, 123 | batch_size=args.batch_size, 124 | max_epoch=args.max_epoch, 125 | lr=args.lr, 126 | opt='adamw' 127 | ) 128 | 129 | # Train the model 130 | if not args.only_test: 131 | framework.train_model('micro_f1') 132 | 133 | # Test 134 | framework.load_state_dict(torch.load(ckpt)['state_dict']) 135 | result = framework.eval_model(framework.test_loader) 136 | 137 | # Print the result 138 | logging.info('Test set results:') 139 | logging.info('Accuracy: {}'.format(result['acc'])) 140 | logging.info('Micro precision: {}'.format(result['micro_p'])) 141 | logging.info('Micro recall: {}'.format(result['micro_r'])) 142 | logging.info('Micro F1: {}'.format(result['micro_f1'])) 143 | logging.info('Macro precision: {}'.format(result['macro_p'])) 144 | logging.info('Macro recall: {}'.format(result['macro_r'])) 145 | logging.info('Macro F1: {}'.format(result['macro_f1'])) 146 | -------------------------------------------------------------------------------- /example/test_multilabel_cnn.py: -------------------------------------------------------------------------------- 1 | # coding:utf-8 2 | import torch 3 | import numpy as np 4 | import json 5 | import opennre 6 | from opennre import encoder, model, framework 7 | import sys 8 | import os 9 | import argparse 10 | import logging 11 | 12 | parser = argparse.ArgumentParser() 13 | parser.add_argument('--ckpt', default='', 14 | help='Checkpoint name') 15 | parser.add_argument('--only_test', action='store_true', 16 | help='Only run test') 17 | 18 | # Data 19 | parser.add_argument('--metric', default='micro_f1', choices=['micro_f1', 'acc'], 20 | help='Metric for picking up best checkpoint') 21 | parser.add_argument('--dataset', default='none', choices=['none', 'semeval', 'wiki80', 'tacred'], 22 | help='Dataset. If not none, the following args can be ignored') 23 | parser.add_argument('--train_file', default='', type=str, 24 | help='Training data file') 25 | parser.add_argument('--val_file', default='', type=str, 26 | help='Validation data file') 27 | parser.add_argument('--test_file', default='', type=str, 28 | help='Test data file') 29 | parser.add_argument('--rel2id_file', default='', type=str, 30 | help='Relation to ID file') 31 | 32 | # Hyper-parameters 33 | parser.add_argument('--batch_size', default=32, type=int, 34 | help='Batch size') 35 | parser.add_argument('--lr', default=1e-1, type=float, 36 | help='Learning rate') 37 | parser.add_argument('--weight_decay', default=1e-5, type=float, 38 | help='Weight decay') 39 | parser.add_argument('--max_length', default=128, type=int, 40 | help='Maximum sentence length') 41 | parser.add_argument('--max_epoch', default=100, type=int, 42 | help='Max number of training epochs') 43 | 44 | parser.add_argument('--encoder', default='cnn', choices=['pcnn', 'cnn']) 45 | 46 | args = parser.parse_args() 47 | 48 | # Some basic settings 49 | root_path = '.' 50 | sys.path.append(root_path) 51 | if not os.path.exists('ckpt'): 52 | os.mkdir('ckpt') 53 | if len(args.ckpt) == 0: 54 | args.ckpt = '{}_{}'.format(args.dataset, 'cnn') 55 | ckpt = 'ckpt/{}.pth.tar'.format(args.ckpt) 56 | 57 | if args.dataset != 'none': 58 | opennre.download(args.dataset, root_path=root_path) 59 | args.train_file = os.path.join(root_path, 'benchmark', args.dataset, '{}_train.txt'.format(args.dataset)) 60 | args.val_file = os.path.join(root_path, 'benchmark', args.dataset, '{}_val.txt'.format(args.dataset)) 61 | args.test_file = os.path.join(root_path, 'benchmark', args.dataset, '{}_test.txt'.format(args.dataset)) 62 | args.rel2id_file = os.path.join(root_path, 'benchmark', args.dataset, '{}_rel2id.json'.format(args.dataset)) 63 | if args.dataset == 'wiki80': 64 | args.metric = 'acc' 65 | else: 66 | args.metric = 'micro_f1' 67 | else: 68 | if not (os.path.exists(args.train_file) and os.path.exists(args.val_file) and os.path.exists(args.test_file) and os.path.exists(args.rel2id_file)): 69 | raise Exception('--train_file, --val_file, --test_file and --rel2id_file are not specified or files do not exist. Or specify --dataset') 70 | 71 | logging.info('Arguments:') 72 | for arg in vars(args): 73 | logging.info(' {}: {}'.format(arg, getattr(args, arg))) 74 | 75 | rel2id = json.load(open(args.rel2id_file)) 76 | 77 | # Download glove 78 | opennre.download('glove', root_path=root_path) 79 | word2id = json.load(open(os.path.join(root_path, 'pretrain/glove/glove.6B.50d_word2id.json'))) 80 | word2vec = np.load(os.path.join(root_path, 'pretrain/glove/glove.6B.50d_mat.npy')) 81 | 82 | # Define the sentence encoder 83 | if args.encoder == 'cnn': 84 | sentence_encoder = opennre.encoder.CNNEncoder( 85 | token2id=word2id, 86 | max_length=args.max_length, 87 | word_size=50, 88 | position_size=5, 89 | hidden_size=230, 90 | blank_padding=True, 91 | kernel_size=3, 92 | padding_size=1, 93 | word2vec=word2vec, 94 | dropout=0.5 95 | ) 96 | elif args.encoder == 'pcnn': 97 | sentence_encoder = opennre.encoder.PCNNEncoder( 98 | token2id=word2id, 99 | max_length=args.max_length, 100 | word_size=50, 101 | position_size=5, 102 | hidden_size=230, 103 | blank_padding=True, 104 | kernel_size=3, 105 | padding_size=1, 106 | word2vec=word2vec, 107 | dropout=0.5 108 | ) 109 | else: 110 | raise NotImplementedError 111 | 112 | # Define the model 113 | model = opennre.model.SoftmaxNN(sentence_encoder, len(rel2id), rel2id) 114 | 115 | # Define the whole training framework 116 | framework = opennre.framework.MultiLabelSentenceRE( 117 | train_path=args.train_file, 118 | val_path=args.val_file, 119 | test_path=args.test_file, 120 | model=model, 121 | ckpt=ckpt, 122 | batch_size=args.batch_size, 123 | max_epoch=args.max_epoch, 124 | lr=args.lr, 125 | weight_decay=args.weight_decay, 126 | opt='sgd' 127 | ) 128 | 129 | # Train the model 130 | if not args.only_test: 131 | framework.train_model(args.metric) 132 | 133 | # Test 134 | framework.load_state_dict(torch.load(ckpt)['state_dict']) 135 | result = framework.eval_model(framework.test_loader) 136 | 137 | # Print the result 138 | logging.info('Test set results:') 139 | logging.info('Accuracy: %.5f' % (result['acc'])) 140 | logging.info('Micro precision: %.5f' % (result['micro_p'])) 141 | logging.info('Micro recall: %.5f' % (result['micro_r'])) 142 | logging.info('Micro F1: %.5f' % (result['micro_f1'])) 143 | logging.info('Macro precision: %.5f' % (result['macro_p'])) 144 | logging.info('Macro recall: %.5f' % (result['macro_r'])) 145 | logging.info('Macro F1: %.5f' % (result['macro_f1'])) 146 | logging.info("AUC: %.5f" % (result['auc'])) 147 | logging.info('P@100: %.5f' % (result['p@100'])) 148 | logging.info('P@200: %.5f' % (result['p@200'])) 149 | logging.info('P@300: %.5f' % (result['p@300'])) 150 | logging.info("Max micro F1: %.5f" % (result['max_micro_f1'])) 151 | logging.info("| threshold: %.5f" % (result['max_micro_f1_threshold'])) 152 | -------------------------------------------------------------------------------- /example/train_bag_bert.py: -------------------------------------------------------------------------------- 1 | # coding:utf-8 2 | import sys, json 3 | import torch 4 | import os 5 | import numpy as np 6 | import opennre 7 | import argparse 8 | import logging 9 | import random 10 | 11 | def set_seed(seed): 12 | random.seed(seed) 13 | np.random.seed(seed) 14 | torch.manual_seed(seed) 15 | torch.cuda.manual_seed_all(seed) 16 | 17 | parser = argparse.ArgumentParser() 18 | parser.add_argument('--pretrain_path', default='bert-base-uncased', 19 | help='Pre-trained ckpt path / model name (hugginface)') 20 | parser.add_argument('--ckpt', default='', 21 | help='Checkpoint name') 22 | parser.add_argument('--result', default='', 23 | help='Result name') 24 | parser.add_argument('--pooler', default='entity', choices=['cls', 'entity'], 25 | help='Sentence representation pooler') 26 | parser.add_argument('--only_test', action='store_true', 27 | help='Only run test') 28 | parser.add_argument('--mask_entity', action='store_true', 29 | help='Mask entity mentions') 30 | 31 | # Data 32 | parser.add_argument('--metric', default='auc', choices=['micro_f1', 'auc'], 33 | help='Metric for picking up best checkpoint') 34 | parser.add_argument('--dataset', default='none', choices=['none', 'wiki_distant', 'nyt10', 'nyt10m', 'wiki20m'], 35 | help='Dataset. If not none, the following args can be ignored') 36 | parser.add_argument('--train_file', default='', type=str, 37 | help='Training data file') 38 | parser.add_argument('--val_file', default='', type=str, 39 | help='Validation data file') 40 | parser.add_argument('--test_file', default='', type=str, 41 | help='Test data file') 42 | parser.add_argument('--rel2id_file', default='', type=str, 43 | help='Relation to ID file') 44 | 45 | # Bag related 46 | parser.add_argument('--bag_size', type=int, default=4, 47 | help='Fixed bag size. If set to 0, use original bag sizes') 48 | 49 | # Hyper-parameters 50 | parser.add_argument('--batch_size', default=16, type=int, 51 | help='Batch size') 52 | parser.add_argument('--lr', default=2e-5, type=float, 53 | help='Learning rate') 54 | parser.add_argument('--max_length', default=128, type=int, 55 | help='Maximum sentence length') 56 | parser.add_argument('--max_epoch', default=3, type=int, 57 | help='Max number of training epochs') 58 | 59 | # Exp 60 | parser.add_argument('--aggr', default='att', choices=['one', 'att', 'avg']) 61 | 62 | 63 | # Seed 64 | parser.add_argument('--seed', default=42, type=int, 65 | help='Seed') 66 | 67 | args = parser.parse_args() 68 | 69 | # Set random seed 70 | set_seed(args.seed) 71 | args = parser.parse_args() 72 | 73 | # Some basic settings 74 | root_path = '.' 75 | sys.path.append(root_path) 76 | if not os.path.exists('ckpt'): 77 | os.mkdir('ckpt') 78 | ckpt = 'ckpt/{}.pth.tar'.format(args.ckpt) 79 | 80 | if args.dataset != 'none': 81 | opennre.download(args.dataset, root_path=root_path) 82 | args.train_file = os.path.join(root_path, 'benchmark', args.dataset, '{}_train.txt'.format(args.dataset)) 83 | args.val_file = os.path.join(root_path, 'benchmark', args.dataset, '{}_val.txt'.format(args.dataset)) 84 | if not os.path.exists(args.val_file): 85 | logging.info("Cannot find the validation file. Use the test file instead.") 86 | args.val_file = os.path.join(root_path, 'benchmark', args.dataset, '{}_test.txt'.format(args.dataset)) 87 | args.test_file = os.path.join(root_path, 'benchmark', args.dataset, '{}_test.txt'.format(args.dataset)) 88 | args.rel2id_file = os.path.join(root_path, 'benchmark', args.dataset, '{}_rel2id.json'.format(args.dataset)) 89 | else: 90 | if not (os.path.exists(args.train_file) and os.path.exists(args.val_file) and os.path.exists(args.test_file) and os.path.exists(args.rel2id_file)): 91 | raise Exception('--train_file, --val_file, --test_file and --rel2id_file are not specified or files do not exist. Or specify --dataset') 92 | 93 | logging.info('Arguments:') 94 | for arg in vars(args): 95 | logging.info(' {}: {}'.format(arg, getattr(args, arg))) 96 | 97 | rel2id = json.load(open(args.rel2id_file)) 98 | 99 | 100 | # Define the sentence encoder 101 | if args.pooler == 'entity': 102 | sentence_encoder = opennre.encoder.BERTEntityEncoder( 103 | max_length=args.max_length, 104 | pretrain_path=args.pretrain_path, 105 | mask_entity=args.mask_entity 106 | ) 107 | elif args.pooler == 'cls': 108 | sentence_encoder = opennre.encoder.BERTEncoder( 109 | max_length=args.max_length, 110 | pretrain_path=args.pretrain_path, 111 | mask_entity=args.mask_entity 112 | ) 113 | else: 114 | raise NotImplementedError 115 | 116 | 117 | # Define the model 118 | if args.aggr == 'att': 119 | model = opennre.model.BagAttention(sentence_encoder, len(rel2id), rel2id) 120 | elif args.aggr == 'avg': 121 | model = opennre.model.BagAverage(sentence_encoder, len(rel2id), rel2id) 122 | elif args.aggr == 'one': 123 | model = opennre.model.BagOne(sentence_encoder, len(rel2id), rel2id) 124 | else: 125 | raise NotImplementedError 126 | 127 | # Define the whole training framework 128 | framework = opennre.framework.BagRE( 129 | train_path=args.train_file, 130 | val_path=args.val_file, 131 | test_path=args.test_file, 132 | model=model, 133 | ckpt=ckpt, 134 | batch_size=args.batch_size, 135 | max_epoch=args.max_epoch, 136 | lr=args.lr, 137 | opt="adamw", 138 | bag_size=args.bag_size 139 | ) 140 | 141 | # Train the model 142 | if not args.only_test: 143 | framework.train_model(args.metric) 144 | 145 | # Test the model 146 | framework.load_state_dict(torch.load(ckpt)['state_dict']) 147 | result = framework.eval_model(framework.test_loader) 148 | 149 | # Print the result 150 | logging.info('Test set results:') 151 | logging.info('AUC: %.5f' % (result['auc'])) 152 | logging.info('Maximum micro F1: %.5f' % (result['max_micro_f1'])) 153 | logging.info('Maximum macro F1: %.5f' % (result['max_macro_f1'])) 154 | logging.info('Micro F1: %.5f' % (result['micro_f1'])) 155 | logging.info('Macro F1: %.5f' % (result['macro_f1'])) 156 | logging.info('P@100: %.5f' % (result['p@100'])) 157 | logging.info('P@200: %.5f' % (result['p@200'])) 158 | logging.info('P@300: %.5f' % (result['p@300'])) 159 | 160 | 161 | # Save precision/recall points 162 | np.save('result/{}_p.npy'.format(args.result), result['np_prec']) 163 | np.save('result/{}_r.npy'.format(args.result), result['np_rec']) 164 | json.dump(result['max_micro_f1_each_relation'], open('result/{}_mmicrof1_rel.json'.format(args.result), 'w'), ensure_ascii=False) 165 | -------------------------------------------------------------------------------- /example/train_bag_cnn.py: -------------------------------------------------------------------------------- 1 | # coding:utf-8 2 | import sys, json 3 | import torch 4 | import os 5 | import numpy as np 6 | import opennre 7 | import argparse 8 | import logging 9 | import random 10 | 11 | def set_seed(seed): 12 | random.seed(seed) 13 | np.random.seed(seed) 14 | torch.manual_seed(seed) 15 | torch.cuda.manual_seed_all(seed) 16 | 17 | parser = argparse.ArgumentParser() 18 | parser.add_argument('--ckpt', default='', 19 | help='Checkpoint name') 20 | parser.add_argument('--result', default='', 21 | help='Save result name') 22 | parser.add_argument('--only_test', action='store_true', 23 | help='Only run test') 24 | 25 | # Data 26 | parser.add_argument('--metric', default='auc', choices=['micro_f1', 'auc'], 27 | help='Metric for picking up best checkpoint') 28 | parser.add_argument('--dataset', default='none', choices=['none', 'wiki_distant', 'nyt10', 'nyt10m', 'wiki20m'], 29 | help='Dataset. If not none, the following args can be ignored') 30 | parser.add_argument('--train_file', default='', type=str, 31 | help='Training data file') 32 | parser.add_argument('--val_file', default='', type=str, 33 | help='Validation data file') 34 | parser.add_argument('--test_file', default='', type=str, 35 | help='Test data file') 36 | parser.add_argument('--rel2id_file', default='', type=str, 37 | help='Relation to ID file') 38 | 39 | # Bag related 40 | parser.add_argument('--bag_size', type=int, default=0, 41 | help='Fixed bag size. If set to 0, use original bag sizes') 42 | 43 | # Hyper-parameters 44 | parser.add_argument('--batch_size', default=160, type=int, 45 | help='Batch size') 46 | parser.add_argument('--lr', default=0.1, type=float, 47 | help='Learning rate') 48 | parser.add_argument('--optim', default='sgd', type=str, 49 | help='Optimizer') 50 | parser.add_argument('--weight_decay', default=1e-5, type=float, 51 | help='Weight decay') 52 | parser.add_argument('--max_length', default=128, type=int, 53 | help='Maximum sentence length') 54 | parser.add_argument('--max_epoch', default=100, type=int, 55 | help='Max number of training epochs') 56 | 57 | # Others 58 | parser.add_argument('--seed', default=42, type=int, 59 | help='Random seed') 60 | 61 | # Exp 62 | parser.add_argument('--encoder', default='pcnn', choices=['pcnn', 'cnn']) 63 | parser.add_argument('--aggr', default='att', choices=['one', 'att', 'avg']) 64 | 65 | args = parser.parse_args() 66 | 67 | # Set random seed 68 | set_seed(args.seed) 69 | 70 | # Some basic settings 71 | root_path = '.' 72 | sys.path.append(root_path) 73 | if not os.path.exists('ckpt'): 74 | os.mkdir('ckpt') 75 | if len(args.ckpt) == 0: 76 | args.ckpt = '{}_{}'.format(args.dataset, 'pcnn_att') 77 | ckpt = 'ckpt/{}.pth.tar'.format(args.ckpt) 78 | 79 | if args.dataset != 'none': 80 | opennre.download(args.dataset, root_path=root_path) 81 | args.train_file = os.path.join(root_path, 'benchmark', args.dataset, '{}_train.txt'.format(args.dataset)) 82 | args.val_file = os.path.join(root_path, 'benchmark', args.dataset, '{}_val.txt'.format(args.dataset)) 83 | if not os.path.exists(args.val_file): 84 | logging.info("Cannot find the validation file. Use the test file instead.") 85 | args.val_file = os.path.join(root_path, 'benchmark', args.dataset, '{}_test.txt'.format(args.dataset)) 86 | args.test_file = os.path.join(root_path, 'benchmark', args.dataset, '{}_test.txt'.format(args.dataset)) 87 | args.rel2id_file = os.path.join(root_path, 'benchmark', args.dataset, '{}_rel2id.json'.format(args.dataset)) 88 | else: 89 | if not (os.path.exists(args.train_file) and os.path.exists(args.val_file) and os.path.exists(args.test_file) and os.path.exists(args.rel2id_file)): 90 | raise Exception('--train_file, --val_file, --test_file and --rel2id_file are not specified or files do not exist. Or specify --dataset') 91 | 92 | logging.info('Arguments:') 93 | for arg in vars(args): 94 | logging.info(' {}: {}'.format(arg, getattr(args, arg))) 95 | 96 | rel2id = json.load(open(args.rel2id_file)) 97 | 98 | # Download glove 99 | opennre.download('glove', root_path=root_path) 100 | word2id = json.load(open(os.path.join(root_path, 'pretrain/glove/glove.6B.50d_word2id.json'))) 101 | word2vec = np.load(os.path.join(root_path, 'pretrain/glove/glove.6B.50d_mat.npy')) 102 | 103 | # Define the sentence encoder 104 | if args.encoder == 'pcnn': 105 | sentence_encoder = opennre.encoder.PCNNEncoder( 106 | token2id=word2id, 107 | max_length=args.max_length, 108 | word_size=50, 109 | position_size=5, 110 | hidden_size=230, 111 | blank_padding=True, 112 | kernel_size=3, 113 | padding_size=1, 114 | word2vec=word2vec, 115 | dropout=0.5 116 | ) 117 | elif args.encoder == 'cnn': 118 | sentence_encoder = opennre.encoder.CNNEncoder( 119 | token2id=word2id, 120 | max_length=args.max_length, 121 | word_size=50, 122 | position_size=5, 123 | hidden_size=230, 124 | blank_padding=True, 125 | kernel_size=3, 126 | padding_size=1, 127 | word2vec=word2vec, 128 | dropout=0.5 129 | ) 130 | else: 131 | raise NotImplementedError 132 | 133 | 134 | # Define the model 135 | if args.aggr == 'att': 136 | model = opennre.model.BagAttention(sentence_encoder, len(rel2id), rel2id) 137 | elif args.aggr == 'avg': 138 | model = opennre.model.BagAverage(sentence_encoder, len(rel2id), rel2id) 139 | elif args.aggr == 'one': 140 | model = opennre.model.BagOne(sentence_encoder, len(rel2id), rel2id) 141 | else: 142 | raise NotImplementedError 143 | 144 | # Define the whole training framework 145 | framework = opennre.framework.BagRE( 146 | train_path=args.train_file, 147 | val_path=args.val_file, 148 | test_path=args.test_file, 149 | model=model, 150 | ckpt=ckpt, 151 | batch_size=args.batch_size, 152 | max_epoch=args.max_epoch, 153 | lr=args.lr, 154 | weight_decay=args.weight_decay, 155 | opt=args.optim, 156 | bag_size=args.bag_size) 157 | 158 | # Train the model 159 | if not args.only_test: 160 | framework.train_model(args.metric) 161 | 162 | # Test the model 163 | framework.load_state_dict(torch.load(ckpt)['state_dict']) 164 | result = framework.eval_model(framework.test_loader) 165 | 166 | # Print the result 167 | logging.info('Test set results:') 168 | logging.info('AUC: %.5f' % (result['auc'])) 169 | logging.info('Maximum micro F1: %.5f' % (result['max_micro_f1'])) 170 | logging.info('Maximum macro F1: %.5f' % (result['max_macro_f1'])) 171 | logging.info('Micro F1: %.5f' % (result['micro_f1'])) 172 | logging.info('Macro F1: %.5f' % (result['macro_f1'])) 173 | logging.info('P@100: %.5f' % (result['p@100'])) 174 | logging.info('P@200: %.5f' % (result['p@200'])) 175 | logging.info('P@300: %.5f' % (result['p@300'])) 176 | 177 | # Save precision/recall points 178 | np.save('result/{}_p.npy'.format(args.result), result['np_prec']) 179 | np.save('result/{}_r.npy'.format(args.result), result['np_rec']) 180 | json.dump(result['max_micro_f1_each_relation'], open('result/{}_mmicrof1_rel.json'.format(args.result), 'w'), ensure_ascii=False) 181 | -------------------------------------------------------------------------------- /example/train_supervised_bert.py: -------------------------------------------------------------------------------- 1 | # coding:utf-8 2 | import torch 3 | import numpy as np 4 | import json 5 | import opennre 6 | from opennre import encoder, model, framework 7 | import sys 8 | import os 9 | import argparse 10 | import logging 11 | import random 12 | 13 | def set_seed(seed): 14 | random.seed(seed) 15 | np.random.seed(seed) 16 | torch.manual_seed(seed) 17 | torch.cuda.manual_seed_all(seed) 18 | 19 | 20 | parser = argparse.ArgumentParser() 21 | parser.add_argument('--pretrain_path', default='bert-base-uncased', 22 | help='Pre-trained ckpt path / model name (hugginface)') 23 | parser.add_argument('--ckpt', default='', 24 | help='Checkpoint name') 25 | parser.add_argument('--pooler', default='entity', choices=['cls', 'entity'], 26 | help='Sentence representation pooler') 27 | parser.add_argument('--only_test', action='store_true', 28 | help='Only run test') 29 | parser.add_argument('--mask_entity', action='store_true', 30 | help='Mask entity mentions') 31 | 32 | # Data 33 | parser.add_argument('--metric', default='micro_f1', choices=['micro_f1', 'acc'], 34 | help='Metric for picking up best checkpoint') 35 | parser.add_argument('--dataset', default='none', choices=['none', 'semeval', 'wiki80', 'tacred'], 36 | help='Dataset. If not none, the following args can be ignored') 37 | parser.add_argument('--train_file', default='', type=str, 38 | help='Training data file') 39 | parser.add_argument('--val_file', default='', type=str, 40 | help='Validation data file') 41 | parser.add_argument('--test_file', default='', type=str, 42 | help='Test data file') 43 | parser.add_argument('--rel2id_file', default='', type=str, 44 | help='Relation to ID file') 45 | 46 | # Hyper-parameters 47 | parser.add_argument('--batch_size', default=64, type=int, 48 | help='Batch size') 49 | parser.add_argument('--lr', default=2e-5, type=float, 50 | help='Learning rate') 51 | parser.add_argument('--max_length', default=128, type=int, 52 | help='Maximum sentence length') 53 | parser.add_argument('--max_epoch', default=3, type=int, 54 | help='Max number of training epochs') 55 | 56 | # Seed 57 | parser.add_argument('--seed', default=42, type=int, 58 | help='Seed') 59 | 60 | args = parser.parse_args() 61 | 62 | # Set random seed 63 | set_seed(args.seed) 64 | 65 | # Some basic settings 66 | root_path = '.' 67 | sys.path.append(root_path) 68 | if not os.path.exists('ckpt'): 69 | os.mkdir('ckpt') 70 | if len(args.ckpt) == 0: 71 | args.ckpt = '{}_{}_{}'.format(args.dataset, args.pretrain_path, args.pooler) 72 | ckpt = 'ckpt/{}.pth.tar'.format(args.ckpt) 73 | 74 | if args.dataset != 'none': 75 | opennre.download(args.dataset, root_path=root_path) 76 | args.train_file = os.path.join(root_path, 'benchmark', args.dataset, '{}_train.txt'.format(args.dataset)) 77 | args.val_file = os.path.join(root_path, 'benchmark', args.dataset, '{}_val.txt'.format(args.dataset)) 78 | args.test_file = os.path.join(root_path, 'benchmark', args.dataset, '{}_test.txt'.format(args.dataset)) 79 | if not os.path.exists(args.test_file): 80 | logging.warn("Test file {} does not exist! Use val file instead".format(args.test_file)) 81 | args.test_file = args.val_file 82 | args.rel2id_file = os.path.join(root_path, 'benchmark', args.dataset, '{}_rel2id.json'.format(args.dataset)) 83 | if args.dataset == 'wiki80': 84 | args.metric = 'acc' 85 | else: 86 | args.metric = 'micro_f1' 87 | else: 88 | if not (os.path.exists(args.train_file) and os.path.exists(args.val_file) and os.path.exists(args.test_file) and os.path.exists(args.rel2id_file)): 89 | raise Exception('--train_file, --val_file, --test_file and --rel2id_file are not specified or files do not exist. Or specify --dataset') 90 | 91 | logging.info('Arguments:') 92 | for arg in vars(args): 93 | logging.info(' {}: {}'.format(arg, getattr(args, arg))) 94 | 95 | rel2id = json.load(open(args.rel2id_file)) 96 | 97 | # Define the sentence encoder 98 | if args.pooler == 'entity': 99 | sentence_encoder = opennre.encoder.BERTEntityEncoder( 100 | max_length=args.max_length, 101 | pretrain_path=args.pretrain_path, 102 | mask_entity=args.mask_entity 103 | ) 104 | elif args.pooler == 'cls': 105 | sentence_encoder = opennre.encoder.BERTEncoder( 106 | max_length=args.max_length, 107 | pretrain_path=args.pretrain_path, 108 | mask_entity=args.mask_entity 109 | ) 110 | else: 111 | raise NotImplementedError 112 | 113 | # Define the model 114 | model = opennre.model.SoftmaxNN(sentence_encoder, len(rel2id), rel2id) 115 | 116 | # Define the whole training framework 117 | framework = opennre.framework.SentenceRE( 118 | train_path=args.train_file, 119 | val_path=args.val_file, 120 | test_path=args.test_file, 121 | model=model, 122 | ckpt=ckpt, 123 | batch_size=args.batch_size, 124 | max_epoch=args.max_epoch, 125 | lr=args.lr, 126 | opt='adamw' 127 | ) 128 | 129 | # Train the model 130 | if not args.only_test: 131 | framework.train_model('micro_f1') 132 | 133 | # Test 134 | framework.load_state_dict(torch.load(ckpt)['state_dict']) 135 | result = framework.eval_model(framework.test_loader) 136 | 137 | # Print the result 138 | logging.info('Test set results:') 139 | logging.info('Accuracy: {}'.format(result['acc'])) 140 | logging.info('Micro precision: {}'.format(result['micro_p'])) 141 | logging.info('Micro recall: {}'.format(result['micro_r'])) 142 | logging.info('Micro F1: {}'.format(result['micro_f1'])) 143 | -------------------------------------------------------------------------------- /example/train_supervised_cnn.py: -------------------------------------------------------------------------------- 1 | # coding:utf-8 2 | import torch 3 | import numpy as np 4 | import json 5 | import opennre 6 | from opennre import encoder, model, framework 7 | import sys 8 | import os 9 | import argparse 10 | import logging 11 | import random 12 | 13 | def set_seed(seed): 14 | random.seed(seed) 15 | np.random.seed(seed) 16 | torch.manual_seed(seed) 17 | torch.cuda.manual_seed_all(seed) 18 | 19 | parser = argparse.ArgumentParser() 20 | parser.add_argument('--ckpt', default='', 21 | help='Checkpoint name') 22 | parser.add_argument('--only_test', action='store_true', 23 | help='Only run test') 24 | parser.add_argument('--encoder', default='pcnn', choices=['pcnn', 'cnn']) 25 | 26 | # Data 27 | parser.add_argument('--metric', default='micro_f1', choices=['micro_f1', 'acc'], 28 | help='Metric for picking up best checkpoint') 29 | parser.add_argument('--dataset', default='none', choices=['none', 'semeval', 'wiki80', 'tacred'], 30 | help='Dataset. If not none, the following args can be ignored') 31 | parser.add_argument('--train_file', default='', type=str, 32 | help='Training data file') 33 | parser.add_argument('--val_file', default='', type=str, 34 | help='Validation data file') 35 | parser.add_argument('--test_file', default='', type=str, 36 | help='Test data file') 37 | parser.add_argument('--rel2id_file', default='', type=str, 38 | help='Relation to ID file') 39 | 40 | # Hyper-parameters 41 | parser.add_argument('--batch_size', default=160, type=int, 42 | help='Batch size') 43 | parser.add_argument('--lr', default=1e-1, type=float, 44 | help='Learning rate') 45 | parser.add_argument('--weight_decay', default=1e-5, type=float, 46 | help='Weight decay') 47 | parser.add_argument('--max_length', default=128, type=int, 48 | help='Maximum sentence length') 49 | parser.add_argument('--max_epoch', default=100, type=int, 50 | help='Max number of training epochs') 51 | 52 | # Others 53 | parser.add_argument('--seed', default=42, type=int, 54 | help='Random seed') 55 | 56 | args = parser.parse_args() 57 | 58 | # Set random seed 59 | set_seed(args.seed) 60 | 61 | # Some basic settings 62 | root_path = '.' 63 | sys.path.append(root_path) 64 | if not os.path.exists('ckpt'): 65 | os.mkdir('ckpt') 66 | if len(args.ckpt) == 0: 67 | args.ckpt = '{}_{}'.format(args.dataset, 'cnn') 68 | ckpt = 'ckpt/{}.pth.tar'.format(args.ckpt) 69 | 70 | if args.dataset != 'none': 71 | opennre.download(args.dataset, root_path=root_path) 72 | args.train_file = os.path.join(root_path, 'benchmark', args.dataset, '{}_train.txt'.format(args.dataset)) 73 | args.val_file = os.path.join(root_path, 'benchmark', args.dataset, '{}_val.txt'.format(args.dataset)) 74 | args.test_file = os.path.join(root_path, 'benchmark', args.dataset, '{}_test.txt'.format(args.dataset)) 75 | args.rel2id_file = os.path.join(root_path, 'benchmark', args.dataset, '{}_rel2id.json'.format(args.dataset)) 76 | if args.dataset == 'wiki80': 77 | args.metric = 'acc' 78 | else: 79 | args.metric = 'micro_f1' 80 | else: 81 | if not (os.path.exists(args.train_file) and os.path.exists(args.val_file) and os.path.exists(args.test_file) and os.path.exists(args.rel2id_file)): 82 | raise Exception('--train_file, --val_file, --test_file and --rel2id_file are not specified or files do not exist. Or specify --dataset') 83 | 84 | logging.info('Arguments:') 85 | for arg in vars(args): 86 | logging.info(' {}: {}'.format(arg, getattr(args, arg))) 87 | 88 | rel2id = json.load(open(args.rel2id_file)) 89 | 90 | # Download glove 91 | opennre.download('glove', root_path=root_path) 92 | word2id = json.load(open(os.path.join(root_path, 'pretrain/glove/glove.6B.50d_word2id.json'))) 93 | word2vec = np.load(os.path.join(root_path, 'pretrain/glove/glove.6B.50d_mat.npy')) 94 | 95 | # Define the sentence encoder 96 | if args.encoder == 'pcnn': 97 | sentence_encoder = opennre.encoder.PCNNEncoder( 98 | token2id=word2id, 99 | max_length=args.max_length, 100 | word_size=50, 101 | position_size=5, 102 | hidden_size=230, 103 | blank_padding=True, 104 | kernel_size=3, 105 | padding_size=1, 106 | word2vec=word2vec, 107 | dropout=0.5 108 | ) 109 | elif args.encoder == 'cnn': 110 | sentence_encoder = opennre.encoder.CNNEncoder( 111 | token2id=word2id, 112 | max_length=args.max_length, 113 | word_size=50, 114 | position_size=5, 115 | hidden_size=230, 116 | blank_padding=True, 117 | kernel_size=3, 118 | padding_size=1, 119 | word2vec=word2vec, 120 | dropout=0.5 121 | ) 122 | else: 123 | raise NotImplementedError 124 | 125 | # Define the model 126 | model = opennre.model.SoftmaxNN(sentence_encoder, len(rel2id), rel2id) 127 | 128 | # Define the whole training framework 129 | framework = opennre.framework.SentenceRE( 130 | train_path=args.train_file, 131 | val_path=args.val_file, 132 | test_path=args.test_file, 133 | model=model, 134 | ckpt=ckpt, 135 | batch_size=args.batch_size, 136 | max_epoch=args.max_epoch, 137 | lr=args.lr, 138 | weight_decay=args.weight_decay, 139 | opt='sgd' 140 | ) 141 | 142 | # Train the model 143 | if not args.only_test: 144 | framework.train_model(args.metric) 145 | 146 | # Test 147 | framework.load_state_dict(torch.load(ckpt)['state_dict']) 148 | result = framework.eval_model(framework.test_loader) 149 | 150 | # Print the result 151 | logging.info('Test set results:') 152 | if args.metric == 'acc': 153 | logging.info('Accuracy: {}'.format(result['acc'])) 154 | else: 155 | logging.info('Micro precision: {}'.format(result['micro_p'])) 156 | logging.info('Micro recall: {}'.format(result['micro_r'])) 157 | logging.info('Micro F1: {}'.format(result['micro_f1'])) 158 | -------------------------------------------------------------------------------- /opennre/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | from .pretrain import check_root, get_model, download, download_pretrain 6 | import logging 7 | import os 8 | 9 | logging.basicConfig(format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', level=os.environ.get("LOGLEVEL", "INFO")) 10 | 11 | def fix_seed(seed=12345): 12 | import torch 13 | import numpy as np 14 | import random 15 | torch.manual_seed(seed) # cpu 16 | torch.cuda.manual_seed(seed) # gpu 17 | np.random.seed(seed) # numpy 18 | random.seed(seed) # random and transforms 19 | torch.backends.cudnn.deterministic=True # cudnn 20 | -------------------------------------------------------------------------------- /opennre/encoder/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | from .cnn_encoder import CNNEncoder 6 | from .pcnn_encoder import PCNNEncoder 7 | from .bert_encoder import BERTEncoder, BERTEntityEncoder 8 | 9 | __all__ = [ 10 | 'CNNEncoder', 11 | 'PCNNEncoder', 12 | 'BERTEncoder', 13 | 'BERTEntityEncoder' 14 | ] -------------------------------------------------------------------------------- /opennre/encoder/base_encoder.py: -------------------------------------------------------------------------------- 1 | import math, logging 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import numpy as np 6 | from ..tokenization import WordTokenizer 7 | 8 | class BaseEncoder(nn.Module): 9 | 10 | def __init__(self, 11 | token2id, 12 | max_length=128, 13 | hidden_size=230, 14 | word_size=50, 15 | position_size=5, 16 | blank_padding=True, 17 | word2vec=None, 18 | mask_entity=False): 19 | """ 20 | Args: 21 | token2id: dictionary of token->idx mapping 22 | max_length: max length of sentence, used for postion embedding 23 | hidden_size: hidden size 24 | word_size: size of word embedding 25 | position_size: size of position embedding 26 | blank_padding: padding for CNN 27 | word2vec: pretrained word2vec numpy 28 | """ 29 | # Hyperparameters 30 | super().__init__() 31 | 32 | self.token2id = token2id 33 | self.max_length = max_length 34 | self.num_token = len(token2id) 35 | self.num_position = max_length * 2 36 | self.mask_entity = mask_entity 37 | 38 | if word2vec is None: 39 | self.word_size = word_size 40 | else: 41 | self.word_size = word2vec.shape[-1] 42 | 43 | self.position_size = position_size 44 | self.hidden_size = hidden_size 45 | self.input_size = word_size + position_size * 2 46 | self.blank_padding = blank_padding 47 | 48 | if not '[UNK]' in self.token2id: 49 | self.token2id['[UNK]'] = len(self.token2id) 50 | self.num_token += 1 51 | if not '[PAD]' in self.token2id: 52 | self.token2id['[PAD]'] = len(self.token2id) 53 | self.num_token += 1 54 | 55 | # Word embedding 56 | self.word_embedding = nn.Embedding(self.num_token, self.word_size) 57 | if word2vec is not None: 58 | logging.info("Initializing word embedding with word2vec.") 59 | word2vec = torch.from_numpy(word2vec) 60 | if self.num_token == len(word2vec) + 2: 61 | unk = torch.randn(1, self.word_size) / math.sqrt(self.word_size) 62 | blk = torch.zeros(1, self.word_size) 63 | self.word_embedding.weight.data.copy_(torch.cat([word2vec, unk, blk], 0)) 64 | else: 65 | self.word_embedding.weight.data.copy_(word2vec) 66 | 67 | # Position Embedding 68 | self.pos1_embedding = nn.Embedding(2 * max_length, self.position_size, padding_idx=0) 69 | self.pos2_embedding = nn.Embedding(2 * max_length, self.position_size, padding_idx=0) 70 | self.tokenizer = WordTokenizer(vocab=self.token2id, unk_token="[UNK]") 71 | 72 | def forward(self, token, pos1, pos2): 73 | """ 74 | Args: 75 | token: (B, L), index of tokens 76 | pos1: (B, L), relative position to head entity 77 | pos2: (B, L), relative position to tail entity 78 | Return: 79 | (B, H), representations for sentences 80 | """ 81 | # Check size of tensors 82 | pass 83 | 84 | def tokenize(self, item): 85 | """ 86 | Args: 87 | item: input instance, including sentence, entity positions, etc. 88 | Return: 89 | index number of tokens and positions 90 | """ 91 | if 'text' in item: 92 | sentence = item['text'] 93 | is_token = False 94 | else: 95 | sentence = item['token'] 96 | is_token = True 97 | pos_head = item['h']['pos'] 98 | pos_tail = item['t']['pos'] 99 | 100 | # Sentence -> token 101 | if not is_token: 102 | if pos_head[0] > pos_tail[0]: 103 | pos_min, pos_max = [pos_tail, pos_head] 104 | rev = True 105 | else: 106 | pos_min, pos_max = [pos_head, pos_tail] 107 | rev = False 108 | sent_0 = self.tokenizer.tokenize(sentence[:pos_min[0]]) 109 | sent_1 = self.tokenizer.tokenize(sentence[pos_min[1]:pos_max[0]]) 110 | sent_2 = self.tokenizer.tokenize(sentence[pos_max[1]:]) 111 | ent_0 = self.tokenizer.tokenize(sentence[pos_min[0]:pos_min[1]]) 112 | ent_1 = self.tokenizer.tokenize(sentence[pos_max[0]:pos_max[1]]) 113 | if self.mask_entity: 114 | ent_0 = ['[UNK]'] 115 | ent_1 = ['[UNK]'] 116 | tokens = sent_0 + ent_0 + sent_1 + ent_1 + sent_2 117 | if rev: 118 | pos_tail = [len(sent_0), len(sent_0) + len(ent_0)] 119 | pos_head = [len(sent_0) + len(ent_0) + len(sent_1), len(sent_0) + len(ent_0) + len(sent_1) + len(ent_1)] 120 | else: 121 | pos_head = [len(sent_0), len(sent_0) + len(ent_0)] 122 | pos_tail = [len(sent_0) + len(ent_0) + len(sent_1), len(sent_0) + len(ent_0) + len(sent_1) + len(ent_1)] 123 | else: 124 | tokens = sentence 125 | 126 | # Token -> index 127 | if self.blank_padding: 128 | indexed_tokens = self.tokenizer.convert_tokens_to_ids(tokens, self.max_length, self.token2id['[PAD]'], self.token2id['[UNK]']) 129 | else: 130 | indexed_tokens = self.tokenizer.convert_tokens_to_ids(tokens, unk_id = self.token2id['[UNK]']) 131 | 132 | # Position -> index 133 | pos1 = [] 134 | pos2 = [] 135 | pos1_in_index = min(pos_head[0], self.max_length) 136 | pos2_in_index = min(pos_tail[0], self.max_length) 137 | for i in range(len(tokens)): 138 | pos1.append(min(i - pos1_in_index + self.max_length, 2 * self.max_length - 1)) 139 | pos2.append(min(i - pos2_in_index + self.max_length, 2 * self.max_length - 1)) 140 | 141 | if self.blank_padding: 142 | while len(pos1) < self.max_length: 143 | pos1.append(0) 144 | while len(pos2) < self.max_length: 145 | pos2.append(0) 146 | indexed_tokens = indexed_tokens[:self.max_length] 147 | pos1 = pos1[:self.max_length] 148 | pos2 = pos2[:self.max_length] 149 | 150 | indexed_tokens = torch.tensor(indexed_tokens).long().unsqueeze(0) # (1, L) 151 | pos1 = torch.tensor(pos1).long().unsqueeze(0) # (1, L) 152 | pos2 = torch.tensor(pos2).long().unsqueeze(0) # (1, L) 153 | 154 | return indexed_tokens, pos1, pos2 155 | -------------------------------------------------------------------------------- /opennre/encoder/bert_encoder.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import torch 3 | import torch.nn as nn 4 | from transformers import BertModel, BertTokenizer 5 | from .base_encoder import BaseEncoder 6 | 7 | class BERTEncoder(nn.Module): 8 | def __init__(self, max_length, pretrain_path, blank_padding=True, mask_entity=False): 9 | """ 10 | Args: 11 | max_length: max length of sentence 12 | pretrain_path: path of pretrain model 13 | """ 14 | super().__init__() 15 | self.max_length = max_length 16 | self.blank_padding = blank_padding 17 | self.hidden_size = 768 18 | self.mask_entity = mask_entity 19 | logging.info('Loading BERT pre-trained checkpoint.') 20 | self.bert = BertModel.from_pretrained(pretrain_path) 21 | self.tokenizer = BertTokenizer.from_pretrained(pretrain_path) 22 | 23 | def forward(self, token, att_mask, pos1, pos2): 24 | """ 25 | Args: 26 | token: (B, L), index of tokens 27 | att_mask: (B, L), attention mask (1 for contents and 0 for padding) 28 | Return: 29 | (B, H), representations for sentences 30 | """ 31 | _, x = self.bert(token, attention_mask=att_mask, return_dict=False) 32 | #return_dict=fault is set to adapt to the new version of transformers 33 | return x 34 | 35 | def tokenize(self, item): 36 | """ 37 | Args: 38 | item: data instance containing 'text' / 'token', 'h' and 't' 39 | Return: 40 | Name of the relation of the sentence 41 | """ 42 | # Sentence -> token 43 | if 'text' in item: 44 | sentence = item['text'] 45 | is_token = False 46 | else: 47 | sentence = item['token'] 48 | is_token = True 49 | pos_head = item['h']['pos'] 50 | pos_tail = item['t']['pos'] 51 | 52 | pos_min = pos_head 53 | pos_max = pos_tail 54 | if pos_head[0] > pos_tail[0]: 55 | pos_min = pos_tail 56 | pos_max = pos_head 57 | rev = True 58 | else: 59 | rev = False 60 | 61 | if not is_token: 62 | sent0 = self.tokenizer.tokenize(sentence[:pos_min[0]]) 63 | ent0 = self.tokenizer.tokenize(sentence[pos_min[0]:pos_min[1]]) 64 | sent1 = self.tokenizer.tokenize(sentence[pos_min[1]:pos_max[0]]) 65 | ent1 = self.tokenizer.tokenize(sentence[pos_max[0]:pos_max[1]]) 66 | sent2 = self.tokenizer.tokenize(sentence[pos_max[1]:]) 67 | else: 68 | sent0 = self.tokenizer.tokenize(' '.join(sentence[:pos_min[0]])) 69 | ent0 = self.tokenizer.tokenize(' '.join(sentence[pos_min[0]:pos_min[1]])) 70 | sent1 = self.tokenizer.tokenize(' '.join(sentence[pos_min[1]:pos_max[0]])) 71 | ent1 = self.tokenizer.tokenize(' '.join(sentence[pos_max[0]:pos_max[1]])) 72 | sent2 = self.tokenizer.tokenize(' '.join(sentence[pos_max[1]:])) 73 | 74 | if self.mask_entity: 75 | ent0 = ['[unused4]'] if not rev else ['[unused5]'] 76 | ent1 = ['[unused5]'] if not rev else ['[unused4]'] 77 | else: 78 | ent0 = ['[unused0]'] + ent0 + ['[unused1]'] if not rev else ['[unused2]'] + ent0 + ['[unused3]'] 79 | ent1 = ['[unused2]'] + ent1 + ['[unused3]'] if not rev else ['[unused0]'] + ent1 + ['[unused1]'] 80 | 81 | re_tokens = ['[CLS]'] + sent0 + ent0 + sent1 + ent1 + sent2 + ['[SEP]'] 82 | 83 | pos1 = 1 + len(sent0) if not rev else 1 + len(sent0 + ent0 + sent1) 84 | pos2 = 1 + len(sent0 + ent0 + sent1) if not rev else 1 + len(sent0) 85 | pos1 = min(self.max_length - 1, pos1) 86 | pos2 = min(self.max_length - 1, pos2) 87 | indexed_tokens = self.tokenizer.convert_tokens_to_ids(re_tokens) 88 | avai_len = len(indexed_tokens) 89 | pos1 = torch.tensor([[pos1]]).long() 90 | pos2 = torch.tensor([[pos2]]).long() 91 | 92 | # Padding 93 | if self.blank_padding: 94 | while len(indexed_tokens) < self.max_length: 95 | indexed_tokens.append(0) # 0 is id for [PAD] 96 | indexed_tokens = indexed_tokens[:self.max_length] 97 | indexed_tokens = torch.tensor(indexed_tokens).long().unsqueeze(0) # (1, L) 98 | 99 | # Attention mask 100 | att_mask = torch.zeros(indexed_tokens.size()).long() # (1, L) 101 | att_mask[0, :avai_len] = 1 102 | 103 | return indexed_tokens, att_mask, pos1, pos2 104 | 105 | 106 | class BERTEntityEncoder(nn.Module): 107 | def __init__(self, max_length, pretrain_path, blank_padding=True, mask_entity=False): 108 | """ 109 | Args: 110 | max_length: max length of sentence 111 | pretrain_path: path of pretrain model 112 | """ 113 | super().__init__() 114 | self.max_length = max_length 115 | self.blank_padding = blank_padding 116 | self.hidden_size = 768 * 2 117 | self.mask_entity = mask_entity 118 | logging.info('Loading BERT pre-trained checkpoint.') 119 | self.bert = BertModel.from_pretrained(pretrain_path) 120 | self.tokenizer = BertTokenizer.from_pretrained(pretrain_path) 121 | self.linear = nn.Linear(self.hidden_size, self.hidden_size) 122 | 123 | def forward(self, token, att_mask, pos1, pos2): 124 | """ 125 | Args: 126 | token: (B, L), index of tokens 127 | att_mask: (B, L), attention mask (1 for contents and 0 for padding) 128 | pos1: (B, 1), position of the head entity starter 129 | pos2: (B, 1), position of the tail entity starter 130 | Return: 131 | (B, 2H), representations for sentences 132 | """ 133 | hidden, _ = self.bert(token, attention_mask=att_mask, return_dict=False) 134 | # Get entity start hidden state 135 | onehot_head = torch.zeros(hidden.size()[:2]).float().to(hidden.device) # (B, L) 136 | onehot_tail = torch.zeros(hidden.size()[:2]).float().to(hidden.device) # (B, L) 137 | onehot_head = onehot_head.scatter_(1, pos1, 1) 138 | onehot_tail = onehot_tail.scatter_(1, pos2, 1) 139 | head_hidden = (onehot_head.unsqueeze(2) * hidden).sum(1) # (B, H) 140 | tail_hidden = (onehot_tail.unsqueeze(2) * hidden).sum(1) # (B, H) 141 | x = torch.cat([head_hidden, tail_hidden], 1) # (B, 2H) 142 | x = self.linear(x) 143 | return x 144 | 145 | def tokenize(self, item): 146 | """ 147 | Args: 148 | item: data instance containing 'text' / 'token', 'h' and 't' 149 | Return: 150 | Name of the relation of the sentence 151 | """ 152 | # Sentence -> token 153 | if 'text' in item: 154 | sentence = item['text'] 155 | is_token = False 156 | else: 157 | sentence = item['token'] 158 | is_token = True 159 | pos_head = item['h']['pos'] 160 | pos_tail = item['t']['pos'] 161 | 162 | pos_min = pos_head 163 | pos_max = pos_tail 164 | if pos_head[0] > pos_tail[0]: 165 | pos_min = pos_tail 166 | pos_max = pos_head 167 | rev = True 168 | else: 169 | rev = False 170 | 171 | if not is_token: 172 | sent0 = self.tokenizer.tokenize(sentence[:pos_min[0]]) 173 | ent0 = self.tokenizer.tokenize(sentence[pos_min[0]:pos_min[1]]) 174 | sent1 = self.tokenizer.tokenize(sentence[pos_min[1]:pos_max[0]]) 175 | ent1 = self.tokenizer.tokenize(sentence[pos_max[0]:pos_max[1]]) 176 | sent2 = self.tokenizer.tokenize(sentence[pos_max[1]:]) 177 | else: 178 | sent0 = self.tokenizer.tokenize(' '.join(sentence[:pos_min[0]])) 179 | ent0 = self.tokenizer.tokenize(' '.join(sentence[pos_min[0]:pos_min[1]])) 180 | sent1 = self.tokenizer.tokenize(' '.join(sentence[pos_min[1]:pos_max[0]])) 181 | ent1 = self.tokenizer.tokenize(' '.join(sentence[pos_max[0]:pos_max[1]])) 182 | sent2 = self.tokenizer.tokenize(' '.join(sentence[pos_max[1]:])) 183 | 184 | if self.mask_entity: 185 | ent0 = ['[unused4]'] if not rev else ['[unused5]'] 186 | ent1 = ['[unused5]'] if not rev else ['[unused4]'] 187 | else: 188 | ent0 = ['[unused0]'] + ent0 + ['[unused1]'] if not rev else ['[unused2]'] + ent0 + ['[unused3]'] 189 | ent1 = ['[unused2]'] + ent1 + ['[unused3]'] if not rev else ['[unused0]'] + ent1 + ['[unused1]'] 190 | 191 | re_tokens = ['[CLS]'] + sent0 + ent0 + sent1 + ent1 + sent2 + ['[SEP]'] 192 | pos1 = 1 + len(sent0) if not rev else 1 + len(sent0 + ent0 + sent1) 193 | pos2 = 1 + len(sent0 + ent0 + sent1) if not rev else 1 + len(sent0) 194 | pos1 = min(self.max_length - 1, pos1) 195 | pos2 = min(self.max_length - 1, pos2) 196 | 197 | indexed_tokens = self.tokenizer.convert_tokens_to_ids(re_tokens) 198 | avai_len = len(indexed_tokens) 199 | 200 | # Position 201 | pos1 = torch.tensor([[pos1]]).long() 202 | pos2 = torch.tensor([[pos2]]).long() 203 | 204 | # Padding 205 | if self.blank_padding: 206 | while len(indexed_tokens) < self.max_length: 207 | indexed_tokens.append(0) # 0 is id for [PAD] 208 | indexed_tokens = indexed_tokens[:self.max_length] 209 | indexed_tokens = torch.tensor(indexed_tokens).long().unsqueeze(0) # (1, L) 210 | 211 | # Attention mask 212 | att_mask = torch.zeros(indexed_tokens.size()).long() # (1, L) 213 | att_mask[0, :avai_len] = 1 214 | 215 | return indexed_tokens, att_mask, pos1, pos2 216 | -------------------------------------------------------------------------------- /opennre/encoder/cnn_encoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from ..module.nn import CNN 6 | from ..module.pool import MaxPool 7 | from .base_encoder import BaseEncoder 8 | 9 | class CNNEncoder(BaseEncoder): 10 | 11 | def __init__(self, 12 | token2id, 13 | max_length=128, 14 | hidden_size=230, 15 | word_size=50, 16 | position_size=5, 17 | blank_padding=True, 18 | word2vec=None, 19 | kernel_size=3, 20 | padding_size=1, 21 | dropout=0, 22 | activation_function=F.relu, 23 | mask_entity=False): 24 | """ 25 | Args: 26 | token2id: dictionary of token->idx mapping 27 | max_length: max length of sentence, used for postion embedding 28 | hidden_size: hidden size 29 | word_size: size of word embedding 30 | position_size: size of position embedding 31 | blank_padding: padding for CNN 32 | word2vec: pretrained word2vec numpy 33 | kernel_size: kernel_size size for CNN 34 | padding_size: padding_size for CNN 35 | """ 36 | # Hyperparameters 37 | super(CNNEncoder, self).__init__(token2id, max_length, hidden_size, word_size, position_size, blank_padding, word2vec, mask_entity=mask_entity) 38 | self.drop = nn.Dropout(dropout) 39 | self.kernel_size = kernel_size 40 | self.padding_size = padding_size 41 | self.act = activation_function 42 | 43 | self.conv = nn.Conv1d(self.input_size, self.hidden_size, self.kernel_size, padding=self.padding_size) 44 | self.pool = nn.MaxPool1d(self.max_length) 45 | 46 | def forward(self, token, pos1, pos2): 47 | """ 48 | Args: 49 | token: (B, L), index of tokens 50 | pos1: (B, L), relative position to head entity 51 | pos2: (B, L), relative position to tail entity 52 | Return: 53 | (B, EMBED), representations for sentences 54 | """ 55 | # Check size of tensors 56 | if len(token.size()) != 2 or token.size() != pos1.size() or token.size() != pos2.size(): 57 | raise Exception("Size of token, pos1 ans pos2 should be (B, L)") 58 | x = torch.cat([self.word_embedding(token), 59 | self.pos1_embedding(pos1), 60 | self.pos2_embedding(pos2)], 2) # (B, L, EMBED) 61 | x = x.transpose(1, 2) # (B, EMBED, L) 62 | x = self.act(self.conv(x)) # (B, H, L) 63 | x = self.pool(x).squeeze(-1) 64 | x = self.drop(x) 65 | return x 66 | 67 | def tokenize(self, item): 68 | return super().tokenize(item) 69 | -------------------------------------------------------------------------------- /opennre/encoder/pcnn_encoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from ..module.nn import CNN 6 | from ..module.pool import MaxPool 7 | from .base_encoder import BaseEncoder 8 | 9 | from nltk import word_tokenize 10 | 11 | class PCNNEncoder(BaseEncoder): 12 | 13 | def __init__(self, 14 | token2id, 15 | max_length=128, 16 | hidden_size=230, 17 | word_size=50, 18 | position_size=5, 19 | blank_padding=True, 20 | word2vec=None, 21 | kernel_size=3, 22 | padding_size=1, 23 | dropout=0.0, 24 | activation_function=F.relu, 25 | mask_entity=False): 26 | """ 27 | Args: 28 | token2id: dictionary of token->idx mapping 29 | max_length: max length of sentence, used for postion embedding 30 | hidden_size: hidden size 31 | word_size: size of word embedding 32 | position_size: size of position embedding 33 | blank_padding: padding for CNN 34 | word2vec: pretrained word2vec numpy 35 | kernel_size: kernel_size size for CNN 36 | padding_size: padding_size for CNN 37 | """ 38 | # hyperparameters 39 | super().__init__(token2id, max_length, hidden_size, word_size, position_size, blank_padding, word2vec, mask_entity=mask_entity) 40 | self.drop = nn.Dropout(dropout) 41 | self.kernel_size = kernel_size 42 | self.padding_size = padding_size 43 | self.act = activation_function 44 | 45 | self.conv = nn.Conv1d(self.input_size, self.hidden_size, self.kernel_size, padding=self.padding_size) 46 | self.pool = nn.MaxPool1d(self.max_length) 47 | self.mask_embedding = nn.Embedding(4, 3) 48 | self.mask_embedding.weight.data.copy_(torch.FloatTensor([[0, 0, 0], [1, 0, 0], [0, 1, 0], [0, 0, 1]])) 49 | self.mask_embedding.weight.requires_grad = False 50 | self._minus = -100 51 | 52 | self.hidden_size *= 3 53 | 54 | def forward(self, token, pos1, pos2, mask): 55 | """ 56 | Args: 57 | token: (B, L), index of tokens 58 | pos1: (B, L), relative position to head entity 59 | pos2: (B, L), relative position to tail entity 60 | Return: 61 | (B, EMBED), representations for sentences 62 | """ 63 | # Check size of tensors 64 | if len(token.size()) != 2 or token.size() != pos1.size() or token.size() != pos2.size(): 65 | raise Exception("Size of token, pos1 ans pos2 should be (B, L)") 66 | x = torch.cat([self.word_embedding(token), 67 | self.pos1_embedding(pos1), 68 | self.pos2_embedding(pos2)], 2) # (B, L, EMBED) 69 | x = x.transpose(1, 2) # (B, EMBED, L) 70 | x = self.conv(x) # (B, H, L) 71 | 72 | mask = 1 - self.mask_embedding(mask).transpose(1, 2) # (B, L) -> (B, L, 3) -> (B, 3, L) 73 | pool1 = self.pool(self.act(x + self._minus * mask[:, 0:1, :])) # (B, H, 1) 74 | pool2 = self.pool(self.act(x + self._minus * mask[:, 1:2, :])) 75 | pool3 = self.pool(self.act(x + self._minus * mask[:, 2:3, :])) 76 | x = torch.cat([pool1, pool2, pool3], 1) # (B, 3H, 1) 77 | x = x.squeeze(2) # (B, 3H) 78 | x = self.drop(x) 79 | 80 | return x 81 | 82 | def tokenize(self, item): 83 | """ 84 | Args: 85 | sentence: string, the input sentence 86 | pos_head: [start, end], position of the head entity 87 | pos_end: [start, end], position of the tail entity 88 | is_token: if is_token == True, sentence becomes an array of token 89 | Return: 90 | Name of the relation of the sentence 91 | """ 92 | if 'text' in item: 93 | sentence = item['text'] 94 | is_token = False 95 | else: 96 | sentence = item['token'] 97 | is_token = True 98 | pos_head = item['h']['pos'] 99 | pos_tail = item['t']['pos'] 100 | 101 | # Sentence -> token 102 | if not is_token: 103 | if pos_head[0] > pos_tail[0]: 104 | pos_min, pos_max = [pos_tail, pos_head] 105 | rev = True 106 | else: 107 | pos_min, pos_max = [pos_head, pos_tail] 108 | rev = False 109 | sent_0 = self.tokenizer.tokenize(sentence[:pos_min[0]]) 110 | sent_1 = self.tokenizer.tokenize(sentence[pos_min[1]:pos_max[0]]) 111 | sent_2 = self.tokenizer.tokenize(sentence[pos_max[1]:]) 112 | ent_0 = self.tokenizer.tokenize(sentence[pos_min[0]:pos_min[1]]) 113 | ent_1 = self.tokenizer.tokenize(sentence[pos_max[0]:pos_max[1]]) 114 | if self.mask_entity: 115 | ent_0 = ['[UNK]'] 116 | ent_1 = ['[UNK]'] 117 | tokens = sent_0 + ent_0 + sent_1 + ent_1 + sent_2 118 | if rev: 119 | pos_tail = [len(sent_0), len(sent_0) + len(ent_0)] 120 | pos_head = [len(sent_0) + len(ent_0) + len(sent_1), len(sent_0) + len(ent_0) + len(sent_1) + len(ent_1)] 121 | else: 122 | pos_head = [len(sent_0), len(sent_0) + len(ent_0)] 123 | pos_tail = [len(sent_0) + len(ent_0) + len(sent_1), len(sent_0) + len(ent_0) + len(sent_1) + len(ent_1)] 124 | else: 125 | tokens = sentence 126 | 127 | # Token -> index 128 | if self.blank_padding: 129 | indexed_tokens = self.tokenizer.convert_tokens_to_ids(tokens, self.max_length, self.token2id['[PAD]'], self.token2id['[UNK]']) 130 | else: 131 | indexed_tokens = self.tokenizer.convert_tokens_to_ids(tokens, unk_id = self.token2id['[UNK]']) 132 | 133 | # Position -> index 134 | pos1 = [] 135 | pos2 = [] 136 | pos1_in_index = min(pos_head[0], self.max_length) 137 | pos2_in_index = min(pos_tail[0], self.max_length) 138 | for i in range(len(tokens)): 139 | pos1.append(min(i - pos1_in_index + self.max_length, 2 * self.max_length - 1)) 140 | pos2.append(min(i - pos2_in_index + self.max_length, 2 * self.max_length - 1)) 141 | 142 | if self.blank_padding: 143 | while len(pos1) < self.max_length: 144 | pos1.append(0) 145 | while len(pos2) < self.max_length: 146 | pos2.append(0) 147 | indexed_tokens = indexed_tokens[:self.max_length] 148 | pos1 = pos1[:self.max_length] 149 | pos2 = pos2[:self.max_length] 150 | 151 | indexed_tokens = torch.tensor(indexed_tokens).long().unsqueeze(0) # (1, L) 152 | pos1 = torch.tensor(pos1).long().unsqueeze(0) # (1, L) 153 | pos2 = torch.tensor(pos2).long().unsqueeze(0) # (1, L) 154 | 155 | # Mask 156 | mask = [] 157 | pos_min = min(pos1_in_index, pos2_in_index) 158 | pos_max = max(pos1_in_index, pos2_in_index) 159 | for i in range(len(tokens)): 160 | if i <= pos_min: 161 | mask.append(1) 162 | elif i <= pos_max: 163 | mask.append(2) 164 | else: 165 | mask.append(3) 166 | # Padding 167 | if self.blank_padding: 168 | while len(mask) < self.max_length: 169 | mask.append(0) 170 | mask = mask[:self.max_length] 171 | 172 | mask = torch.tensor(mask).long().unsqueeze(0) # (1, L) 173 | return indexed_tokens, pos1, pos2, mask 174 | -------------------------------------------------------------------------------- /opennre/framework/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | from .data_loader import SentenceREDataset, SentenceRELoader, BagREDataset, BagRELoader, MultiLabelSentenceREDataset, MultiLabelSentenceRELoader 6 | from .sentence_re import SentenceRE 7 | from .bag_re import BagRE 8 | from .multi_label_sentence_re import MultiLabelSentenceRE 9 | 10 | __all__ = [ 11 | 'SentenceREDataset', 12 | 'SentenceRELoader', 13 | 'SentenceRE', 14 | 'BagRE', 15 | 'BagREDataset', 16 | 'BagRELoader', 17 | 'MultiLabelSentenceREDataset', 18 | 'MultiLabelSentenceRELoader', 19 | 'MultiLabelSentenceRE' 20 | ] 21 | -------------------------------------------------------------------------------- /opennre/framework/bag_re.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn, optim 3 | import json 4 | from .data_loader import SentenceRELoader, BagRELoader 5 | from .utils import AverageMeter 6 | from tqdm import tqdm 7 | import os 8 | 9 | class BagRE(nn.Module): 10 | 11 | def __init__(self, 12 | model, 13 | train_path, 14 | val_path, 15 | test_path, 16 | ckpt, 17 | batch_size=32, 18 | max_epoch=100, 19 | lr=0.1, 20 | weight_decay=1e-5, 21 | opt='sgd', 22 | bag_size=0, 23 | loss_weight=False): 24 | 25 | super().__init__() 26 | self.max_epoch = max_epoch 27 | self.bag_size = bag_size 28 | # Load data 29 | if train_path != None: 30 | self.train_loader = BagRELoader( 31 | train_path, 32 | model.rel2id, 33 | model.sentence_encoder.tokenize, 34 | batch_size, 35 | True, 36 | bag_size=bag_size, 37 | entpair_as_bag=False) 38 | 39 | if val_path != None: 40 | self.val_loader = BagRELoader( 41 | val_path, 42 | model.rel2id, 43 | model.sentence_encoder.tokenize, 44 | batch_size, 45 | False, 46 | bag_size=bag_size, 47 | entpair_as_bag=True) 48 | 49 | if test_path != None: 50 | self.test_loader = BagRELoader( 51 | test_path, 52 | model.rel2id, 53 | model.sentence_encoder.tokenize, 54 | batch_size, 55 | False, 56 | bag_size=bag_size, 57 | entpair_as_bag=True 58 | ) 59 | # Model 60 | self.model = nn.DataParallel(model) 61 | # Criterion 62 | if loss_weight: 63 | self.criterion = nn.CrossEntropyLoss(weight=self.train_loader.dataset.weight) 64 | else: 65 | self.criterion = nn.CrossEntropyLoss() 66 | # Params and optimizer 67 | params = self.model.parameters() 68 | self.lr = lr 69 | if opt == 'sgd': 70 | self.optimizer = optim.SGD(params, lr, weight_decay=weight_decay) 71 | elif opt == 'adam': 72 | self.optimizer = optim.Adam(params, lr, weight_decay=weight_decay) 73 | elif opt == 'adamw': 74 | from transformers import AdamW 75 | params = list(self.named_parameters()) 76 | no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight'] 77 | grouped_params = [ 78 | { 79 | 'params': [p for n, p in params if not any(nd in n for nd in no_decay)], 80 | 'weight_decay': 0.01, 81 | 'lr': lr, 82 | 'ori_lr': lr 83 | }, 84 | { 85 | 'params': [p for n, p in params if any(nd in n for nd in no_decay)], 86 | 'weight_decay': 0.0, 87 | 'lr': lr, 88 | 'ori_lr': lr 89 | } 90 | ] 91 | self.optimizer = AdamW(grouped_params, correct_bias=False) 92 | else: 93 | raise Exception("Invalid optimizer. Must be 'sgd' or 'adam' or 'bert_adam'.") 94 | # Cuda 95 | if torch.cuda.is_available(): 96 | self.cuda() 97 | # Ckpt 98 | self.ckpt = ckpt 99 | 100 | def train_model(self, metric='auc'): 101 | best_metric = 0 102 | for epoch in range(self.max_epoch): 103 | # Train 104 | self.train() 105 | print("=== Epoch %d train ===" % epoch) 106 | avg_loss = AverageMeter() 107 | avg_acc = AverageMeter() 108 | avg_pos_acc = AverageMeter() 109 | t = tqdm(self.train_loader) 110 | for iter, data in enumerate(t): 111 | if torch.cuda.is_available(): 112 | for i in range(len(data)): 113 | try: 114 | data[i] = data[i].cuda() 115 | except: 116 | pass 117 | label = data[0] 118 | bag_name = data[1] 119 | scope = data[2] 120 | args = data[3:] 121 | logits = self.model(label, scope, *args, bag_size=self.bag_size) 122 | loss = self.criterion(logits, label) 123 | score, pred = logits.max(-1) # (B) 124 | acc = float((pred == label).long().sum()) / label.size(0) 125 | pos_total = (label != 0).long().sum() 126 | pos_correct = ((pred == label).long() * (label != 0).long()).sum() 127 | if pos_total > 0: 128 | pos_acc = float(pos_correct) / float(pos_total) 129 | else: 130 | pos_acc = 0 131 | 132 | # Log 133 | avg_loss.update(loss.item(), 1) 134 | avg_acc.update(acc, 1) 135 | avg_pos_acc.update(pos_acc, 1) 136 | t.set_postfix(loss=avg_loss.avg, acc=avg_acc.avg, pos_acc=avg_pos_acc.avg) 137 | 138 | # Optimize 139 | loss.backward() 140 | self.optimizer.step() 141 | self.optimizer.zero_grad() 142 | 143 | # Val 144 | print("=== Epoch %d val ===" % epoch) 145 | result = self.eval_model(self.val_loader) 146 | print("AUC: %.4f" % result['auc']) 147 | print("Micro F1: %.4f" % (result['max_micro_f1'])) 148 | if result[metric] > best_metric: 149 | print("Best ckpt and saved.") 150 | torch.save({'state_dict': self.model.module.state_dict()}, self.ckpt) 151 | best_metric = result[metric] 152 | print("Best %s on val set: %f" % (metric, best_metric)) 153 | 154 | def eval_model(self, eval_loader): 155 | self.model.eval() 156 | with torch.no_grad(): 157 | t = tqdm(eval_loader) 158 | pred_result = [] 159 | for iter, data in enumerate(t): 160 | if torch.cuda.is_available(): 161 | for i in range(len(data)): 162 | try: 163 | data[i] = data[i].cuda() 164 | except: 165 | pass 166 | label = data[0] 167 | bag_name = data[1] 168 | scope = data[2] 169 | args = data[3:] 170 | logits = self.model(None, scope, *args, train=False, bag_size=self.bag_size) # results after softmax 171 | logits = logits.cpu().numpy() 172 | for i in range(len(logits)): 173 | for relid in range(self.model.module.num_class): 174 | if self.model.module.id2rel[relid] != 'NA': 175 | pred_result.append({ 176 | 'entpair': bag_name[i][:2], 177 | 'relation': self.model.module.id2rel[relid], 178 | 'score': logits[i][relid] 179 | }) 180 | result = eval_loader.dataset.eval(pred_result) 181 | return result 182 | 183 | def load_state_dict(self, state_dict): 184 | self.model.module.load_state_dict(state_dict) 185 | -------------------------------------------------------------------------------- /opennre/framework/data_loader.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.utils.data as data 3 | import os, random, json, logging 4 | import numpy as np 5 | import sklearn.metrics 6 | 7 | class SentenceREDataset(data.Dataset): 8 | """ 9 | Sentence-level relation extraction dataset 10 | """ 11 | def __init__(self, path, rel2id, tokenizer, kwargs): 12 | """ 13 | Args: 14 | path: path of the input file 15 | rel2id: dictionary of relation->id mapping 16 | tokenizer: function of tokenizing 17 | """ 18 | super().__init__() 19 | self.path = path 20 | self.tokenizer = tokenizer 21 | self.rel2id = rel2id 22 | self.kwargs = kwargs 23 | 24 | # Load the file 25 | f = open(path) 26 | self.data = [] 27 | for line in f.readlines(): 28 | line = line.rstrip() 29 | if len(line) > 0: 30 | self.data.append(eval(line)) 31 | f.close() 32 | logging.info("Loaded sentence RE dataset {} with {} lines and {} relations.".format(path, len(self.data), len(self.rel2id))) 33 | 34 | def __len__(self): 35 | return len(self.data) 36 | 37 | def __getitem__(self, index): 38 | item = self.data[index] 39 | seq = list(self.tokenizer(item, **self.kwargs)) 40 | res = [self.rel2id[item['relation']]] + seq 41 | return [self.rel2id[item['relation']]] + seq # label, seq1, seq2, ... 42 | 43 | def collate_fn(data): 44 | data = list(zip(*data)) 45 | labels = data[0] 46 | seqs = data[1:] 47 | batch_labels = torch.tensor(labels).long() # (B) 48 | batch_seqs = [] 49 | for seq in seqs: 50 | batch_seqs.append(torch.cat(seq, 0)) # (B, L) 51 | return [batch_labels] + batch_seqs 52 | 53 | def eval(self, pred_result, use_name=False): 54 | """ 55 | Args: 56 | pred_result: a list of predicted label (id) 57 | Make sure that the `shuffle` param is set to `False` when getting the loader. 58 | use_name: if True, `pred_result` contains predicted relation names instead of ids 59 | Return: 60 | {'acc': xx} 61 | """ 62 | correct = 0 63 | total = len(self.data) 64 | correct_positive = 0 65 | pred_positive = 0 66 | gold_positive = 0 67 | neg = -1 68 | for name in ['NA', 'na', 'no_relation', 'Other', 'Others']: 69 | if name in self.rel2id: 70 | if use_name: 71 | neg = name 72 | else: 73 | neg = self.rel2id[name] 74 | break 75 | for i in range(total): 76 | if use_name: 77 | golden = self.data[i]['relation'] 78 | else: 79 | golden = self.rel2id[self.data[i]['relation']] 80 | if golden == pred_result[i]: 81 | correct += 1 82 | if golden != neg: 83 | correct_positive += 1 84 | if golden != neg: 85 | gold_positive +=1 86 | if pred_result[i] != neg: 87 | pred_positive += 1 88 | acc = float(correct) / float(total) 89 | try: 90 | micro_p = float(correct_positive) / float(pred_positive) 91 | except: 92 | micro_p = 0 93 | try: 94 | micro_r = float(correct_positive) / float(gold_positive) 95 | except: 96 | micro_r = 0 97 | try: 98 | micro_f1 = 2 * micro_p * micro_r / (micro_p + micro_r) 99 | except: 100 | micro_f1 = 0 101 | result = {'acc': acc, 'micro_p': micro_p, 'micro_r': micro_r, 'micro_f1': micro_f1} 102 | logging.info('Evaluation result: {}.'.format(result)) 103 | return result 104 | 105 | def SentenceRELoader(path, rel2id, tokenizer, batch_size, 106 | shuffle, num_workers=8, collate_fn=SentenceREDataset.collate_fn, **kwargs): 107 | dataset = SentenceREDataset(path = path, rel2id = rel2id, tokenizer = tokenizer, kwargs=kwargs) 108 | data_loader = data.DataLoader(dataset=dataset, 109 | batch_size=batch_size, 110 | shuffle=shuffle, 111 | pin_memory=True, 112 | num_workers=num_workers, 113 | collate_fn=collate_fn) 114 | return data_loader 115 | 116 | class BagREDataset(data.Dataset): 117 | """ 118 | Bag-level relation extraction dataset. Note that relation of NA should be named as 'NA'. 119 | """ 120 | def __init__(self, path, rel2id, tokenizer, entpair_as_bag=False, bag_size=0, mode=None): 121 | """ 122 | Args: 123 | path: path of the input file 124 | rel2id: dictionary of relation->id mapping 125 | tokenizer: function of tokenizing 126 | entpair_as_bag: if True, bags are constructed based on same 127 | entity pairs instead of same relation facts (ignoring 128 | relation labels) 129 | """ 130 | super().__init__() 131 | self.tokenizer = tokenizer 132 | self.rel2id = rel2id 133 | self.entpair_as_bag = entpair_as_bag 134 | self.bag_size = bag_size 135 | 136 | # Load the file 137 | f = open(path) 138 | self.data = [] 139 | for line in f: 140 | line = line.rstrip() 141 | if len(line) > 0: 142 | self.data.append(eval(line)) 143 | f.close() 144 | 145 | # Construct bag-level dataset (a bag contains instances sharing the same relation fact) 146 | if mode == None: 147 | self.weight = np.ones((len(self.rel2id)), dtype=np.float32) 148 | self.bag_scope = [] 149 | self.name2id = {} 150 | self.bag_name = [] 151 | self.facts = {} 152 | for idx, item in enumerate(self.data): 153 | # Annotated test set 154 | if 'anno_relation_list' in item: 155 | for r in item['anno_relation_list']: 156 | fact = (item['h']['id'], item['t']['id'], r) 157 | if r != 'NA': 158 | self.facts[fact] = 1 159 | assert entpair_as_bag 160 | name = (item['h']['id'], item['t']['id']) 161 | else: 162 | fact = (item['h']['id'], item['t']['id'], item['relation']) 163 | if item['relation'] != 'NA': 164 | self.facts[fact] = 1 165 | if entpair_as_bag: 166 | name = (item['h']['id'], item['t']['id']) 167 | else: 168 | name = fact 169 | if name not in self.name2id: 170 | self.name2id[name] = len(self.name2id) 171 | self.bag_scope.append([]) 172 | self.bag_name.append(name) 173 | self.bag_scope[self.name2id[name]].append(idx) 174 | self.weight[self.rel2id[item['relation']]] += 1.0 175 | self.weight = 1.0 / (self.weight ** 0.05) 176 | self.weight = torch.from_numpy(self.weight) 177 | else: 178 | pass 179 | 180 | def __len__(self): 181 | return len(self.bag_scope) 182 | 183 | def __getitem__(self, index): 184 | bag = self.bag_scope[index] 185 | if self.bag_size > 0: 186 | if self.bag_size <= len(bag): 187 | resize_bag = random.sample(bag, self.bag_size) 188 | else: 189 | resize_bag = bag + list(np.random.choice(bag, self.bag_size - len(bag))) 190 | bag = resize_bag 191 | 192 | seqs = None 193 | rel = self.rel2id[self.data[bag[0]]['relation']] 194 | for sent_id in bag: 195 | item = self.data[sent_id] 196 | seq = list(self.tokenizer(item)) 197 | if seqs is None: 198 | seqs = [] 199 | for i in range(len(seq)): 200 | seqs.append([]) 201 | for i in range(len(seq)): 202 | seqs[i].append(seq[i]) 203 | for i in range(len(seqs)): 204 | seqs[i] = torch.cat(seqs[i], 0) # (n, L), n is the size of bag 205 | return [rel, self.bag_name[index], len(bag)] + seqs 206 | 207 | def collate_fn(data): 208 | data = list(zip(*data)) 209 | label, bag_name, count = data[:3] 210 | seqs = data[3:] 211 | for i in range(len(seqs)): 212 | seqs[i] = torch.cat(seqs[i], 0) # (sumn, L) 213 | seqs[i] = seqs[i].expand((torch.cuda.device_count() if torch.cuda.device_count() > 0 else 1, ) + seqs[i].size()) 214 | scope = [] # (B, 2) 215 | start = 0 216 | for c in count: 217 | scope.append((start, start + c)) 218 | start += c 219 | assert(start == seqs[0].size(1)) 220 | scope = torch.tensor(scope).long() 221 | label = torch.tensor(label).long() # (B) 222 | return [label, bag_name, scope] + seqs 223 | 224 | def collate_bag_size_fn(data): 225 | data = list(zip(*data)) 226 | label, bag_name, count = data[:3] 227 | seqs = data[3:] 228 | for i in range(len(seqs)): 229 | seqs[i] = torch.stack(seqs[i], 0) # (batch, bag, L) 230 | scope = [] # (B, 2) 231 | start = 0 232 | for c in count: 233 | scope.append((start, start + c)) 234 | start += c 235 | label = torch.tensor(label).long() # (B) 236 | return [label, bag_name, scope] + seqs 237 | 238 | 239 | def eval(self, pred_result, threshold=0.5): 240 | """ 241 | Args: 242 | pred_result: a list with dict {'entpair': (head_id, tail_id), 'relation': rel, 'score': score}. 243 | Note that relation of NA should be excluded. 244 | Return: 245 | {'prec': narray[...], 'rec': narray[...], 'mean_prec': xx, 'f1': xx, 'auc': xx} 246 | prec (precision) and rec (recall) are in micro style. 247 | prec (precision) and rec (recall) are sorted in the decreasing order of the score. 248 | f1 is the max f1 score of those precison-recall points 249 | """ 250 | sorted_pred_result = sorted(pred_result, key=lambda x: x['score'], reverse=True) 251 | prec = [] 252 | rec = [] 253 | correct = 0 254 | total = len(self.facts) 255 | 256 | entpair = {} 257 | 258 | for i, item in enumerate(sorted_pred_result): 259 | # Save entpair label and result for later calculating F1 260 | idtf = item['entpair'][0] + '#' + item['entpair'][1] 261 | if idtf not in entpair: 262 | entpair[idtf] = { 263 | 'label': np.zeros((len(self.rel2id)), dtype=np.int), 264 | 'pred': np.zeros((len(self.rel2id)), dtype=np.int), 265 | 'score': np.zeros((len(self.rel2id)), dtype=np.float) 266 | } 267 | if (item['entpair'][0], item['entpair'][1], item['relation']) in self.facts: 268 | correct += 1 269 | entpair[idtf]['label'][self.rel2id[item['relation']]] = 1 270 | if item['score'] >= threshold: 271 | entpair[idtf]['pred'][self.rel2id[item['relation']]] = 1 272 | entpair[idtf]['score'][self.rel2id[item['relation']]] = item['score'] 273 | 274 | prec.append(float(correct) / float(i + 1)) 275 | rec.append(float(correct) / float(total)) 276 | 277 | auc = sklearn.metrics.auc(x=rec, y=prec) 278 | np_prec = np.array(prec) 279 | np_rec = np.array(rec) 280 | max_micro_f1 = (2 * np_prec * np_rec / (np_prec + np_rec + 1e-20)).max() 281 | best_threshold = sorted_pred_result[(2 * np_prec * np_rec / (np_prec + np_rec + 1e-20)).argmax()]['score'] 282 | mean_prec = np_prec.mean() 283 | 284 | label_vec = [] 285 | pred_result_vec = [] 286 | score_vec = [] 287 | for ep in entpair: 288 | label_vec.append(entpair[ep]['label']) 289 | pred_result_vec.append(entpair[ep]['pred']) 290 | score_vec.append(entpair[ep]['score']) 291 | label_vec = np.stack(label_vec, 0) 292 | pred_result_vec = np.stack(pred_result_vec, 0) 293 | score_vec = np.stack(score_vec, 0) 294 | 295 | micro_p = sklearn.metrics.precision_score(label_vec, pred_result_vec, labels=list(range(1, len(self.rel2id))), average='micro') 296 | micro_r = sklearn.metrics.recall_score(label_vec, pred_result_vec, labels=list(range(1, len(self.rel2id))), average='micro') 297 | micro_f1 = sklearn.metrics.f1_score(label_vec, pred_result_vec, labels=list(range(1, len(self.rel2id))), average='micro') 298 | 299 | macro_p = sklearn.metrics.precision_score(label_vec, pred_result_vec, labels=list(range(1, len(self.rel2id))), average='macro') 300 | macro_r = sklearn.metrics.recall_score(label_vec, pred_result_vec, labels=list(range(1, len(self.rel2id))), average='macro') 301 | macro_f1 = sklearn.metrics.f1_score(label_vec, pred_result_vec, labels=list(range(1, len(self.rel2id))), average='macro') 302 | 303 | pred_result_vec = score_vec >= best_threshold 304 | max_macro_f1 = sklearn.metrics.f1_score(label_vec, pred_result_vec, labels=list(range(1, len(self.rel2id))), average='macro') 305 | max_micro_f1_each_relation = {} 306 | for rel in self.rel2id: 307 | if rel != 'NA': 308 | max_micro_f1_each_relation[rel] = sklearn.metrics.f1_score(label_vec, pred_result_vec, labels=[self.rel2id[rel]], average='micro') 309 | 310 | return {'np_prec': np_prec, 'np_rec': np_rec, 'max_micro_f1': max_micro_f1, 'max_macro_f1': max_macro_f1, 'auc': auc, 'p@100': np_prec[99], 'p@200': np_prec[199], 'p@300': np_prec[299], 'avg_p300': (np_prec[99] + np_prec[199] + np_prec[299]) / 3, 'micro_f1': micro_f1, 'macro_f1': macro_f1, 'max_micro_f1_each_relation': max_micro_f1_each_relation} 311 | 312 | def BagRELoader(path, rel2id, tokenizer, batch_size, 313 | shuffle, entpair_as_bag=False, bag_size=0, num_workers=8, 314 | collate_fn=BagREDataset.collate_fn): 315 | if bag_size == 0: 316 | collate_fn = BagREDataset.collate_fn 317 | else: 318 | collate_fn = BagREDataset.collate_bag_size_fn 319 | dataset = BagREDataset(path, rel2id, tokenizer, entpair_as_bag=entpair_as_bag, bag_size=bag_size) 320 | data_loader = data.DataLoader(dataset=dataset, 321 | batch_size=batch_size, 322 | shuffle=shuffle, 323 | pin_memory=True, 324 | num_workers=num_workers, 325 | collate_fn=collate_fn) 326 | return data_loader 327 | 328 | 329 | class MultiLabelSentenceREDataset(data.Dataset): 330 | """ 331 | Sentence-level relation extraction dataset 332 | """ 333 | def __init__(self, path, rel2id, tokenizer, kwargs): 334 | """ 335 | Args: 336 | path: path of the input file 337 | rel2id: dictionary of relation->id mapping 338 | tokenizer: function of tokenizing 339 | """ 340 | super().__init__() 341 | self.path = path 342 | self.tokenizer = tokenizer 343 | self.rel2id = rel2id 344 | self.kwargs = kwargs 345 | 346 | # Load the file 347 | f = open(path) 348 | self.data = [] 349 | for line in f.readlines(): 350 | line = line.rstrip() 351 | if len(line) > 0: 352 | self.data.append(eval(line)) 353 | f.close() 354 | logging.info("Loaded sentence RE dataset {} with {} lines and {} relations.".format(path, len(self.data), len(self.rel2id))) 355 | 356 | def __len__(self): 357 | return len(self.data) 358 | 359 | def __getitem__(self, index): 360 | item = self.data[index] 361 | seq = list(self.tokenizer(item, **self.kwargs)) 362 | res = [self.rel2id[item['relation']]] + seq 363 | return [self.rel2id[item['relation']]] + seq # label, seq1, seq2, ... 364 | 365 | def collate_fn(data): 366 | data = list(zip(*data)) 367 | labels = data[0] 368 | seqs = data[1:] 369 | batch_labels = torch.tensor(labels).long() # (B) 370 | batch_seqs = [] 371 | for seq in seqs: 372 | batch_seqs.append(torch.cat(seq, 0)) # (B, L) 373 | return [batch_labels] + batch_seqs 374 | 375 | def eval(self, pred_score, threshold=0.5, use_name=False): 376 | """ 377 | Args: 378 | pred_score: [sent_num, label_num] 379 | use_name: if True, `pred_result` contains predicted relation names instead of ids 380 | Return: 381 | {'acc': xx} 382 | """ 383 | assert len(self.data) == len(pred_score) 384 | pred_score = np.array(pred_score) 385 | 386 | # Calculate AUC 387 | sorted_result = [] 388 | total = 0 389 | for sent_id in range(len(self.data)): 390 | for rel in self.rel2id: 391 | if rel not in ['NA', 'na', 'N/A', 'None', 'none', 'n/a', 'no_relation']: 392 | sorted_result.append({'sent_id': sent_id, 'relation': rel, 'score': pred_score[sent_id][self.rel2id[rel]]}) 393 | if 'anno_relation_list' in self.data[sent_id]: 394 | if rel in self.data[sent_id]['anno_relation_list']: 395 | total += 1 396 | else: 397 | if rel == self.data[sent_id]['relation']: 398 | total += 1 399 | 400 | sorted_result.sort(key=lambda x: x['score'], reverse=True) 401 | prec = [] 402 | rec = [] 403 | correct = 0 404 | for i, item in enumerate(sorted_result): 405 | if 'anno_relation_list' in self.data[item['sent_id']]: 406 | if item['relation'] in self.data[item['sent_id']]['anno_relation_list']: 407 | correct += 1 408 | else: 409 | if item['relation'] == self.data[item['sent_id']]['relation']: 410 | correct += 1 411 | prec.append(float(correct) / float(i + 1)) 412 | rec.append(float(correct) / float(total)) 413 | auc = sklearn.metrics.auc(x=rec, y=prec) 414 | np_prec = np.array(prec) 415 | np_rec = np.array(rec) 416 | max_micro_f1 = (2 * np_prec * np_rec / (np_prec + np_rec + 1e-20)).max() 417 | max_micro_f1_threshold = sorted_result[(2 * np_prec * np_rec / (np_prec + np_rec + 1e-20)).argmax()]['score'] 418 | mean_prec = np_prec.mean() 419 | 420 | # Calculate F1 421 | pred_result_vec = np.zeros((len(self.data), len(self.rel2id)), dtype=np.int) 422 | pred_result_vec[pred_score >= threshold] = 1 423 | label_vec = [] 424 | for item in self.data: 425 | if 'anno_relation_list' in item: 426 | label_vec.append(np.array(item['anno_relation_vec'], dtype=np.int)) 427 | else: 428 | one_hot = np.zeros((len(self.rel2id)), dtype=np.int) 429 | one_hot[self.rel2id[item['relation']]] = 1 430 | label_vec.append(one_hot) 431 | label_vec = np.stack(label_vec, 0) 432 | assert label_vec.shape == pred_result_vec.shape 433 | 434 | micro_p = sklearn.metrics.precision_score(label_vec, pred_result_vec, labels=list(range(1, len(self.rel2id))), average='micro') 435 | micro_r = sklearn.metrics.recall_score(label_vec, pred_result_vec, labels=list(range(1, len(self.rel2id))), average='micro') 436 | micro_f1 = sklearn.metrics.f1_score(label_vec, pred_result_vec, labels=list(range(1, len(self.rel2id))), average='micro') 437 | 438 | macro_p = sklearn.metrics.precision_score(label_vec, pred_result_vec, labels=list(range(1, len(self.rel2id))), average='macro') 439 | macro_r = sklearn.metrics.recall_score(label_vec, pred_result_vec, labels=list(range(1, len(self.rel2id))), average='macro') 440 | macro_f1 = sklearn.metrics.f1_score(label_vec, pred_result_vec, labels=list(range(1, len(self.rel2id))), average='macro') 441 | 442 | acc = (label_vec == pred_result_vec).mean() 443 | 444 | result = {'acc': acc, 'micro_p': micro_p, 'micro_r': micro_r, 'micro_f1': micro_f1, 'macro_p': macro_p, 'macro_r': macro_r, 'macro_f1': macro_f1, 'np_prec': np_prec, 'np_rec': np_rec, 'max_micro_f1': max_micro_f1, 'max_micro_f1_threshold': max_micro_f1_threshold, 'auc': auc, 'p@100': np_prec[99], 'p@200': np_prec[199], 'p@300': np_prec[299]} 445 | logging.info('Evaluation result: {}.'.format(result)) 446 | return result 447 | 448 | def MultiLabelSentenceRELoader(path, rel2id, tokenizer, batch_size, 449 | shuffle, num_workers=8, collate_fn=SentenceREDataset.collate_fn, **kwargs): 450 | dataset = MultiLabelSentenceREDataset(path = path, rel2id = rel2id, tokenizer = tokenizer, kwargs=kwargs) 451 | data_loader = data.DataLoader(dataset=dataset, 452 | batch_size=batch_size, 453 | shuffle=shuffle, 454 | pin_memory=True, 455 | num_workers=num_workers, 456 | collate_fn=collate_fn) 457 | return data_loader 458 | 459 | 460 | -------------------------------------------------------------------------------- /opennre/framework/multi_label_sentence_re.py: -------------------------------------------------------------------------------- 1 | import os, logging, json 2 | from tqdm import tqdm 3 | import torch 4 | from torch import nn, optim 5 | from .data_loader import MultiLabelSentenceRELoader 6 | from .utils import AverageMeter 7 | import numpy as np 8 | 9 | class MultiLabelSentenceRE(nn.Module): 10 | 11 | def __init__(self, 12 | model, 13 | train_path, 14 | val_path, 15 | test_path, 16 | ckpt, 17 | batch_size=32, 18 | max_epoch=100, 19 | lr=0.1, 20 | weight_decay=1e-5, 21 | warmup_step=300, 22 | opt='sgd'): 23 | 24 | super().__init__() 25 | self.max_epoch = max_epoch 26 | # Load data 27 | if train_path != None: 28 | self.train_loader = MultiLabelSentenceRELoader( 29 | train_path, 30 | model.rel2id, 31 | model.sentence_encoder.tokenize, 32 | batch_size, 33 | True) 34 | 35 | if val_path != None: 36 | self.val_loader = MultiLabelSentenceRELoader( 37 | val_path, 38 | model.rel2id, 39 | model.sentence_encoder.tokenize, 40 | batch_size, 41 | False) 42 | 43 | if test_path != None: 44 | self.test_loader = MultiLabelSentenceRELoader( 45 | test_path, 46 | model.rel2id, 47 | model.sentence_encoder.tokenize, 48 | batch_size, 49 | False 50 | ) 51 | # Model 52 | self.model = model 53 | self.parallel_model = nn.DataParallel(self.model) 54 | # Criterion 55 | self.criterion = nn.BCEWithLogitsLoss() 56 | # Params and optimizer 57 | params = self.parameters() 58 | self.lr = lr 59 | if opt == 'sgd': 60 | self.optimizer = optim.SGD(params, lr, weight_decay=weight_decay) 61 | elif opt == 'adam': 62 | self.optimizer = optim.Adam(params, lr, weight_decay=weight_decay) 63 | elif opt == 'adamw': # Optimizer for BERT 64 | from transformers import AdamW 65 | params = list(self.named_parameters()) 66 | no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight'] 67 | grouped_params = [ 68 | { 69 | 'params': [p for n, p in params if not any(nd in n for nd in no_decay)], 70 | 'weight_decay': 0.01, 71 | 'lr': lr, 72 | 'ori_lr': lr 73 | }, 74 | { 75 | 'params': [p for n, p in params if any(nd in n for nd in no_decay)], 76 | 'weight_decay': 0.0, 77 | 'lr': lr, 78 | 'ori_lr': lr 79 | } 80 | ] 81 | self.optimizer = AdamW(grouped_params, correct_bias=False) 82 | else: 83 | raise Exception("Invalid optimizer. Must be 'sgd' or 'adam' or 'adamw'.") 84 | # Warmup 85 | if warmup_step > 0: 86 | from transformers import get_linear_schedule_with_warmup 87 | training_steps = self.train_loader.dataset.__len__() // batch_size * self.max_epoch 88 | self.scheduler = get_linear_schedule_with_warmup(self.optimizer, num_warmup_steps=warmup_step, num_training_steps=training_steps) 89 | else: 90 | self.scheduler = None 91 | # Cuda 92 | if torch.cuda.is_available(): 93 | self.cuda() 94 | # Ckpt 95 | self.ckpt = ckpt 96 | 97 | def train_model(self, metric='acc'): 98 | best_metric = 0 99 | global_step = 0 100 | for epoch in range(self.max_epoch): 101 | self.train() 102 | logging.info("=== Epoch %d train ===" % epoch) 103 | avg_loss = AverageMeter() 104 | avg_acc = AverageMeter() 105 | t = tqdm(self.train_loader) 106 | for iter, data in enumerate(t): 107 | if torch.cuda.is_available(): 108 | for i in range(len(data)): 109 | try: 110 | data[i] = data[i].cuda() 111 | except: 112 | pass 113 | label = data[0] 114 | args = data[1:] 115 | logits = self.parallel_model(*args) 116 | 117 | label_vec = torch.zeros_like(logits).cuda() 118 | label_vec[torch.arange(label_vec.size(0)), label] = 1 119 | label_vec = label_vec[:, 1:] 120 | logits = logits[:, 1:] 121 | 122 | loss = self.criterion(logits.reshape(-1), label_vec.reshape(-1)) 123 | pred = (torch.sigmoid(logits) >= 0.5).long() 124 | acc = float((pred == label_vec).long().sum()) / (label_vec.size(0) * label_vec.size(1)) 125 | 126 | # Log 127 | avg_loss.update(loss.item(), 1) 128 | avg_acc.update(acc, 1) 129 | t.set_postfix(loss=avg_loss.avg, acc=avg_acc.avg) 130 | # Optimize 131 | loss.backward() 132 | self.optimizer.step() 133 | if self.scheduler is not None: 134 | self.scheduler.step() 135 | self.optimizer.zero_grad() 136 | global_step += 1 137 | # Val 138 | logging.info("=== Epoch %d val ===" % epoch) 139 | result = self.eval_model(self.val_loader) 140 | logging.info('Metric {} current / best: {} / {}'.format(metric, result[metric], best_metric)) 141 | if result[metric] > best_metric: 142 | logging.info("Best ckpt and saved.") 143 | folder_path = '/'.join(self.ckpt.split('/')[:-1]) 144 | if not os.path.exists(folder_path): 145 | os.mkdir(folder_path) 146 | torch.save({'state_dict': self.model.state_dict()}, self.ckpt) 147 | best_metric = result[metric] 148 | logging.info("Best %s on val set: %f" % (metric, best_metric)) 149 | 150 | def eval_model(self, eval_loader): 151 | self.eval() 152 | pred_score = [] 153 | with torch.no_grad(): 154 | t = tqdm(eval_loader) 155 | for iter, data in enumerate(t): 156 | if torch.cuda.is_available(): 157 | for i in range(len(data)): 158 | try: 159 | data[i] = data[i].cuda() 160 | except: 161 | pass 162 | label = data[0] 163 | args = data[1:] 164 | logits = self.parallel_model(*args) 165 | score = self.parallel_model.module.logit_to_score(logits).cpu().numpy() 166 | # Save result 167 | pred_score.append(score) 168 | # Log 169 | pred_score = np.concatenate(pred_score, 0) 170 | result = eval_loader.dataset.eval(pred_score) 171 | return result 172 | 173 | def load_state_dict(self, state_dict): 174 | self.model.load_state_dict(state_dict) 175 | 176 | -------------------------------------------------------------------------------- /opennre/framework/sentence_re.py: -------------------------------------------------------------------------------- 1 | import os, logging, json 2 | from tqdm import tqdm 3 | import torch 4 | from torch import nn, optim 5 | from .data_loader import SentenceRELoader 6 | from .utils import AverageMeter 7 | 8 | class SentenceRE(nn.Module): 9 | 10 | def __init__(self, 11 | model, 12 | train_path, 13 | val_path, 14 | test_path, 15 | ckpt, 16 | batch_size=32, 17 | max_epoch=100, 18 | lr=0.1, 19 | weight_decay=1e-5, 20 | warmup_step=300, 21 | opt='sgd'): 22 | 23 | super().__init__() 24 | self.max_epoch = max_epoch 25 | # Load data 26 | if train_path != None: 27 | self.train_loader = SentenceRELoader( 28 | train_path, 29 | model.rel2id, 30 | model.sentence_encoder.tokenize, 31 | batch_size, 32 | True) 33 | 34 | if val_path != None: 35 | self.val_loader = SentenceRELoader( 36 | val_path, 37 | model.rel2id, 38 | model.sentence_encoder.tokenize, 39 | batch_size, 40 | False) 41 | 42 | if test_path != None: 43 | self.test_loader = SentenceRELoader( 44 | test_path, 45 | model.rel2id, 46 | model.sentence_encoder.tokenize, 47 | batch_size, 48 | False 49 | ) 50 | # Model 51 | self.model = model 52 | self.parallel_model = nn.DataParallel(self.model) 53 | # Criterion 54 | self.criterion = nn.CrossEntropyLoss() 55 | # Params and optimizer 56 | params = self.parameters() 57 | self.lr = lr 58 | if opt == 'sgd': 59 | self.optimizer = optim.SGD(params, lr, weight_decay=weight_decay) 60 | elif opt == 'adam': 61 | self.optimizer = optim.Adam(params, lr, weight_decay=weight_decay) 62 | elif opt == 'adamw': # Optimizer for BERT 63 | from transformers import AdamW 64 | params = list(self.named_parameters()) 65 | no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight'] 66 | grouped_params = [ 67 | { 68 | 'params': [p for n, p in params if not any(nd in n for nd in no_decay)], 69 | 'weight_decay': 0.01, 70 | 'lr': lr, 71 | 'ori_lr': lr 72 | }, 73 | { 74 | 'params': [p for n, p in params if any(nd in n for nd in no_decay)], 75 | 'weight_decay': 0.0, 76 | 'lr': lr, 77 | 'ori_lr': lr 78 | } 79 | ] 80 | self.optimizer = AdamW(grouped_params, correct_bias=False) 81 | else: 82 | raise Exception("Invalid optimizer. Must be 'sgd' or 'adam' or 'adamw'.") 83 | # Warmup 84 | if warmup_step > 0: 85 | from transformers import get_linear_schedule_with_warmup 86 | training_steps = self.train_loader.dataset.__len__() // batch_size * self.max_epoch 87 | self.scheduler = get_linear_schedule_with_warmup(self.optimizer, num_warmup_steps=warmup_step, num_training_steps=training_steps) 88 | else: 89 | self.scheduler = None 90 | # Cuda 91 | if torch.cuda.is_available(): 92 | self.cuda() 93 | # Ckpt 94 | self.ckpt = ckpt 95 | 96 | def train_model(self, metric='acc'): 97 | best_metric = 0 98 | global_step = 0 99 | for epoch in range(self.max_epoch): 100 | self.train() 101 | logging.info("=== Epoch %d train ===" % epoch) 102 | avg_loss = AverageMeter() 103 | avg_acc = AverageMeter() 104 | t = tqdm(self.train_loader) 105 | for iter, data in enumerate(t): 106 | if torch.cuda.is_available(): 107 | for i in range(len(data)): 108 | try: 109 | data[i] = data[i].cuda() 110 | except: 111 | pass 112 | label = data[0] 113 | args = data[1:] 114 | logits = self.parallel_model(*args) 115 | loss = self.criterion(logits, label) 116 | score, pred = logits.max(-1) # (B) 117 | acc = float((pred == label).long().sum()) / label.size(0) 118 | # Log 119 | avg_loss.update(loss.item(), 1) 120 | avg_acc.update(acc, 1) 121 | t.set_postfix(loss=avg_loss.avg, acc=avg_acc.avg) 122 | # Optimize 123 | loss.backward() 124 | self.optimizer.step() 125 | if self.scheduler is not None: 126 | self.scheduler.step() 127 | self.optimizer.zero_grad() 128 | global_step += 1 129 | # Val 130 | logging.info("=== Epoch %d val ===" % epoch) 131 | result = self.eval_model(self.val_loader) 132 | logging.info('Metric {} current / best: {} / {}'.format(metric, result[metric], best_metric)) 133 | if result[metric] > best_metric: 134 | logging.info("Best ckpt and saved.") 135 | folder_path = '/'.join(self.ckpt.split('/')[:-1]) 136 | if not os.path.exists(folder_path): 137 | os.mkdir(folder_path) 138 | torch.save({'state_dict': self.model.state_dict()}, self.ckpt) 139 | best_metric = result[metric] 140 | logging.info("Best %s on val set: %f" % (metric, best_metric)) 141 | 142 | def eval_model(self, eval_loader): 143 | self.eval() 144 | avg_acc = AverageMeter() 145 | pred_result = [] 146 | with torch.no_grad(): 147 | t = tqdm(eval_loader) 148 | for iter, data in enumerate(t): 149 | if torch.cuda.is_available(): 150 | for i in range(len(data)): 151 | try: 152 | data[i] = data[i].cuda() 153 | except: 154 | pass 155 | label = data[0] 156 | args = data[1:] 157 | logits = self.parallel_model(*args) 158 | score, pred = logits.max(-1) # (B) 159 | # Save result 160 | for i in range(pred.size(0)): 161 | pred_result.append(pred[i].item()) 162 | # Log 163 | acc = float((pred == label).long().sum()) / label.size(0) 164 | avg_acc.update(acc, pred.size(0)) 165 | t.set_postfix(acc=avg_acc.avg) 166 | result = eval_loader.dataset.eval(pred_result) 167 | return result 168 | 169 | def load_state_dict(self, state_dict): 170 | self.model.load_state_dict(state_dict) 171 | 172 | -------------------------------------------------------------------------------- /opennre/framework/utils.py: -------------------------------------------------------------------------------- 1 | class AverageMeter(object): 2 | """ 3 | Computes and stores the average and current value of metrics. 4 | """ 5 | 6 | def __init__(self): 7 | self.reset() 8 | 9 | def reset(self): 10 | self.val = 0 11 | self.avg = 0 12 | self.sum = 0 13 | self.count = 0 14 | 15 | def update(self, val, n=0): 16 | self.val = val 17 | self.sum += val * n 18 | self.count += n 19 | self.avg = self.sum / (.0001 + self.count) 20 | 21 | def __str__(self): 22 | """ 23 | String representation for logging 24 | """ 25 | # for values that should be recorded exactly e.g. iteration number 26 | if self.count == 0: 27 | return str(self.val) 28 | # for stats 29 | return '%.4f (%.4f)' % (self.val, self.avg) 30 | -------------------------------------------------------------------------------- /opennre/model/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | from .base_model import SentenceRE, BagRE, FewShotRE, NER 6 | from .softmax_nn import SoftmaxNN 7 | from .sigmoid_nn import SigmoidNN 8 | from .bag_attention import BagAttention 9 | from .bag_average import BagAverage 10 | from .bag_one import BagOne 11 | 12 | __all__ = [ 13 | 'SentenceRE', 14 | 'BagRE', 15 | 'FewShotRE', 16 | 'NER', 17 | 'SoftmaxNN', 18 | 'BagAttention', 19 | 'BagAverage', 20 | 'BagOne' 21 | ] -------------------------------------------------------------------------------- /opennre/model/bag_attention.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn, optim 3 | from .base_model import BagRE 4 | 5 | class BagAttention(BagRE): 6 | """ 7 | Instance attention for bag-level relation extraction. 8 | """ 9 | 10 | def __init__(self, sentence_encoder, num_class, rel2id, use_diag=True): 11 | """ 12 | Args: 13 | sentence_encoder: encoder for sentences 14 | num_class: number of classes 15 | id2rel: dictionary of id -> relation name mapping 16 | """ 17 | super().__init__() 18 | self.sentence_encoder = sentence_encoder 19 | self.num_class = num_class 20 | self.fc = nn.Linear(self.sentence_encoder.hidden_size, num_class) 21 | self.softmax = nn.Softmax(-1) 22 | self.rel2id = rel2id 23 | self.id2rel = {} 24 | self.drop = nn.Dropout() 25 | for rel, id in rel2id.items(): 26 | self.id2rel[id] = rel 27 | if use_diag: 28 | self.use_diag = True 29 | self.diag = nn.Parameter(torch.ones(self.sentence_encoder.hidden_size)) 30 | else: 31 | self.use_diag = False 32 | 33 | def infer(self, bag): 34 | """ 35 | Args: 36 | bag: bag of sentences with the same entity pair 37 | [{ 38 | 'text' or 'token': ..., 39 | 'h': {'pos': [start, end], ...}, 40 | 't': {'pos': [start, end], ...} 41 | }] 42 | Return: 43 | (relation, score) 44 | """ 45 | self.eval() 46 | tokens = [] 47 | pos1s = [] 48 | pos2s = [] 49 | masks = [] 50 | for item in bag: 51 | token, pos1, pos2, mask = self.sentence_encoder.tokenize(item) 52 | tokens.append(token) 53 | pos1s.append(pos1) 54 | pos2s.append(pos2) 55 | masks.append(mask) 56 | tokens = torch.cat(tokens, 0).unsqueeze(0) # (n, L) 57 | pos1s = torch.cat(pos1s, 0).unsqueeze(0) 58 | pos2s = torch.cat(pos2s, 0).unsqueeze(0) 59 | masks = torch.cat(masks, 0).unsqueeze(0) 60 | scope = torch.tensor([[0, len(bag)]]).long() # (1, 2) 61 | bag_logits = self.forward(None, scope, tokens, pos1s, pos2s, masks, train=False).squeeze(0) # (N) after softmax 62 | score, pred = bag_logits.max(0) 63 | score = score.item() 64 | pred = pred.item() 65 | rel = self.id2rel[pred] 66 | return (rel, score) 67 | 68 | def forward(self, label, scope, token, pos1, pos2, mask=None, train=True, bag_size=0): 69 | """ 70 | Args: 71 | label: (B), label of the bag 72 | scope: (B), scope for each bag 73 | token: (nsum, L), index of tokens 74 | pos1: (nsum, L), relative position to head entity 75 | pos2: (nsum, L), relative position to tail entity 76 | mask: (nsum, L), used for piece-wise CNN 77 | Return: 78 | logits, (B, N) 79 | 80 | Dirty hack: 81 | When the encoder is BERT, the input is actually token, att_mask, pos1, pos2, but 82 | since the arguments are then fed into BERT encoder with the original order, 83 | the encoder can actually work out correclty. 84 | """ 85 | if bag_size > 0: 86 | token = token.view(-1, token.size(-1)) 87 | pos1 = pos1.view(-1, pos1.size(-1)) 88 | pos2 = pos2.view(-1, pos2.size(-1)) 89 | if mask is not None: 90 | mask = mask.view(-1, mask.size(-1)) 91 | else: 92 | begin, end = scope[0][0], scope[-1][1] 93 | token = token[:, begin:end, :].view(-1, token.size(-1)) 94 | pos1 = pos1[:, begin:end, :].view(-1, pos1.size(-1)) 95 | pos2 = pos2[:, begin:end, :].view(-1, pos2.size(-1)) 96 | if mask is not None: 97 | mask = mask[:, begin:end, :].view(-1, mask.size(-1)) 98 | scope = torch.sub(scope, torch.zeros_like(scope).fill_(begin)) 99 | 100 | # Attention 101 | if train: 102 | if mask is not None: 103 | rep = self.sentence_encoder(token, pos1, pos2, mask) # (nsum, H) 104 | else: 105 | rep = self.sentence_encoder(token, pos1, pos2) # (nsum, H) 106 | 107 | if bag_size == 0: 108 | bag_rep = [] 109 | query = torch.zeros((rep.size(0))).long() 110 | if torch.cuda.is_available(): 111 | query = query.cuda() 112 | for i in range(len(scope)): 113 | query[scope[i][0]:scope[i][1]] = label[i] 114 | att_mat = self.fc.weight[query] # (nsum, H) 115 | if self.use_diag: 116 | att_mat = att_mat * self.diag.unsqueeze(0) 117 | att_score = (rep * att_mat).sum(-1) # (nsum) 118 | 119 | for i in range(len(scope)): 120 | bag_mat = rep[scope[i][0]:scope[i][1]] # (n, H) 121 | softmax_att_score = self.softmax(att_score[scope[i][0]:scope[i][1]]) # (n) 122 | bag_rep.append((softmax_att_score.unsqueeze(-1) * bag_mat).sum(0)) # (n, 1) * (n, H) -> (n, H) -> (H) 123 | bag_rep = torch.stack(bag_rep, 0) # (B, H) 124 | else: 125 | batch_size = label.size(0) 126 | query = label.unsqueeze(1) # (B, 1) 127 | att_mat = self.fc.weight[query] # (B, 1, H) 128 | if self.use_diag: 129 | att_mat = att_mat * self.diag.unsqueeze(0) 130 | rep = rep.view(batch_size, bag_size, -1) 131 | att_score = (rep * att_mat).sum(-1) # (B, bag) 132 | softmax_att_score = self.softmax(att_score) # (B, bag) 133 | bag_rep = (softmax_att_score.unsqueeze(-1) * rep).sum(1) # (B, bag, 1) * (B, bag, H) -> (B, bag, H) -> (B, H) 134 | bag_rep = self.drop(bag_rep) 135 | bag_logits = self.fc(bag_rep) # (B, N) 136 | else: 137 | 138 | if bag_size == 0: 139 | rep = [] 140 | bs = 256 141 | total_bs = len(token) // bs + (1 if len(token) % bs != 0 else 0) 142 | for b in range(total_bs): 143 | with torch.no_grad(): 144 | left = bs * b 145 | right = min(bs * (b + 1), len(token)) 146 | if mask is not None: 147 | rep.append(self.sentence_encoder(token[left:right], pos1[left:right], pos2[left:right], mask[left:right]).detach()) # (nsum, H) 148 | else: 149 | rep.append(self.sentence_encoder(token[left:right], pos1[left:right], pos2[left:right]).detach()) # (nsum, H) 150 | rep = torch.cat(rep, 0) 151 | 152 | bag_logits = [] 153 | att_mat = self.fc.weight.transpose(0, 1) 154 | if self.use_diag: 155 | att_mat = att_mat * self.diag.unsqueeze(1) 156 | att_score = torch.matmul(rep, att_mat) # (nsum, H) * (H, N) -> (nsum, N) 157 | for i in range(len(scope)): 158 | bag_mat = rep[scope[i][0]:scope[i][1]] # (n, H) 159 | softmax_att_score = self.softmax(att_score[scope[i][0]:scope[i][1]].transpose(0, 1)) # (N, (softmax)n) 160 | rep_for_each_rel = torch.matmul(softmax_att_score, bag_mat) # (N, n) * (n, H) -> (N, H) 161 | logit_for_each_rel = self.softmax(self.fc(rep_for_each_rel)) # ((each rel)N, (logit)N) 162 | logit_for_each_rel = logit_for_each_rel.diag() # (N) 163 | bag_logits.append(logit_for_each_rel) 164 | bag_logits = torch.stack(bag_logits, 0) # after **softmax** 165 | else: 166 | if mask is not None: 167 | rep = self.sentence_encoder(token, pos1, pos2, mask) # (nsum, H) 168 | else: 169 | rep = self.sentence_encoder(token, pos1, pos2) # (nsum, H) 170 | 171 | batch_size = rep.size(0) // bag_size 172 | att_mat = self.fc.weight.transpose(0, 1) 173 | if self.use_diag: 174 | att_mat = att_mat * self.diag.unsqueeze(1) 175 | att_score = torch.matmul(rep, att_mat) # (nsum, H) * (H, N) -> (nsum, N) 176 | att_score = att_score.view(batch_size, bag_size, -1) # (B, bag, N) 177 | rep = rep.view(batch_size, bag_size, -1) # (B, bag, H) 178 | softmax_att_score = self.softmax(att_score.transpose(1, 2)) # (B, N, (softmax)bag) 179 | rep_for_each_rel = torch.matmul(softmax_att_score, rep) # (B, N, bag) * (B, bag, H) -> (B, N, H) 180 | bag_logits = self.softmax(self.fc(rep_for_each_rel)).diagonal(dim1=1, dim2=2) # (B, (each rel)N) 181 | return bag_logits 182 | 183 | -------------------------------------------------------------------------------- /opennre/model/bag_average.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn, optim 3 | from .base_model import BagRE 4 | 5 | class BagAverage(BagRE): 6 | """ 7 | Average policy for bag-level relation extraction. 8 | """ 9 | 10 | def __init__(self, sentence_encoder, num_class, rel2id): 11 | """ 12 | Args: 13 | sentence_encoder: encoder for sentences 14 | num_class: number of classes 15 | id2rel: dictionary of id -> relation name mapping 16 | """ 17 | super().__init__() 18 | self.sentence_encoder = sentence_encoder 19 | self.num_class = num_class 20 | self.fc = nn.Linear(self.sentence_encoder.hidden_size, num_class) 21 | self.softmax = nn.Softmax(-1) 22 | self.rel2id = rel2id 23 | self.id2rel = {} 24 | self.drop = nn.Dropout() 25 | for rel, id in rel2id.items(): 26 | self.id2rel[id] = rel 27 | 28 | def infer(self, bag): 29 | """ 30 | Args: 31 | bag: bag of sentences with the same entity pair 32 | [{ 33 | 'text' or 'token': ..., 34 | 'h': {'pos': [start, end], ...}, 35 | 't': {'pos': [start, end], ...} 36 | }] 37 | Return: 38 | (relation, score) 39 | """ 40 | pass 41 | 42 | """ 43 | tokens = [] 44 | pos1s = [] 45 | pos2s = [] 46 | masks = [] 47 | for item in bag: 48 | if 'text' in item: 49 | token, pos1, pos2, mask = self.tokenizer(item['text'], 50 | item['h']['pos'], item['t']['pos'], is_token=False, padding=True) 51 | else: 52 | token, pos1, pos2, mask = self.tokenizer(item['token'], 53 | item['h']['pos'], item['t']['pos'], is_token=True, padding=True) 54 | tokens.append(token) 55 | pos1s.append(pos1) 56 | pos2s.append(pos2) 57 | masks.append(mask) 58 | tokens = torch.cat(tokens, 0) # (n, L) 59 | pos1s = torch.cat(pos1s, 0) 60 | pos2s = torch.cat(pos2s, 0) 61 | masks = torch.cat(masks, 0) 62 | scope = torch.tensor([[0, len(bag)]]).long() # (1, 2) 63 | bag_logits = self.forward(None, scope, tokens, pos1s, pos2s, masks, train=False).squeeze(0) # (N) after softmax 64 | score, pred = bag_logits.max() 65 | score = score.item() 66 | pred = pred.item() 67 | rel = self.id2rel[pred] 68 | return (rel, score) 69 | """ 70 | 71 | def forward(self, label, scope, token, pos1, pos2, mask=None, train=True, bag_size=None): 72 | """ 73 | Args: 74 | label: (B), label of the bag 75 | scope: (B), scope for each bag 76 | token: (nsum, L), index of tokens 77 | pos1: (nsum, L), relative position to head entity 78 | pos2: (nsum, L), relative position to tail entity 79 | mask: (nsum, L), used for piece-wise CNN 80 | Return: 81 | logits, (B, N) 82 | """ 83 | if bag_size > 0: 84 | token = token.view(-1, token.size(-1)) 85 | pos1 = pos1.view(-1, pos1.size(-1)) 86 | pos2 = pos2.view(-1, pos2.size(-1)) 87 | if mask is not None: 88 | mask = mask.view(-1, mask.size(-1)) 89 | else: 90 | begin, end = scope[0][0], scope[-1][1] 91 | token = token[:, begin:end, :].view(-1, token.size(-1)) 92 | pos1 = pos1[:, begin:end, :].view(-1, pos1.size(-1)) 93 | pos2 = pos2[:, begin:end, :].view(-1, pos2.size(-1)) 94 | if mask is not None: 95 | mask = mask[:, begin:end, :].view(-1, mask.size(-1)) 96 | scope = torch.sub(scope, torch.zeros_like(scope).fill_(begin)) 97 | 98 | if train or bag_size > 0: 99 | if mask is not None: 100 | rep = self.sentence_encoder(token, pos1, pos2, mask) # (nsum, H) 101 | else: 102 | rep = self.sentence_encoder(token, pos1, pos2) # (nsum, H) 103 | else: 104 | rep = [] 105 | bs = 256 106 | total_bs = len(token) // bs + (1 if len(token) % bs != 0 else 0) 107 | for b in range(total_bs): 108 | with torch.no_grad(): 109 | left = bs * b 110 | right = min(bs * (b + 1), len(token)) 111 | if mask is not None: 112 | rep.append(self.sentence_encoder(token[left:right], pos1[left:right], pos2[left:right], mask[left:right]).detach()) # (nsum, H) 113 | else: 114 | rep.append(self.sentence_encoder(token[left:right], pos1[left:right], pos2[left:right]).detach()) # (nsum, H) 115 | rep = torch.cat(rep, 0) 116 | 117 | # Average 118 | bag_rep = [] 119 | if bag_size is None or bag_size == 0: 120 | for i in range(len(scope)): 121 | bag_rep.append(rep[scope[i][0]:scope[i][1]].mean(0)) 122 | bag_rep = torch.stack(bag_rep, 0) # (B, H) 123 | else: 124 | batch_size = len(scope) 125 | rep = rep.view(batch_size, bag_size, -1) # (B, bag, H) 126 | bag_rep = rep.mean(1) # (B, H) 127 | bag_rep = self.drop(bag_rep) 128 | bag_logits = self.fc(bag_rep) # (B, N) 129 | 130 | if not train: 131 | bag_logits = torch.softmax(bag_logits, -1) 132 | 133 | return bag_logits 134 | 135 | -------------------------------------------------------------------------------- /opennre/model/bag_one.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn, optim 3 | from .base_model import BagRE 4 | 5 | class BagOne(BagRE): 6 | """ 7 | Instance one(max) for bag-level relation extraction. 8 | """ 9 | 10 | def __init__(self, sentence_encoder, num_class, rel2id): 11 | """ 12 | Args: 13 | sentence_encoder: encoder for sentences 14 | num_class: number of classes 15 | id2rel: dictionary of id -> relation name mapping 16 | """ 17 | super().__init__() 18 | self.sentence_encoder = sentence_encoder 19 | self.num_class = num_class 20 | self.fc = nn.Linear(self.sentence_encoder.hidden_size, num_class) 21 | self.softmax = nn.Softmax(-1) 22 | self.rel2id = rel2id 23 | self.id2rel = {} 24 | self.drop = nn.Dropout() 25 | for rel, id in rel2id.items(): 26 | self.id2rel[id] = rel 27 | 28 | def infer(self, bag): 29 | """ 30 | Args: 31 | bag: bag of sentences with the same entity pair 32 | [{ 33 | 'text' or 'token': ..., 34 | 'h': {'pos': [start, end], ...}, 35 | 't': {'pos': [start, end], ...} 36 | }] 37 | Return: 38 | (relation, score) 39 | """ 40 | self.eval() 41 | tokens = [] 42 | pos1s = [] 43 | pos2s = [] 44 | masks = [] 45 | for item in bag: 46 | token, pos1, pos2, mask = self.sentence_encoder.tokenize(item) 47 | tokens.append(token) 48 | pos1s.append(pos1) 49 | pos2s.append(pos2) 50 | masks.append(mask) 51 | tokens = torch.cat(tokens, 0).unsqueeze(0) # (n, L) 52 | pos1s = torch.cat(pos1s, 0).unsqueeze(0) 53 | pos2s = torch.cat(pos2s, 0).unsqueeze(0) 54 | masks = torch.cat(masks, 0).unsqueeze(0) 55 | scope = torch.tensor([[0, len(bag)]]).long() # (1, 2) 56 | bag_logits = self.forward(None, scope, tokens, pos1s, pos2s, masks, train=False).squeeze(0) # (N) after softmax 57 | score, pred = bag_logits.max(0) 58 | score = score.item() 59 | pred = pred.item() 60 | rel = self.id2rel[pred] 61 | return (rel, score) 62 | 63 | def forward(self, label, scope, token, pos1, pos2, mask=None, train=True, bag_size=0): 64 | """ 65 | Args: 66 | label: (B), label of the bag 67 | scope: (B), scope for each bag 68 | token: (nsum, L), index of tokens 69 | pos1: (nsum, L), relative position to head entity 70 | pos2: (nsum, L), relative position to tail entity 71 | mask: (nsum, L), used for piece-wise CNN 72 | Return: 73 | logits, (B, N) 74 | """ 75 | # Encode 76 | if bag_size > 0: 77 | token = token.view(-1, token.size(-1)) 78 | pos1 = pos1.view(-1, pos1.size(-1)) 79 | pos2 = pos2.view(-1, pos2.size(-1)) 80 | if mask is not None: 81 | mask = mask.view(-1, mask.size(-1)) 82 | else: 83 | begin, end = scope[0][0], scope[-1][1] 84 | token = token[:, begin:end, :].view(-1, token.size(-1)) 85 | pos1 = pos1[:, begin:end, :].view(-1, pos1.size(-1)) 86 | pos2 = pos2[:, begin:end, :].view(-1, pos2.size(-1)) 87 | if mask is not None: 88 | mask = mask[:, begin:end, :].view(-1, mask.size(-1)) 89 | scope = torch.sub(scope, torch.zeros_like(scope).fill_(begin)) 90 | 91 | if train or bag_size > 0: 92 | if mask is not None: 93 | rep = self.sentence_encoder(token, pos1, pos2, mask) # (nsum, H) 94 | else: 95 | rep = self.sentence_encoder(token, pos1, pos2) # (nsum, H) 96 | else: 97 | rep = [] 98 | bs = 256 99 | total_bs = len(token) // bs + (1 if len(token) % bs != 0 else 0) 100 | for b in range(total_bs): 101 | with torch.no_grad(): 102 | left = bs * b 103 | right = min(bs * (b + 1), len(token)) 104 | if mask is not None: 105 | rep.append(self.sentence_encoder(token[left:right], pos1[left:right], pos2[left:right], mask[left:right]).detach()) # (nsum, H) 106 | else: 107 | rep.append(self.sentence_encoder(token[left:right], pos1[left:right], pos2[left:right]).detach()) # (nsum, H) 108 | rep = torch.cat(rep, 0) 109 | 110 | # Max 111 | if train: 112 | if bag_size == 0: 113 | bag_rep = [] 114 | query = torch.zeros((rep.size(0))).long() 115 | if torch.cuda.is_available(): 116 | query = query.cuda() 117 | for i in range(len(scope)): 118 | query[scope[i][0]:scope[i][1]] = label[i] 119 | 120 | for i in range(len(scope)): # iterate over bags 121 | bag_mat = rep[scope[i][0]:scope[i][1]] # (n, H) 122 | instance_logit = self.softmax(self.fc(bag_mat)) # (n, N) 123 | # select j* which scores highest on the known label 124 | max_index = instance_logit[:, query[i]].argmax() # (1) 125 | bag_rep.append(bag_mat[max_index]) # (n, H) -> (H) 126 | bag_rep = torch.stack(bag_rep, 0) # (B, H) 127 | bag_rep = self.drop(bag_rep) 128 | bag_logits = self.fc(bag_rep) # (B, N) 129 | else: 130 | batch_size = label.size(0) 131 | query = label # (B) 132 | rep = rep.view(batch_size, bag_size, -1) 133 | instance_logit = self.softmax(self.fc(rep)) 134 | max_index = instance_logit[torch.arange(batch_size), :, query].argmax(-1) 135 | bag_rep = rep[torch.arange(batch_size), max_index] 136 | 137 | bag_rep = self.drop(bag_rep) 138 | bag_logits = self.fc(bag_rep) # (B, N) 139 | 140 | else: 141 | if bag_size == 0: 142 | bag_logits = [] 143 | for i in range(len(scope)): 144 | bag_mat = rep[scope[i][0]:scope[i][1]] # (n, H) 145 | instance_logit = self.softmax(self.fc(bag_mat)) # (n, N) 146 | logit_for_each_rel = instance_logit.max(dim=0)[0] # (N) 147 | bag_logits.append(logit_for_each_rel) 148 | bag_logits = torch.stack(bag_logits, 0) # after **softmax** 149 | else: 150 | batch_size = rep.size(0) // bag_size 151 | rep = rep.view(batch_size, bag_size, -1) 152 | bag_logits = self.softmax(self.fc(rep)).max(1)[0] 153 | 154 | return bag_logits 155 | 156 | -------------------------------------------------------------------------------- /opennre/model/base_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import json 4 | 5 | class SentenceRE(nn.Module): 6 | def __init__(self): 7 | super().__init__() 8 | 9 | def infer(self, item): 10 | """ 11 | Args: 12 | item: {'text' or 'token', 'h': {'pos': [start, end]}, 't': ...} 13 | Return: 14 | (Name of the relation of the sentence, score) 15 | """ 16 | raise NotImplementedError 17 | 18 | 19 | class BagRE(nn.Module): 20 | def __init__(self): 21 | super().__init__() 22 | 23 | def infer(self, bag): 24 | """ 25 | Args: 26 | bag: bag of sentences with the same entity pair 27 | [{ 28 | 'text' or 'token': ..., 29 | 'h': {'pos': [start, end], ...}, 30 | 't': {'pos': [start, end], ...} 31 | }] 32 | Return: 33 | (relation, score) 34 | """ 35 | raise NotImplementedError 36 | 37 | class FewShotRE(nn.Module): 38 | def __init__(self): 39 | super().__init__() 40 | 41 | def infer(self, support, query): 42 | """ 43 | Args: 44 | support: supporting set. 45 | [{'text' or 'token': ..., 46 | 'h': {'pos': [start, end], ...}, 47 | 't': {'pos': [start, end], ...}, 48 | 'relation': ...}] 49 | query: same format as support 50 | Return: 51 | [(relation, score), ...] 52 | 53 | 54 | For few-shot relation extraction, please refer to FewRel 55 | https://github.com/thunlp/FewRel 56 | """ 57 | raise NotImplementedError 58 | 59 | class NER(nn.Module): 60 | def __init__(self): 61 | super().__init__() 62 | 63 | def ner(self, sentence, is_token=False): 64 | """ 65 | Args: 66 | sentence: string, the input sentence 67 | is_token: if is_token == True, senetence becomes an array of token 68 | Return: 69 | [{name: xx, pos: [start, end]}], a list of named entities 70 | """ 71 | raise NotImplementedError 72 | -------------------------------------------------------------------------------- /opennre/model/sigmoid_nn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn, optim 3 | from .base_model import SentenceRE 4 | 5 | class SigmoidNN(SentenceRE): 6 | """ 7 | Sigmoid (binary) classifier for sentence-level relation extraction. 8 | """ 9 | 10 | def __init__(self, sentence_encoder, num_class, rel2id): 11 | """ 12 | Args: 13 | sentence_encoder: encoder for sentences 14 | num_class: number of classes 15 | id2rel: dictionary of id -> relation name mapping 16 | """ 17 | super().__init__() 18 | self.sentence_encoder = sentence_encoder 19 | self.num_class = num_class 20 | self.fc = nn.Linear(self.sentence_encoder.hidden_size, num_class) 21 | self.rel2id = rel2id 22 | self.id2rel = {} 23 | self.drop = nn.Dropout() 24 | for rel, id in rel2id.items(): 25 | self.id2rel[id] = rel 26 | 27 | def forward(self, *args): 28 | """ 29 | Args: 30 | args: depends on the encoder 31 | Return: 32 | logits, (B, N) 33 | """ 34 | rep = self.sentence_encoder(*args) # (B, H) 35 | rep = self.drop(rep) 36 | logits = self.fc(rep) # (B, N) 37 | return logits 38 | 39 | def logit_to_score(self, logits): 40 | return torch.sigmoid(logits) 41 | -------------------------------------------------------------------------------- /opennre/model/softmax_nn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn, optim 3 | from .base_model import SentenceRE 4 | 5 | class SoftmaxNN(SentenceRE): 6 | """ 7 | Softmax classifier for sentence-level relation extraction. 8 | """ 9 | 10 | def __init__(self, sentence_encoder, num_class, rel2id): 11 | """ 12 | Args: 13 | sentence_encoder: encoder for sentences 14 | num_class: number of classes 15 | id2rel: dictionary of id -> relation name mapping 16 | """ 17 | super().__init__() 18 | self.sentence_encoder = sentence_encoder 19 | self.num_class = num_class 20 | self.fc = nn.Linear(self.sentence_encoder.hidden_size, num_class) 21 | self.softmax = nn.Softmax(-1) 22 | self.rel2id = rel2id 23 | self.id2rel = {} 24 | self.drop = nn.Dropout() 25 | for rel, id in rel2id.items(): 26 | self.id2rel[id] = rel 27 | 28 | def infer(self, item): 29 | self.eval() 30 | _item = self.sentence_encoder.tokenize(item) 31 | item = [] 32 | for x in _item: 33 | item.append(x.to(next(self.parameters()).device)) 34 | logits = self.forward(*item) 35 | logits = self.softmax(logits) 36 | score, pred = logits.max(-1) 37 | score = score.item() 38 | pred = pred.item() 39 | return self.id2rel[pred], score 40 | 41 | def forward(self, *args): 42 | """ 43 | Args: 44 | args: depends on the encoder 45 | Return: 46 | logits, (B, N) 47 | """ 48 | rep = self.sentence_encoder(*args) # (B, H) 49 | rep = self.drop(rep) 50 | logits = self.fc(rep) # (B, N) 51 | return logits 52 | 53 | def logit_to_score(self, logits): 54 | return torch.softmax(logits, -1) 55 | -------------------------------------------------------------------------------- /opennre/module/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | -------------------------------------------------------------------------------- /opennre/module/nn/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | from .cnn import CNN 6 | from .rnn import RNN 7 | from .lstm import LSTM 8 | 9 | __all__ = [ 10 | 'CNN', 11 | 'RNN', 12 | 'LSTM', 13 | ] -------------------------------------------------------------------------------- /opennre/module/nn/cnn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | class CNN(nn.Module): 6 | 7 | def __init__(self, input_size=50, hidden_size=256, dropout=0, kernel_size=3, padding=1, activation_function=F.relu): 8 | """ 9 | Args: 10 | input_size: dimention of input embedding 11 | kernel_size: kernel_size for CNN 12 | padding: padding for CNN 13 | hidden_size: hidden size 14 | """ 15 | super().__init__() 16 | self.conv = nn.Conv1d(input_size, hidden_size, kernel_size, padding=padding) 17 | self.act = activation_function 18 | self.dropout = nn.Dropout(dropout) 19 | 20 | def forward(self, x): 21 | """ 22 | Args: 23 | input features: (B, L, I_EMBED) 24 | Return: 25 | output features: (B, H_EMBED) 26 | """ 27 | # Check size of tensors 28 | x = x.transpose(1, 2) # (B, I_EMBED, L) 29 | x = self.conv(x) # (B, H_EMBED, L) 30 | x = self.act(x) # (B, H_EMBED, L) 31 | x = self.dropout(x) # (B, H_EMBED, L) 32 | x = x.transpose(1, 2) # (B, L, H_EMBED) 33 | return x 34 | -------------------------------------------------------------------------------- /opennre/module/nn/lstm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | class LSTM(nn.Module): 5 | 6 | def __init__(self, input_size=50, hidden_size=256, dropout=0, bidirectional=False, num_layers=1, activation_function="tanh"): 7 | """ 8 | Args: 9 | input_size: dimention of input embedding 10 | hidden_size: hidden size 11 | dropout: dropout layer on the outputs of each RNN layer except the last layer 12 | bidirectional: if it is a bidirectional RNN 13 | num_layers: number of recurrent layers 14 | activation_function: the activation function of RNN, tanh/relu 15 | """ 16 | super().__init__() 17 | if bidirectional: 18 | hidden_size /= 2 19 | self.lstm = nn.LSTM(input_size, 20 | hidden_size, 21 | num_layers, 22 | nonlinearity=activation_function, 23 | dropout=dropout, 24 | bidirectional=bidirectional) 25 | 26 | def forward(self, x): 27 | """ 28 | Args: 29 | input features: (B, L, I_EMBED) 30 | Return: 31 | output features: (B, L, H_EMBED) 32 | """ 33 | # Check size of tensors 34 | x = x.transpose(0, 1) # (L, B, I_EMBED) 35 | x, h, c = self.lstm(x) # (L, B, H_EMBED) 36 | x = x.transpose(0, 1) # (B, L, I_EMBED) 37 | return x 38 | -------------------------------------------------------------------------------- /opennre/module/nn/rnn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | class RNN(nn.Module): 5 | 6 | def __init__(self, input_size=50, hidden_size=256, dropout=0, bidirectional=False, num_layers=1, activation_function="tanh"): 7 | """ 8 | Args: 9 | input_size: dimention of input embedding 10 | hidden_size: hidden size 11 | dropout: dropout layer on the outputs of each RNN layer except the last layer 12 | bidirectional: if it is a bidirectional RNN 13 | num_layers: number of recurrent layers 14 | activation_function: the activation function of RNN, tanh/relu 15 | """ 16 | super().__init__() 17 | if bidirectional: 18 | hidden_size /= 2 19 | self.rnn = nn.RNN(input_size, 20 | hidden_size, 21 | num_layers, 22 | nonlinearity=activation_function, 23 | dropout=dropout, 24 | bidirectional=bidirectional) 25 | 26 | def forward(self, x): 27 | """ 28 | Args: 29 | input features: (B, L, I_EMBED) 30 | Return: 31 | output features: (B, L, H_EMBED) 32 | """ 33 | # Check size of tensors 34 | x = x.transpose(0, 1) # (L, B, I_EMBED) 35 | x, h = self.rnn(x) # (L, B, H_EMBED) 36 | x = x.transpose(0, 1) # (B, L, I_EMBED) 37 | return x 38 | -------------------------------------------------------------------------------- /opennre/module/pool/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | from .max_pool import MaxPool 6 | from .avg_pool import AvgPool 7 | 8 | __all__ = [ 9 | 'MaxPool', 10 | 'AvgPool' 11 | ] -------------------------------------------------------------------------------- /opennre/module/pool/avg_pool.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | 6 | class AvgPool(nn.Module): 7 | 8 | def __init__(self, kernel_size, segment_num=None): 9 | """ 10 | Args: 11 | input_size: dimention of input embedding 12 | kernel_size: kernel_size for CNN 13 | padding: padding for CNN 14 | hidden_size: hidden size 15 | """ 16 | super().__init__() 17 | self.segment_num = segment_num 18 | if self.segment_num != None: 19 | self.mask_embedding = nn.Embedding(segment_num + 1, segment_num) 20 | self.mask_embedding.weight.data.copy_(torch.FloatTensor(np.concatenate([np.zeros(segment_num), np.identity(segment_num)], axis = 0))) 21 | self.mask_embedding.weight.requires_grad = False 22 | self.pool = nn.AvgPool1d(kernel_size) 23 | 24 | def forward(self, x, mask=None): 25 | """ 26 | Args: 27 | input features: (B, L, I_EMBED) 28 | Return: 29 | output features: (B, H_EMBED) 30 | """ 31 | # Check size of tensors 32 | if mask == None or self.segment_num == None or self.segment_num == 1: 33 | x = x.transpose(1, 2) # (B, L, I_EMBED) -> (B, I_EMBED, L) 34 | x = self.pool(x).squeeze(-1) # (B, I_EMBED, 1) -> (B, I_EMBED) 35 | return x 36 | else: 37 | B, L, I_EMBED = x.size()[:2] 38 | mask = self.mask_embedding(mask).transpose(1, 2).unsqueeze(2) # (B, L) -> (B, L, S) -> (B, S, L) -> (B, S, 1, L) 39 | x = x.transpose(1, 2).unsqueeze(1) # (B, L, I_EMBED) -> (B, I_EMBED, L) -> (B, 1, I_EMBED, L) 40 | x = (x * mask).view([-1, I_EMBED, L]) # (B, S, I_EMBED, L) -> (B * S, I_EMBED, L) 41 | x = self.pool(x).squeeze(-1) # (B * S, I_EMBED, 1) -> (B * S, I_EMBED) 42 | x = x.view([B, -1]) - self._minus # (B, S * I_EMBED) 43 | return x -------------------------------------------------------------------------------- /opennre/module/pool/max_pool.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | 6 | class MaxPool(nn.Module): 7 | 8 | def __init__(self, kernel_size, segment_num=None): 9 | """ 10 | Args: 11 | input_size: dimention of input embedding 12 | kernel_size: kernel_size for CNN 13 | padding: padding for CNN 14 | hidden_size: hidden size 15 | """ 16 | super().__init__() 17 | self.segment_num = segment_num 18 | if self.segment_num != None: 19 | self.mask_embedding = nn.Embedding(segment_num + 1, segment_num) 20 | self.mask_embedding.weight.data.copy_(torch.FloatTensor(np.concatenate([np.zeros((1, segment_num)), np.identity(segment_num)], axis=0))) 21 | self.mask_embedding.weight.requires_grad = False 22 | self._minus = -100 23 | self.pool = nn.MaxPool1d(kernel_size) 24 | 25 | def forward(self, x, mask=None): 26 | """ 27 | Args: 28 | input features: (B, L, I_EMBED) 29 | Return: 30 | output features: (B, H_EMBED) 31 | """ 32 | # Check size of tensors 33 | if mask is None or self.segment_num is None or self.segment_num == 1: 34 | x = x.transpose(1, 2) # (B, L, I_EMBED) -> (B, I_EMBED, L) 35 | x = self.pool(x).squeeze(-1) # (B, I_EMBED, 1) -> (B, I_EMBED) 36 | return x 37 | else: 38 | B, L, I_EMBED = x.size()[:3] 39 | # mask = 1 - self.mask_embedding(mask).transpose(1, 2).unsqueeze(2) # (B, L) -> (B, L, S) -> (B, S, L) -> (B, S, 1, L) 40 | # x = x.transpose(1, 2).unsqueeze(1) # (B, L, I_EMBED) -> (B, I_EMBED, L) -> (B, 1, I_EMBED, L) 41 | # x = (x + self._minus * mask).contiguous().view([-1, I_EMBED, L]) # (B, S, I_EMBED, L) -> (B * S, I_EMBED, L) 42 | # x = self.pool(x).squeeze(-1) # (B * S, I_EMBED, 1) -> (B * S, I_EMBED) 43 | # x = x.view([B, -1]) # (B, S * I_EMBED) 44 | # return x 45 | mask = 1 - self.mask_embedding(mask).transpose(1, 2) 46 | x = x.transpose(1, 2) 47 | pool1 = self.pool(x + self._minus * mask[:, 0:1, :]) 48 | pool2 = self.pool(x + self._minus * mask[:, 1:2, :]) 49 | pool3 = self.pool(x + self._minus * mask[:, 2:3, :]) 50 | 51 | x = torch.cat([pool1, pool2, pool3], 1) 52 | # x = x.squeeze(-1) 53 | return x 54 | -------------------------------------------------------------------------------- /opennre/pretrain.py: -------------------------------------------------------------------------------- 1 | from . import encoder 2 | from . import model 3 | from . import framework 4 | import torch 5 | import os 6 | import sys 7 | import json 8 | import numpy as np 9 | import logging 10 | 11 | root_url = "https://thunlp.oss-cn-qingdao.aliyuncs.com/" 12 | default_root_path = os.path.join(os.getenv('HOME'), '.opennre') 13 | 14 | def check_root(root_path=default_root_path): 15 | if not os.path.exists(root_path): 16 | os.mkdir(root_path) 17 | os.mkdir(os.path.join(root_path, 'benchmark')) 18 | os.mkdir(os.path.join(root_path, 'pretrain')) 19 | os.mkdir(os.path.join(root_path, 'pretrain/nre')) 20 | 21 | def download_wiki80(root_path=default_root_path): 22 | check_root() 23 | if not os.path.exists(os.path.join(root_path, 'benchmark/wiki80')): 24 | os.mkdir(os.path.join(root_path, 'benchmark/wiki80')) 25 | os.system('wget -P ' + os.path.join(root_path, 'benchmark/wiki80') + ' ' + root_url + 'opennre/benchmark/wiki80/wiki80_rel2id.json') 26 | os.system('wget -P ' + os.path.join(root_path, 'benchmark/wiki80') + ' ' + root_url + 'opennre/benchmark/wiki80/wiki80_train.txt') 27 | os.system('wget -P ' + os.path.join(root_path, 'benchmark/wiki80') + ' ' + root_url + 'opennre/benchmark/wiki80/wiki80_val.txt') 28 | 29 | def download_tacred(root_path=default_root_path): 30 | check_root() 31 | if not os.path.exists(os.path.join(root_path, 'benchmark/tacred')): 32 | os.mkdir(os.path.join(root_path, 'benchmark/tacred')) 33 | os.system('wget -P ' + os.path.join(root_path, 'benchmark/tacred') + ' ' + root_url + 'opennre/benchmark/tacred/tacred_rel2id.json') 34 | logging.info('Due to copyright limits, we only provide rel2id for TACRED. Please download TACRED manually and convert the data to OpenNRE format if needed.') 35 | 36 | def download_nyt10(root_path=default_root_path): 37 | check_root() 38 | if not os.path.exists(os.path.join(root_path, 'benchmark/nyt10')): 39 | os.mkdir(os.path.join(root_path, 'benchmark/nyt10')) 40 | os.system('wget -P ' + os.path.join(root_path, 'benchmark/nyt10') + ' ' + root_url + 'opennre/benchmark/nyt10/nyt10_rel2id.json') 41 | os.system('wget -P ' + os.path.join(root_path, 'benchmark/nyt10') + ' ' + root_url + 'opennre/benchmark/nyt10/nyt10_train.txt') 42 | os.system('wget -P ' + os.path.join(root_path, 'benchmark/nyt10') + ' ' + root_url + 'opennre/benchmark/nyt10/nyt10_test.txt') 43 | 44 | def download_nyt10m(root_path=default_root_path): 45 | check_root() 46 | if not os.path.exists(os.path.join(root_path, 'benchmark/nyt10m')): 47 | os.mkdir(os.path.join(root_path, 'benchmark/nyt10m')) 48 | os.system('wget -P ' + os.path.join(root_path, 'benchmark/nyt10m') + ' ' + root_url + 'opennre/benchmark/nyt10m/nyt10m_rel2id.json') 49 | os.system('wget -P ' + os.path.join(root_path, 'benchmark/nyt10m') + ' ' + root_url + 'opennre/benchmark/nyt10m/nyt10m_train.txt') 50 | os.system('wget -P ' + os.path.join(root_path, 'benchmark/nyt10m') + ' ' + root_url + 'opennre/benchmark/nyt10m/nyt10m_test.txt') 51 | os.system('wget -P ' + os.path.join(root_path, 'benchmark/nyt10m') + ' ' + root_url + 'opennre/benchmark/nyt10m/nyt10m_val.txt') 52 | 53 | def download_wiki20m(root_path=default_root_path): 54 | check_root() 55 | if not os.path.exists(os.path.join(root_path, 'benchmark/wiki20m')): 56 | os.mkdir(os.path.join(root_path, 'benchmark/wiki20m')) 57 | os.system('wget -P ' + os.path.join(root_path, 'benchmark/wiki20m') + ' ' + root_url + 'opennre/benchmark/wiki20m/wiki20m_rel2id.json') 58 | os.system('wget -P ' + os.path.join(root_path, 'benchmark/wiki20m') + ' ' + root_url + 'opennre/benchmark/wiki20m/wiki20m_train.txt') 59 | os.system('wget -P ' + os.path.join(root_path, 'benchmark/wiki20m') + ' ' + root_url + 'opennre/benchmark/wiki20m/wiki20m_test.txt') 60 | os.system('wget -P ' + os.path.join(root_path, 'benchmark/wiki20m') + ' ' + root_url + 'opennre/benchmark/wiki20m/wiki20m_val.txt') 61 | 62 | def download_wiki_distant(root_path=default_root_path): 63 | check_root() 64 | if not os.path.exists(os.path.join(root_path, 'benchmark/wiki_distant')): 65 | os.mkdir(os.path.join(root_path, 'benchmark/wiki_distant')) 66 | os.system('wget -P ' + os.path.join(root_path, 'benchmark/wiki_distant') + ' ' + root_url + 'opennre/benchmark/wiki_distant/wiki_distant_rel2id.json') 67 | os.system('wget -P ' + os.path.join(root_path, 'benchmark/wiki_distant') + ' ' + root_url + 'opennre/benchmark/wiki_distant/wiki_distant_train.txt') 68 | os.system('wget -P ' + os.path.join(root_path, 'benchmark/wiki_distant') + ' ' + root_url + 'opennre/benchmark/wiki_distant/wiki_distant_test.txt') 69 | os.system('wget -P ' + os.path.join(root_path, 'benchmark/wiki_distant') + ' ' + root_url + 'opennre/benchmark/wiki_distant/wiki_distant_val.txt') 70 | 71 | def download_semeval(root_path=default_root_path): 72 | check_root() 73 | if not os.path.exists(os.path.join(root_path, 'benchmark/semeval')): 74 | os.mkdir(os.path.join(root_path, 'benchmark/semeval')) 75 | os.system('wget -P ' + os.path.join(root_path, 'benchmark/semeval') + ' ' + root_url + 'opennre/benchmark/semeval/semeval_rel2id.json') 76 | os.system('wget -P ' + os.path.join(root_path, 'benchmark/semeval') + ' ' + root_url + 'opennre/benchmark/semeval/semeval_train.txt') 77 | os.system('wget -P ' + os.path.join(root_path, 'benchmark/semeval') + ' ' + root_url + 'opennre/benchmark/semeval/semeval_test.txt') 78 | os.system('wget -P ' + os.path.join(root_path, 'benchmark/semeval') + ' ' + root_url + 'opennre/benchmark/semeval/semeval_val.txt') 79 | 80 | def download_glove(root_path=default_root_path): 81 | check_root() 82 | if not os.path.exists(os.path.join(root_path, 'pretrain/glove')): 83 | os.mkdir(os.path.join(root_path, 'pretrain/glove')) 84 | os.system('wget -P ' + os.path.join(root_path, 'pretrain/glove') + ' ' + root_url + 'opennre/pretrain/glove/glove.6B.50d_mat.npy') 85 | os.system('wget -P ' + os.path.join(root_path, 'pretrain/glove') + ' ' + root_url + 'opennre/pretrain/glove/glove.6B.50d_word2id.json') 86 | 87 | def download_bert_base_uncased(root_path=default_root_path): 88 | check_root() 89 | if not os.path.exists(os.path.join(root_path, 'pretrain/bert-base-uncased')): 90 | os.mkdir(os.path.join(root_path, 'pretrain/bert-base-uncased')) 91 | os.system('wget -P ' + os.path.join(root_path, 'pretrain/bert-base-uncased') + ' ' + root_url + 'opennre/pretrain/bert-base-uncased/config.json') 92 | os.system('wget -P ' + os.path.join(root_path, 'pretrain/bert-base-uncased') + ' ' + root_url + 'opennre/pretrain/bert-base-uncased/pytorch_model.bin') 93 | os.system('wget -P ' + os.path.join(root_path, 'pretrain/bert-base-uncased') + ' ' + root_url + 'opennre/pretrain/bert-base-uncased/vocab.txt') 94 | 95 | def download_pretrain(model_name, root_path=default_root_path): 96 | ckpt = os.path.join(root_path, 'pretrain/nre/' + model_name + '.pth.tar') 97 | if not os.path.exists(ckpt): 98 | os.system('wget -P ' + os.path.join(root_path, 'pretrain/nre') + ' ' + root_url + 'opennre/pretrain/nre/' + model_name + '.pth.tar') 99 | 100 | def download(name, root_path=default_root_path): 101 | if not os.path.exists(os.path.join(root_path, 'benchmark')): 102 | os.mkdir(os.path.join(root_path, 'benchmark')) 103 | if not os.path.exists(os.path.join(root_path, 'pretrain')): 104 | os.mkdir(os.path.join(root_path, 'pretrain')) 105 | if name == 'nyt10': 106 | download_nyt10(root_path=root_path) 107 | elif name == 'nyt10m': 108 | download_nyt10m(root_path=root_path) 109 | elif name == 'wiki20m': 110 | download_wiki20m(root_path=root_path) 111 | elif name == 'wiki_distant': 112 | download_wiki_distant(root_path=root_path) 113 | elif name == 'semeval': 114 | download_semeval(root_path=root_path) 115 | elif name == 'wiki80': 116 | download_wiki80(root_path=root_path) 117 | elif name == 'tacred': 118 | download_tacred(root_path=root_path) 119 | elif name == 'glove': 120 | download_glove(root_path=root_path) 121 | elif name == 'bert_base_uncased': 122 | download_bert_base_uncased(root_path=root_path) 123 | else: 124 | raise Exception('Cannot find corresponding data.') 125 | 126 | def get_model(model_name, root_path=default_root_path): 127 | check_root() 128 | ckpt = os.path.join(root_path, 'pretrain/nre/' + model_name + '.pth.tar') 129 | if model_name == 'wiki80_cnn_softmax': 130 | download_pretrain(model_name, root_path=root_path) 131 | download('glove', root_path=root_path) 132 | download('wiki80', root_path=root_path) 133 | wordi2d = json.load(open(os.path.join(root_path, 'pretrain/glove/glove.6B.50d_word2id.json'))) 134 | word2vec = np.load(os.path.join(root_path, 'pretrain/glove/glove.6B.50d_mat.npy')) 135 | rel2id = json.load(open(os.path.join(root_path, 'benchmark/wiki80/wiki80_rel2id.json'))) 136 | sentence_encoder = encoder.CNNEncoder(token2id=wordi2d, 137 | max_length=40, 138 | word_size=50, 139 | position_size=5, 140 | hidden_size=230, 141 | blank_padding=True, 142 | kernel_size=3, 143 | padding_size=1, 144 | word2vec=word2vec, 145 | dropout=0.5) 146 | m = model.SoftmaxNN(sentence_encoder, len(rel2id), rel2id) 147 | m.load_state_dict(torch.load(ckpt, map_location='cpu')['state_dict']) 148 | return m 149 | elif model_name in ['wiki80_bert_softmax', 'wiki80_bertentity_softmax']: 150 | download_pretrain(model_name, root_path=root_path) 151 | download('bert_base_uncased', root_path=root_path) 152 | download('wiki80', root_path=root_path) 153 | rel2id = json.load(open(os.path.join(root_path, 'benchmark/wiki80/wiki80_rel2id.json'))) 154 | if 'entity' in model_name: 155 | sentence_encoder = encoder.BERTEntityEncoder( 156 | max_length=80, pretrain_path=os.path.join(root_path, 'pretrain/bert-base-uncased')) 157 | else: 158 | sentence_encoder = encoder.BERTEncoder( 159 | max_length=80, pretrain_path=os.path.join(root_path, 'pretrain/bert-base-uncased')) 160 | m = model.SoftmaxNN(sentence_encoder, len(rel2id), rel2id) 161 | m.load_state_dict(torch.load(ckpt, map_location='cpu')['state_dict']) 162 | return m 163 | elif model_name in ['tacred_bert_softmax', 'tacred_bertentity_softmax']: 164 | download_pretrain(model_name, root_path=root_path) 165 | download('bert_base_uncased', root_path=root_path) 166 | download('tacred', root_path=root_path) 167 | rel2id = json.load(open(os.path.join(root_path, 'benchmark/tacred/tacred_rel2id.json'))) 168 | if 'entity' in model_name: 169 | sentence_encoder = encoder.BERTEntityEncoder( 170 | max_length=80, pretrain_path=os.path.join(root_path, 'pretrain/bert-base-uncased')) 171 | else: 172 | sentence_encoder = encoder.BERTEncoder( 173 | max_length=80, pretrain_path=os.path.join(root_path, 'pretrain/bert-base-uncased')) 174 | m = model.SoftmaxNN(sentence_encoder, len(rel2id), rel2id) 175 | m.load_state_dict(torch.load(ckpt, map_location='cpu')['state_dict']) 176 | return m 177 | else: 178 | raise NotImplementedError 179 | -------------------------------------------------------------------------------- /opennre/tokenization/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | 3 | from __future__ import absolute_import 4 | from __future__ import division 5 | from __future__ import print_function 6 | 7 | from .basic_tokenizer import BasicTokenizer 8 | from .word_piece_tokenizer import WordpieceTokenizer 9 | from .word_tokenizer import WordTokenizer 10 | from .bert_tokenizer import BertTokenizer 11 | 12 | __all__ = [ 13 | 'BasicTokenizer', 14 | 'WordpieceTokenizer', 15 | 'WordTokenizer', 16 | 'BertTokenizer', 17 | 'is_whitespace', 18 | 'is_control', 19 | 'is_punctuation', 20 | 'is_chinese_char', 21 | 'convert_to_unicode', 22 | 'clean_text', 23 | 'split_on_whitespace', 24 | 'split_on_punctuation', 25 | 'tokenize_chinese_chars', 26 | 'strip_accents', 27 | 'printable_text', 28 | 'convert_by_vocab', 29 | 'convert_tokens_to_ids', 30 | 'convert_ids_to_tokens' 31 | ] 32 | 33 | 34 | -------------------------------------------------------------------------------- /opennre/tokenization/basic_tokenizer.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """BasicTokenizer classes.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | from .utils import (convert_to_unicode, 23 | clean_text, 24 | split_on_whitespace, 25 | split_on_punctuation, 26 | tokenize_chinese_chars, 27 | strip_accents) 28 | 29 | class BasicTokenizer(object): 30 | """Runs basic tokenization (punctuation splitting, lower casing, etc.).""" 31 | 32 | def __init__(self, 33 | do_lower_case=True, 34 | never_split=("[UNK]", "[SEP]", "[PAD]", "[CLS]", "[MASK]")): 35 | """Constructs a BasicTokenizer. 36 | Args: 37 | do_lower_case: Whether to lower case the input. 38 | """ 39 | self.do_lower_case = do_lower_case 40 | self.never_split = never_split 41 | 42 | def tokenize(self, text): 43 | """Tokenizes a piece of text.""" 44 | text = convert_to_unicode(text) 45 | text = clean_text(text) 46 | text = tokenize_chinese_chars(text) 47 | # This was added on November 1st, 2018 for the multilingual and Chinese 48 | # models. This is also applied to the English models now, but it doesn't 49 | # matter since the English models were not trained on any Chinese data 50 | # and generally don't have any Chinese data in them (there are Chinese 51 | # characters in the vocabulary because Wikipedia does have some Chinese 52 | # words in the English Wikipedia.). 53 | orig_tokens = split_on_whitespace(text) 54 | split_tokens = [] 55 | current_positions = [] 56 | for token in orig_tokens: 57 | if self.do_lower_case and token not in self.never_split: 58 | token = token.lower() 59 | token = strip_accents(token) 60 | current_positions.append([]) 61 | current_positions[-1].append(len(split_tokens)) 62 | split_tokens.extend(split_on_punctuation(token)) 63 | current_positions[-1].append(len(split_tokens)) 64 | return split_tokens, current_positions 65 | -------------------------------------------------------------------------------- /opennre/tokenization/bert_tokenizer.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """BertTokenizer classes.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import collections 23 | import unicodedata 24 | import six 25 | 26 | from .basic_tokenizer import BasicTokenizer 27 | from .word_piece_tokenizer import WordpieceTokenizer 28 | from .utils import (convert_to_unicode, 29 | load_vocab, 30 | convert_by_vocab, 31 | convert_tokens_to_ids, 32 | convert_ids_to_tokens) 33 | 34 | class BertTokenizer(object): 35 | 36 | def __init__(self, 37 | vocab = None, 38 | do_lower_case = True, 39 | do_basic_tokenize = True, 40 | never_split=["[UNK]", "[SEP]", "[PAD]", "[CLS]", "[MASK]"]): 41 | 42 | self.vocab = load_vocab(vocab) 43 | self.inv_vocab = {v: k for k, v in self.vocab.items()} 44 | self.basic_tokenizer = BasicTokenizer(do_lower_case = do_lower_case, never_split = never_split) 45 | self.wordpiece_tokenizer = WordpieceTokenizer(vocab = self.vocab) 46 | self.do_basic_tokenize = do_basic_tokenize 47 | 48 | def tokenize(self, text): 49 | split_tokens = [] 50 | if self.do_basic_tokenize: 51 | tokens, _ = self.basic_tokenizer.tokenize(text) 52 | text = " ".join(tokens) 53 | split_tokens, current_positions = self.wordpiece_tokenizer.tokenize(text) 54 | return split_tokens, current_positions 55 | 56 | def convert_tokens_to_ids(self, tokens): 57 | return convert_by_vocab(self.vocab, tokens) 58 | 59 | def convert_ids_to_tokens(self, ids): 60 | return convert_by_vocab(self.inv_vocab, ids) 61 | -------------------------------------------------------------------------------- /opennre/tokenization/utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import collections 6 | import unicodedata 7 | import six 8 | 9 | def is_whitespace(char): 10 | """ Checks whether `chars` is a whitespace character. 11 | \t, \n, and \r are technically contorl characters but we treat them 12 | as whitespace since they are generally considered as such. 13 | """ 14 | if char == " " or char == "\t" or char == "\n" or char == "\r": 15 | return True 16 | cat = unicodedata.category(char) 17 | if cat == "Zs": 18 | return True 19 | return False 20 | 21 | def is_control(char): 22 | """ Checks whether `chars` is a control character. 23 | These are technically control characters but we count them as whitespace characters. 24 | """ 25 | if char == "\t" or char == "\n" or char == "\r": 26 | return False 27 | cat = unicodedata.category(char) 28 | if cat.startswith("C"): 29 | return True 30 | return False 31 | 32 | def is_punctuation(char): 33 | """ Checks whether `chars` is a punctuation character. 34 | We treat all non-letter/number ASCII as punctuation. Characters such as "^", "$", and "`" are not in the Unicode. 35 | Punctuation class but we treat them as punctuation anyways, for consistency. 36 | """ 37 | cp = ord(char) 38 | if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)): 39 | return True 40 | cat = unicodedata.category(char) 41 | if cat.startswith("P"): 42 | return True 43 | return False 44 | 45 | def is_chinese_char(cp): 46 | """ Checks whether CP is the codepoint of a CJK character. 47 | This defines a "chinese character" as anything in the CJK Unicode block: 48 | https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block) 49 | Note that the CJK Unicode block is NOT all Japanese and Korean characters, 50 | despite its name. The modern Korean Hangul alphabet is a different block, 51 | as is Japanese Hiragana and Katakana. Those alphabets are used to write 52 | space-separated words, so they are not treated specially and handled 53 | like the all of the other languages. 54 | """ 55 | if ((cp >= 0x4E00 and cp <= 0x9FFF) or 56 | (cp >= 0x3400 and cp <= 0x4DBF) or 57 | (cp >= 0x20000 and cp <= 0x2A6DF) or 58 | (cp >= 0x2A700 and cp <= 0x2B73F) or 59 | (cp >= 0x2B740 and cp <= 0x2B81F) or 60 | (cp >= 0x2B820 and cp <= 0x2CEAF) or 61 | (cp >= 0xF900 and cp <= 0xFAFF) or 62 | (cp >= 0x2F800 and cp <= 0x2FA1F)): 63 | return True 64 | return False 65 | 66 | def convert_to_unicode(text): 67 | """Converts `text` to Unicode (if it's not already), assuming utf-8 input.""" 68 | if six.PY3: 69 | if isinstance(text, str): 70 | return text 71 | elif isinstance(text, bytes): 72 | return text.decode("utf-8", "ignore") 73 | else: 74 | raise ValueError("Unsupported string type: %s" % (type(text))) 75 | elif six.PY2: 76 | if isinstance(text, str): 77 | return text.decode("utf-8", "ignore") 78 | elif isinstance(text, unicode): 79 | return text 80 | else: 81 | raise ValueError("Unsupported string type: %s" % (type(text))) 82 | else: 83 | raise ValueError("Not running on Python2 or Python 3?") 84 | 85 | def clean_text(text): 86 | output = [] 87 | for char in text: 88 | cp = ord(char) 89 | if cp == 0 or cp == 0xfffd or is_control(char): 90 | continue 91 | if is_whitespace(char): 92 | output.append(" ") 93 | else: 94 | output.append(char) 95 | return "".join(output) 96 | 97 | def split_on_whitespace(text): 98 | """ Runs basic whitespace cleaning and splitting on a peice of text. 99 | e.g, 'a b c' -> ['a', 'b', 'c'] 100 | """ 101 | text = text.strip() 102 | if not text: 103 | return [] 104 | return text.split() 105 | 106 | def split_on_punctuation(text): 107 | """Splits punctuation on a piece of text.""" 108 | start_new_word = True 109 | output = [] 110 | for char in text: 111 | if is_punctuation(char): 112 | output.append([char]) 113 | start_new_word = True 114 | else: 115 | if start_new_word: 116 | output.append([]) 117 | start_new_word = False 118 | output[-1].append(char) 119 | return ["".join(x) for x in output] 120 | 121 | def tokenize_chinese_chars(text): 122 | """Adds whitespace around any CJK character.""" 123 | output = [] 124 | for char in text: 125 | cp = ord(char) 126 | if is_chinese_char(cp): 127 | output.append(" ") 128 | output.append(char) 129 | output.append(" ") 130 | else: 131 | output.append(char) 132 | return "".join(output) 133 | 134 | def strip_accents(text): 135 | """Strips accents from a piece of text.""" 136 | text = unicodedata.normalize("NFD", text) 137 | output = [] 138 | for char in text: 139 | cat = unicodedata.category(char) 140 | if cat == "Mn": 141 | continue 142 | output.append(char) 143 | return "".join(output) 144 | 145 | def load_vocab(vocab_file): 146 | """Loads a vocabulary file into a dictionary.""" 147 | if vocab_file == None: 148 | raise Exception("Vocab file not provided") 149 | if isinstance(vocab_file, str) or isinstance(vocab_file, bytes): 150 | vocab = collections.OrderedDict() 151 | index = 0 152 | with open(vocab_file, "r", encoding="utf-8") as reader: 153 | while True: 154 | token = reader.readline() 155 | if not token: 156 | break 157 | token = token.strip() 158 | vocab[token] = index 159 | index += 1 160 | return vocab 161 | else: 162 | return vocab_file 163 | 164 | def printable_text(text): 165 | """ Returns text encoded in a way suitable for print or `tf.logging`. 166 | These functions want `str` for both Python2 and Python3, but in one case 167 | it's a Unicode string and in the other it's a byte string. 168 | """ 169 | if six.PY3: 170 | if isinstance(text, str): 171 | return text 172 | elif isinstance(text, bytes): 173 | return text.decode("utf-8", "ignore") 174 | else: 175 | raise ValueError("Unsupported string type: %s" % (type(text))) 176 | elif six.PY2: 177 | if isinstance(text, str): 178 | return text 179 | elif isinstance(text, unicode): 180 | return text.encode("utf-8") 181 | else: 182 | raise ValueError("Unsupported string type: %s" % (type(text))) 183 | else: 184 | raise ValueError("Not running on Python2 or Python 3?") 185 | 186 | def convert_by_vocab(vocab, items, max_seq_length = None, blank_id = 0, unk_id = 1, uncased = True): 187 | """Converts a sequence of [tokens|ids] using the vocab.""" 188 | output = [] 189 | for item in items: 190 | if uncased: 191 | item = item.lower() 192 | if item in vocab: 193 | output.append(vocab[item]) 194 | else: 195 | output.append(unk_id) 196 | if max_seq_length != None: 197 | if len(output) > max_seq_length: 198 | output = output[:max_seq_length] 199 | else: 200 | while len(output) < max_seq_length: 201 | output.append(blank_id) 202 | return output 203 | 204 | def convert_tokens_to_ids(vocab, tokens, max_seq_length = None, blank_id = 0, unk_id = 1): 205 | return convert_by_vocab(vocab, tokens, max_seq_length, blank_id, unk_id) 206 | 207 | def convert_ids_to_tokens(inv_vocab, ids): 208 | return convert_by_vocab(inv_vocab, ids, uncased=False) 209 | 210 | def truncate_seq_pair(tokens_a, tokens_b, max_num_tokens, rng): 211 | """Truncates a pair of sequences to a maximum sequence length.""" 212 | while True: 213 | total_length = len(tokens_a) + len(tokens_b) 214 | if total_length <= max_num_tokens: 215 | break 216 | trunc_tokens = tokens_a if len(tokens_a) > len(tokens_b) else tokens_b 217 | assert len(trunc_tokens) >= 1 218 | # We want to sometimes truncate from the front and sometimes from the 219 | # back to add more randomness and avoid biases. 220 | if rng.random() < 0.5: 221 | del trunc_tokens[0] 222 | else: 223 | trunc_tokens.pop() 224 | 225 | def create_int_feature(values): 226 | feature = tf.train.Feature(int64_list=tf.train.Int64List(value=list(values))) 227 | return feature 228 | 229 | def create_float_feature(values): 230 | feature = tf.train.Feature(float_list=tf.train.FloatList(value=list(values))) 231 | return feature 232 | 233 | def add_token(tokens_a, tokens_b = None): 234 | assert len(tokens_a) >= 1 235 | 236 | tokens = [] 237 | segment_ids = [] 238 | 239 | tokens.append("[CLS]") 240 | segment_ids.append(0) 241 | 242 | for token in tokens_a: 243 | tokens.append(token) 244 | segment_ids.append(0) 245 | 246 | tokens.append("[SEP]") 247 | segment_ids.append(0) 248 | 249 | if tokens_b != None: 250 | assert len(tokens_b) >= 1 251 | 252 | for token in tokens_b: 253 | tokens.append(token) 254 | segment_ids.append(1) 255 | 256 | tokens.append("[SEP]") 257 | segment_ids.append(1) 258 | 259 | return tokens, segment_ids 260 | -------------------------------------------------------------------------------- /opennre/tokenization/word_piece_tokenizer.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """WordpieceTokenizer classes.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import unicodedata 23 | 24 | from .utils import (load_vocab, 25 | convert_to_unicode, 26 | clean_text, 27 | split_on_whitespace, 28 | convert_by_vocab, 29 | tokenize_chinese_chars) 30 | 31 | class WordpieceTokenizer(object): 32 | """Runs WordPiece tokenziation.""" 33 | 34 | def __init__(self, vocab = None, unk_token="[UNK]", max_input_chars_per_word=200): 35 | self.vocab = load_vocab(vocab) 36 | self.inv_vocab = {v: k for k, v in self.vocab.items()} 37 | self.unk_token = unk_token 38 | self.max_input_chars_per_word = max_input_chars_per_word 39 | 40 | def tokenize(self, text): 41 | """ Tokenizes a piece of text into its word pieces. This uses a greedy longest-match-first algorithm to perform tokenization 42 | using the given vocabulary. 43 | 44 | For example: 45 | input = "unaffable" 46 | output = ["un", "##aff", "##able"] 47 | 48 | Args: 49 | text: A single token or whitespace separated tokens. This should have already been passed through `BasicTokenizer`. 50 | Returns: 51 | output_tokens: A list of wordpiece tokens. 52 | current_positions: A list of the current positions for the original words in text . 53 | """ 54 | text = convert_to_unicode(text) 55 | text = clean_text(text) 56 | text = tokenize_chinese_chars(text) 57 | output_tokens = [] 58 | current_positions = [] 59 | token_list = split_on_whitespace(text) 60 | for chars in token_list: 61 | if len(chars) > self.max_input_chars_per_word: 62 | output_tokens.append(self.unk_token) 63 | continue 64 | is_bad = False 65 | start = 0 66 | sub_tokens = [] 67 | while start < len(chars): 68 | end = len(chars) 69 | if start > 0: 70 | substr = "##" + chars[start:end] 71 | else: 72 | substr = chars[start:end] 73 | cur_substr = None 74 | while start < end: 75 | if substr in self.vocab: 76 | cur_substr = substr 77 | break 78 | end -= 1 79 | substr = substr[:-1] 80 | if cur_substr is None: 81 | is_bad = True 82 | break 83 | else: 84 | sub_tokens.append(cur_substr) 85 | start = end 86 | current_positions.append([]) 87 | if is_bad: 88 | current_positions[-1].append(len(output_tokens)) 89 | output_tokens.append(self.unk_token) 90 | current_positions[-1].append(len(output_tokens)) 91 | else: 92 | current_positions[-1].append(len(output_tokens)) 93 | output_tokens.extend(sub_tokens) 94 | current_positions[-1].append(len(output_tokens)) 95 | 96 | return output_tokens, current_positions 97 | 98 | def convert_tokens_to_ids(self, tokens): 99 | return convert_by_vocab(self.vocab, tokens) 100 | 101 | def convert_ids_to_tokens(self, ids): 102 | return convert_by_vocab(self.inv_vocab, ids) 103 | -------------------------------------------------------------------------------- /opennre/tokenization/word_tokenizer.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """WordpieceTokenizer classes.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import unicodedata 23 | 24 | from .utils import (load_vocab, 25 | convert_to_unicode, 26 | clean_text, 27 | split_on_whitespace, 28 | convert_by_vocab, 29 | tokenize_chinese_chars) 30 | 31 | class WordTokenizer(object): 32 | """Runs WordPiece tokenziation.""" 33 | 34 | def __init__(self, vocab = None, unk_token="[UNK]"): 35 | self.vocab = load_vocab(vocab) 36 | self.inv_vocab = {v: k for k, v in self.vocab.items()} 37 | self.unk_token = unk_token 38 | 39 | def tokenize(self, text): 40 | """ Tokenizes a piece of text into its word pieces. This uses a greedy longest-match-first algorithm to perform tokenization 41 | using the given vocabulary. 42 | 43 | For example: 44 | input = "unaffable" 45 | output = ["un", "##aff", "##able"] 46 | 47 | Args: 48 | text: A single token or whitespace separated tokens. This should have already been passed through `BasicTokenizer`. 49 | Returns: 50 | output_tokens: A list of wordpiece tokens. 51 | current_positions: A list of the current positions for the original words in text . 52 | """ 53 | text = convert_to_unicode(text) 54 | text = clean_text(text) 55 | text = tokenize_chinese_chars(text) 56 | # output_tokens = [] 57 | token_list = split_on_whitespace(text) 58 | # for chars in token_list: 59 | # # current_positions.append([]) 60 | # if chars in self.vocab: 61 | # output_tokens.append(chars) 62 | # else: 63 | # output_tokens.append(self.unk_token) 64 | return token_list 65 | 66 | def convert_tokens_to_ids(self, tokens, max_seq_length = None, blank_id = 0, unk_id = 1, uncased = True): 67 | return convert_by_vocab(self.vocab, tokens, max_seq_length, blank_id, unk_id, uncased=uncased) 68 | 69 | def convert_ids_to_tokens(self, ids): 70 | return convert_by_vocab(self.inv_vocab, ids) 71 | -------------------------------------------------------------------------------- /pretrain/download_bert.sh: -------------------------------------------------------------------------------- 1 | mkdir bert-base-uncased 2 | wget -P bert-base-uncased https://thunlp.oss-cn-qingdao.aliyuncs.com/opennre/pretrain/bert-base-uncased/config.json 3 | wget -P bert-base-uncased https://thunlp.oss-cn-qingdao.aliyuncs.com/opennre/pretrain/bert-base-uncased/pytorch_model.bin 4 | wget -P bert-base-uncased https://thunlp.oss-cn-qingdao.aliyuncs.com/opennre/pretrain/bert-base-uncased/vocab.txt 5 | -------------------------------------------------------------------------------- /pretrain/download_glove.sh: -------------------------------------------------------------------------------- 1 | mkdir glove 2 | wget -P glove https://thunlp.oss-cn-qingdao.aliyuncs.com/opennre/pretrain/glove/glove.6B.50d_mat.npy 3 | wget -P glove https://thunlp.oss-cn-qingdao.aliyuncs.com/opennre/pretrain/glove/glove.6B.50d_word2id.json 4 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==1.6.0 2 | transformers==3.4.0 3 | pytest==5.3.2 4 | scikit-learn==0.22.1 5 | scipy==1.4.1 6 | nltk>=3.6.4 7 | 8 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import setuptools 2 | with open("README.md", "r") as fh: 3 | setuptools.setup( 4 | name='open-nre', 5 | version='0.1.1', 6 | author="Tianyu Gao", 7 | author_email="gaotianyu1350@126.com", 8 | description="An open source toolkit for relation extraction", 9 | url="https://github.com/thunlp/opennre", 10 | packages=setuptools.find_packages(), 11 | classifiers=[ 12 | "Programming Language :: Python :: 3", 13 | "License :: OSI Approved :: MIT License", 14 | "Operating System :: POSIX :: Linux", 15 | ], 16 | install_requires=['nltk', 'transformers', 'scikit-learn'], 17 | setup_requires=['wheel'] 18 | ) 19 | -------------------------------------------------------------------------------- /tests/test_inference.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import opennre 3 | 4 | class TestInference(unittest.TestCase): 5 | 6 | def test_wiki80_cnn_softmax(self): 7 | model = opennre.get_model('wiki80_cnn_softmax') 8 | result = model.infer({'text': 'He was the son of Máel Dúin mac Máele Fithrich, and grandson of the high king Áed Uaridnach (died 612).', 'h': {'pos': (18, 46)}, 't': {'pos': (78, 91)}}) 9 | print(result) 10 | self.assertEqual(result[0], 'father') 11 | self.assertTrue(abs(result[1] - 0.7500484585762024) < 1e-6) 12 | 13 | if __name__ == '__main__': 14 | unittest.main() 15 | --------------------------------------------------------------------------------