├── .gitignore
├── LICENSE
├── README.md
├── benchmark
├── download_fewrel.sh
├── download_nyt10.sh
├── download_nyt10m.sh
├── download_semeval.sh
├── download_wiki20m.sh
├── download_wiki80.sh
└── rel4wiki80.md
├── example
├── test_multilabel_bert.py
├── test_multilabel_cnn.py
├── train_bag_bert.py
├── train_bag_cnn.py
├── train_supervised_bert.py
└── train_supervised_cnn.py
├── opennre
├── __init__.py
├── encoder
│ ├── __init__.py
│ ├── base_encoder.py
│ ├── bert_encoder.py
│ ├── cnn_encoder.py
│ └── pcnn_encoder.py
├── framework
│ ├── __init__.py
│ ├── bag_re.py
│ ├── data_loader.py
│ ├── multi_label_sentence_re.py
│ ├── sentence_re.py
│ └── utils.py
├── model
│ ├── __init__.py
│ ├── bag_attention.py
│ ├── bag_average.py
│ ├── bag_one.py
│ ├── base_model.py
│ ├── sigmoid_nn.py
│ └── softmax_nn.py
├── module
│ ├── __init__.py
│ ├── nn
│ │ ├── __init__.py
│ │ ├── cnn.py
│ │ ├── lstm.py
│ │ └── rnn.py
│ └── pool
│ │ ├── __init__.py
│ │ ├── avg_pool.py
│ │ └── max_pool.py
├── pretrain.py
└── tokenization
│ ├── __init__.py
│ ├── basic_tokenizer.py
│ ├── bert_tokenizer.py
│ ├── utils.py
│ ├── word_piece_tokenizer.py
│ └── word_tokenizer.py
├── pretrain
├── download_bert.sh
└── download_glove.sh
├── requirements.txt
├── setup.py
└── tests
└── test_inference.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | *.egg-info/
24 | .installed.cfg
25 | *.egg
26 | MANIFEST
27 |
28 | # PyInstaller
29 | # Usually these files are written by a python script from a template
30 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
31 | *.manifest
32 | *.spec
33 |
34 | # Installer logs
35 | pip-log.txt
36 | pip-delete-this-directory.txt
37 |
38 | # Unit test / coverage reports
39 | htmlcov/
40 | .tox/
41 | .coverage
42 | .coverage.*
43 | .cache
44 | nosetests.xml
45 | coverage.xml
46 | *.cover
47 | .hypothesis/
48 | .pytest_cache/
49 |
50 | # Translations
51 | *.mo
52 | *.pot
53 |
54 | # Django stuff:
55 | *.log
56 | local_settings.py
57 | db.sqlite3
58 |
59 | # Flask stuff:
60 | instance/
61 | .webassets-cache
62 |
63 | # Scrapy stuff:
64 | .scrapy
65 |
66 | # Sphinx documentation
67 | docs/_build/
68 |
69 | # PyBuilder
70 | target/
71 |
72 | # Jupyter Notebook
73 | .ipynb_checkpoints
74 |
75 | # pyenv
76 | .python-version
77 |
78 | # celery beat schedule file
79 | celerybeat-schedule
80 |
81 | # SageMath parsed files
82 | *.sage.py
83 |
84 | # Environments
85 | .env
86 | .venv
87 | env/
88 | venv/
89 | ENV/
90 | env.bak/
91 | venv.bak/
92 |
93 | # Spyder project settings
94 | .spyderproject
95 | .spyproject
96 |
97 | # Rope project settings
98 | .ropeproject
99 |
100 | # mkdocs documentation
101 | /site
102 |
103 | # mypy
104 | .mypy_cache/
105 |
106 | # test
107 | test.py
108 |
109 | # ckpt
110 | ckpt
111 |
112 | # vscode
113 | .vscode
114 |
115 | # tacred
116 | benchmark/tacred
117 | *.swp
118 |
119 | <<<<<<< HEAD
120 | # data and pretrain
121 | pretrain
122 | benchmark
123 | !benchmark/*.sh
124 | !pretrain/*.sh
125 |
126 | # test env
127 | .test
128 |
129 | # package
130 | opennre-egg.info
131 |
132 | *.sh
133 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2019 Tianyu Gao
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # OpenNRE (sub-project of OpenSKL)
2 |
3 | OpenNRE is a sub-project of OpenSKL, providing an **Open**-source **N**eural **R**elation **E**xtraction toolkit for extracting structured knowledge from plain text, with [ATT](https://aclanthology.org/P16-1200.pdf) as key features to consider relation-associated text information.
4 |
5 | ## Overview
6 |
7 | OpenNRE is an open-source and extensible toolkit that provides a unified framework to implement relation extraction models. We unify the input and output interfaces of different relation extraction models and provide scalable options for each model. The toolkit covers both supervised and distant supervised settings, and is compatible with both conventional neural networks and pre-trained language models.
8 |
9 | Relation extraction is a natural language processing (NLP) task aiming at extracting relations (e.g., *founder of*) between entities (e.g., **Bill Gates** and **Microsoft**). For example, from the sentence *Bill Gates founded Microsoft*, we can extract the relation triple (**Bill Gates**, *founder of*, **Microsoft**).
10 |
11 | Relation extraction is a crucial technique in automatic knowledge graph construction. By using relation extraction, we can accumulatively extract new relation facts and expand the knowledge graph, which, as a way for machines to understand the human world, has many downstream applications like question answering, recommender system and search engine. If you want to learn more about neural relation extraction, visit another project of ours ([NREPapers](https://github.com/thunlp/NREPapers)).
12 |
13 | It's our honor to help you better explore relation extraction with our OpenNRE toolkit! You can refer to our [document](https://opennre-docs.readthedocs.io/en/latest/) for more details about this project.
14 |
15 |
24 |
25 | ## Models
26 | In this toolkit, we support CNN-based relation extraction models including standard CNN and our proposed [CNN+ATT](https://aclanthology.org/P16-1200v2.pdf). We also implement methods based on pre-trained language models (BERT).
27 |
28 | ## Evaluation
29 |
30 | To validate the effectiveness of this toolkit, we employ the Bag-Level Relation Extraction task for evaluation.
31 |
32 | ### Settings
33 |
34 | We utilize the NYT10 dataset, which is a distantly supervised collection derived from the New York Times corpus and FreeBase. We mainly experiment on CNN-ATT model, which employs instance-level attention and shows superior performance compared with vanilla CNN.
35 |
36 | ### Results
37 |
38 | We report AUC and F1 scores of two models. The right two columns marked with (\*) indicates the results sourced from [Gao et al.(2021)](https://aclanthology.org/2021.findings-acl.112.pdf) and [Lin et al.(2016)](https://aclanthology.org/P16-1200v2.pdf). The results show that our implementation of CNN-ATT model is slighly better than the original paper, and also confirm the better performance of CNN-ATT over standard CNN model.
39 |
40 | | Model | AUC | F1 | AUC(Paper \*) | F1(Paper \*) |
41 | | :-: | :-: | :-: | :-: | :-: |
42 | | CNN | - | - | 0.212 | 0.318 |
43 | | CNN-ATT | 0.333 | 0.397 | 0.318 | 0.380 |
44 |
45 | ## Usage
46 |
47 | ### Installation
48 |
49 | #### Install as A Python Package
50 |
51 | We are now working on deploy OpenNRE as a Python package. Coming soon!
52 |
53 | #### Using Git Repository
54 |
55 | Clone the repository from our github page (don't forget to star us!)
56 |
57 | ```bash
58 | git clone https://github.com/thunlp/OpenNRE.git
59 | ```
60 |
61 | If it is too slow, you can try
62 | ```
63 | git clone https://github.com/thunlp/OpenNRE.git --depth 1
64 | ```
65 |
66 | Then install all the requirements:
67 |
68 | ```
69 | pip install -r requirements.txt
70 | ```
71 |
72 | **Note**: Please choose appropriate PyTorch version based on your machine (related to your CUDA version). For details, refer to https://pytorch.org/.
73 |
74 | Then install the package with
75 | ```
76 | python setup.py install
77 | ```
78 |
79 | If you also want to modify the code, run this:
80 | ```
81 | python setup.py develop
82 | ```
83 |
84 | Note that we have excluded all data and pretrain files for fast deployment. You can manually download them by running scripts in the ``benchmark`` and ``pretrain`` folders. For example, if you want to download FewRel dataset, you can run
85 |
86 | ```bash
87 | bash benchmark/download_fewrel.sh
88 | ```
89 | ### Data
90 |
91 | You can go into the `benchmark` folder and download datasets using our scripts. We also list some of the information about the datasets in [this document](https://opennre-docs.readthedocs.io/en/latest/get_started/benchmark.html#bag-level-relation-extraction). We provide two distantly-supervised datasets with human-annotated test sets, **NYT10m** and **Wiki20m**. Check the [datasets](#datasets) section for details.
92 |
93 | ### Easy Start
94 |
95 | Make sure you have installed OpenNRE as instructed above. Then import our package and load pre-trained models.
96 |
97 | ```python
98 | >>> import opennre
99 | >>> model = opennre.get_model('wiki80_cnn_softmax')
100 | ```
101 |
102 | Note that it may take a few minutes to download checkpoint and data for the first time. Then use `infer` to do sentence-level relation extraction
103 |
104 | ```python
105 | >>> model.infer({'text': 'He was the son of Máel Dúin mac Máele Fithrich, and grandson of the high king Áed Uaridnach (died 612).', 'h': {'pos': (18, 46)}, 't': {'pos': (78, 91)}})
106 | ('father', 0.5108704566955566)
107 | ```
108 |
109 | You will get the relation result and its confidence score.
110 |
111 | If you want to use the model on your GPU, just run
112 | ```python
113 | >>> model = model.cuda()
114 | ```
115 | before calling the inference function.
116 |
117 | For now, we have the following available models:
118 |
119 | * `wiki80_cnn_softmax`: trained on `wiki80` dataset with a CNN encoder.
120 | * `wiki80_bert_softmax`: trained on `wiki80` dataset with a BERT encoder.
121 | * `wiki80_bertentity_softmax`: trained on `wiki80` dataset with a BERT encoder (using entity representation concatenation).
122 | * `tacred_bert_softmax`: trained on `TACRED` dataset with a BERT encoder.
123 | * `tacred_bertentity_softmax`: trained on `TACRED` dataset with a BERT encoder (using entity representation concatenation).
124 |
125 | ### Training
126 |
127 | You can train your own models on your own data with OpenNRE. In `example` folder we give example training codes for supervised RE models and bag-level RE models. You can either use our provided datasets or your own datasets. For example, you can use the following script to train a PCNN-ATT bag-level model on the NYT10 dataset with manual test set. The ATT algorithm is a typical method to combine a bag of sentences for extracting relations between entities.
128 |
129 | ```bash
130 | python example/train_bag_cnn.py \
131 | --metric auc \
132 | --dataset nyt10m \
133 | --batch_size 160 \
134 | --lr 0.1 \
135 | --weight_decay 1e-5 \
136 | --max_epoch 100 \
137 | --max_length 128 \
138 | --seed 42 \
139 | --encoder pcnn \
140 | --aggr att
141 | ```
142 |
143 | Or use the following script to train a BERT model on the Wiki80 dataset:
144 | ```bash
145 | python example/train_supervised_bert.py \
146 | --pretrain_path bert-base-uncased \
147 | --dataset wiki80
148 | ```
149 |
150 | We provide many options in the example training code and you can check them out for detailed instructions.
151 |
152 | ## Citation
153 |
154 | If you find OpenNRE is useful for your research, please consider citing the following papers:
155 |
156 | ```
157 | @inproceedings{han-etal-2019-opennre,
158 | title = "{O}pen{NRE}: An Open and Extensible Toolkit for Neural Relation Extraction",
159 | author = "Han, Xu and Gao, Tianyu and Yao, Yuan and Ye, Deming and Liu, Zhiyuan and Sun, Maosong",
160 | booktitle = "Proceedings of EMNLP-IJCNLP: System Demonstrations",
161 | year = "2019",
162 | url = "https://www.aclweb.org/anthology/D19-3029",
163 | doi = "10.18653/v1/D19-3029",
164 | pages = "169--174"
165 | }
166 | ```
167 |
168 | This package is mainly contributed by [Tianyu Gao](https://github.com/gaotianyu1350), [Xu Han](https://github.com/THUCSTHanxu13), [Shulian Cao](https://github.com/ShulinCao), [Lumin Tang](https://github.com/Tsingularity), [Yankai Lin](https://github.com/Mrlyk423), [Zhiyuan Liu](http://nlp.csai.tsinghua.edu.cn/~lzy/)
169 |
170 |
171 | ******************
172 | ## About OpenSKL
173 | OpenSKL project aims to harness the power of both structured knowledge and natural languages via representation learning. All sub-projects of OpenSKL, under the categories of **Algorithm**, **Resource** and **Application**, are as follows.
174 |
175 | - **Algorithm**:
176 | - [OpenKE](https://www.github.com/thunlp/OpenKE)
177 | - An effective and efficient toolkit for representing structured knowledge in large-scale knowledge graphs as embeddings, with TransR and PTransE as key features to handle complex relations and relational paths.
178 | - This toolkit also includes three repositories:
179 | - [KB2E](https://www.github.com/thunlp/KB2E)
180 | - [TensorFlow-Transx](https://www.github.com/thunlp/TensorFlow-Transx)
181 | - [Fast-TransX](https://www.github.com/thunlp/Fast-TransX)
182 | - [ERNIE](https://github.com/thunlp/ERNIE)
183 | - An effective and efficient toolkit for augmenting pre-trained language models with knowledge graph representations.
184 | - [OpenNE](https://www.github.com/thunlp/OpenNE)
185 | - An effective and efficient toolkit for representing nodes in large-scale graphs as embeddings, with [TADW](https://www.ijcai.org/Proceedings/15/Papers/299.pdf) as key features to incorporate text attributes of nodes.
186 | - [OpenNRE](https://www.github.com/thunlp/OpenNRE)
187 | - An effective and efficient toolkit for implementing neural networks for extracting structured knowledge from text, with [ATT](https://aclanthology.org/P16-1200.pdf) as key features to consider relation-associated text information.
188 | - This toolkit also includes two repositories:
189 | - [JointNRE](https://www.github.com/thunlp/JointNRE)
190 | - [NRE](https://github.com/thunlp/NRE)
191 | - **Resource**:
192 | - The embeddings of large-scale knowledge graphs pre-trained by OpenKE, covering three typical large-scale knowledge graphs: Wikidata, Freebase, and XLORE. The embeddings are free to use under the [MIT license](https://opensource.org/license/mit/), and please click the following link to submit [download requests](http://139.129.163.161/download/wikidata).
193 | - OpenKE-Wikidata
194 | - Wikidata is a free and collaborative database, collecting structured data to provide support for Wikipedia. The original Wikidata contains 20,982,733 entities, 594 relations and 68,904,773 triplets. In particular, Wikidata-5M is the core subgraph of Wikidata, containing 5,040,986 high-frequency entities from Wikidata with their corresponding 927 relations and 24,267,796 triplets.
195 | - [TransE version](http://139.129.163.161/download/wikidata): Knowledge embeddings of Wikidata pre-trained by OpenKE.
196 | - [TransR version](http://139.129.163.161/download/wikidata) of Wikidata-5M: Knowledge embeddings of Wikidata-5M pre-trained by OpenKE.
197 | - OpenKE-Freebase
198 | - Freebase was a large collaborative knowledge base consisting of data composed mainly by its community members. It was an online collection of structured data harvested from many sources. Freebase contains 86,054,151 entities, 14,824 relations and 338,586,276 triplets.
199 | - [TransE version](http://139.129.163.161/download/wikidata): Knowledge embeddings of Freebase pre-trained by OpenKE.
200 | - OpenKE-XLORE
201 | - XLORE is one of the most popular Chinese knowledge graphs developed by THUKEG. XLORE contains 10,572,209 entities, 138,581 relations and 35,954,249 triplets.
202 | - [TransE version](http://139.129.163.161/download/wikidata): Knowledge embeddings of XLORE pre-trained by OpenKE.
203 | - **Application**:
204 | - [Knowledge-Plugin](https://github.com/THUNLP/Knowledge-Plugin)
205 | - An effective and efficient toolkit of plug-and-play knowledge injection for pre-trained language models. Knowledge-Plugin is general for all kinds of knowledge graph embeddings mentioned above. In the toolkit, we plug the TransR version of Wikidata-5M into BERT as an example of applications. With the TransR embedding, we enhance the knowledge ability of BERT without fine-tuning the original model, e.g., up to 8% improvement on question answering.
206 |
--------------------------------------------------------------------------------
/benchmark/download_fewrel.sh:
--------------------------------------------------------------------------------
1 | mkdir fewrel
2 | wget -P fewrel https://thunlp.oss-cn-qingdao.aliyuncs.com/opennre/benchmark/fewrel/fewrel_train.txt
3 | wget -P fewrel https://thunlp.oss-cn-qingdao.aliyuncs.com/opennre/benchmark/fewrel/fewrel_train_rel2id.json
4 | wget -P fewrel https://thunlp.oss-cn-qingdao.aliyuncs.com/opennre/benchmark/fewrel/fewrel_val.txt
5 | wget -P fewrel https://thunlp.oss-cn-qingdao.aliyuncs.com/opennre/benchmark/fewrel/fewrel_val_rel2id.json
6 |
--------------------------------------------------------------------------------
/benchmark/download_nyt10.sh:
--------------------------------------------------------------------------------
1 | mkdir nyt10
2 | wget -P nyt10 https://thunlp.oss-cn-qingdao.aliyuncs.com/opennre/benchmark/nyt10/nyt10_rel2id.json
3 | wget -P nyt10 https://thunlp.oss-cn-qingdao.aliyuncs.com/opennre/benchmark/nyt10/nyt10_train.txt
4 | wget -P nyt10 https://thunlp.oss-cn-qingdao.aliyuncs.com/opennre/benchmark/nyt10/nyt10_test.txt
5 |
--------------------------------------------------------------------------------
/benchmark/download_nyt10m.sh:
--------------------------------------------------------------------------------
1 | mkdir nyt10m
2 | wget -P nyt10m https://thunlp.oss-cn-qingdao.aliyuncs.com/opennre/benchmark/nyt10m/nyt10m_rel2id.json
3 | wget -P nyt10m https://thunlp.oss-cn-qingdao.aliyuncs.com/opennre/benchmark/nyt10m/nyt10m_train.txt
4 | wget -P nyt10m https://thunlp.oss-cn-qingdao.aliyuncs.com/opennre/benchmark/nyt10m/nyt10m_val.txt
5 | wget -P nyt10m https://thunlp.oss-cn-qingdao.aliyuncs.com/opennre/benchmark/nyt10m/nyt10m_test.txt
6 |
--------------------------------------------------------------------------------
/benchmark/download_semeval.sh:
--------------------------------------------------------------------------------
1 | mkdir semeval
2 | wget -P semeval https://thunlp.oss-cn-qingdao.aliyuncs.com/opennre/benchmark/semeval/semeval_rel2id.json
3 | wget -P semeval https://thunlp.oss-cn-qingdao.aliyuncs.com/opennre/benchmark/semeval/semeval_train.txt
4 | wget -P semeval https://thunlp.oss-cn-qingdao.aliyuncs.com/opennre/benchmark/semeval/semeval_val.txt
5 | wget -P semeval https://thunlp.oss-cn-qingdao.aliyuncs.com/opennre/benchmark/semeval/semeval_test.txt
6 |
--------------------------------------------------------------------------------
/benchmark/download_wiki20m.sh:
--------------------------------------------------------------------------------
1 | mkdir wiki20m
2 | wget -P wiki20m https://thunlp.oss-cn-qingdao.aliyuncs.com/opennre/benchmark/wiki20m/wiki20m_rel2id.json
3 | wget -P wiki20m https://thunlp.oss-cn-qingdao.aliyuncs.com/opennre/benchmark/wiki20m/wiki20m_train.txt
4 | wget -P wiki20m https://thunlp.oss-cn-qingdao.aliyuncs.com/opennre/benchmark/wiki20m/wiki20m_val.txt
5 | wget -P wiki20m https://thunlp.oss-cn-qingdao.aliyuncs.com/opennre/benchmark/wiki20m/wiki20m_test.txt
6 |
--------------------------------------------------------------------------------
/benchmark/download_wiki80.sh:
--------------------------------------------------------------------------------
1 | mkdir wiki80
2 | wget -P wiki80 https://thunlp.oss-cn-qingdao.aliyuncs.com/opennre/benchmark/wiki80/wiki80_rel2id.json
3 | wget -P wiki80 https://thunlp.oss-cn-qingdao.aliyuncs.com/opennre/benchmark/wiki80/wiki80_train.txt
4 | wget -P wiki80 https://thunlp.oss-cn-qingdao.aliyuncs.com/opennre/benchmark/wiki80/wiki80_val.txt
5 |
--------------------------------------------------------------------------------
/benchmark/rel4wiki80.md:
--------------------------------------------------------------------------------
1 |
2 | | Wikidata ID | Name | Description|
3 | | ---- | ---- | ---- |
4 | |P6 | head of government | head of the executive power of this town, city, municipality, state, country, or other governmental body |
5 | |P17 | country | sovereign state of this item; don't use on humans |
6 | |P22 | father | male parent of the subject. For stepfather, use "stepparent" (P3448) |
7 | |P27 | country of citizenship | the object is a country that recognizes the subject as its citizen |
8 | |P31 | instance of | that class of which this subject is a particular example and member. (Subject typically an individual member with Proper Name label.) Different from P279 (subclass of) |
9 | |P39 | position held | subject currently or formerly holds the object position or public office |
10 | |P57 | director | director(s) of this motion picture, TV-series, stageplay, video game or similar |
11 | |P58 | screenwriter | author(s) of the screenplay or script for this work |
12 | |P84 | architect | person or architectural firm that designed this building |
13 | |P86 | composer | person(s) who wrote the music; also use P676 for lyricist |
14 | |P101 | field of work | specialization of a person or organization; see P106 for the occupation |
15 | |P102 | member of political party | the political party of which this politician is or has been a member |
16 | |P105 | taxon rank | level in a taxonomic hierarchy |
17 | |P106 | occupation | occupation of a person; see also "field of work" (Property:P101), "position held" (Property:P39) |
18 | |P118 | league | league in which team or player plays or has played in |
19 | |P123 | publisher | organization or person responsible for publishing books, periodicals, games or software |
20 | |P127 | owned by | owner of the subject |
21 | |P131 | located in the administrative territorial entity | the item is located on the territory of the following administrative entity. Use P276 (location) for specifying the location of non-administrative places and for items about events |
22 | |P135 | movement | literary, artistic, scientific or philosophical movement associated with this person or work |
23 | |P136 | genre | creative work's genre or an artist's field of work (P101). Use main subject (P921) to relate creative works to their topic |
24 | |P137 | operator | person or organization that operates the equipment, facility, or service; use country for diplomatic missions |
25 | |P140 | religion | religion of a person, organization or religious building, or associated with this subject |
26 | |P150 | contains administrative territorial entity | (list of) direct subdivisions of an administrative territorial entity |
27 | |P156 | followed by | immediately following item in some series of which the subject is part. Use P1366 (replaced by) if the item is replaced, e.g. political offices, states |
28 | |P159 | headquarters location | specific location where an organization's headquarters is or has been situated |
29 | |P175 | performer | performer involved in the performance or the recording of a work |
30 | |P176 | manufacturer | manufacturer or producer of this product |
31 | |P178 | developer | organisation or person that developed this item |
32 | |P241 | military branch | branch to which this military unit, award, office, or person belongs, e.g. Royal Navy |
33 | |P264 | record label | brand and trademark associated with the marketing of subject music recordings and music videos |
34 | |P276 | location | location of the item, physical object or event is within. In case of an administrative entity use P131. In case of a distinct terrain feature use P706. |
35 | |P306 | operating system | operating system (OS) on which a software works or the OS installed on hardware |
36 | |P355 | subsidiary | subsidiary of a company or organization, opposite of parent company (P749) |
37 | |P400 | platform | platform for which a work has been developed or released / specific platform version of a software developed |
38 | |P403 | mouth of the watercourse | the body of water to which the watercourse drains |
39 | |P407 | language of work or name | language associated with this work or name (for persons use P103 and P1412) |
40 | |P449 | original network | network(s) the radio or television show was originally aired on, including |
41 | |P460 | said to be the same as | this item is said to be the same as that item, but the statement is disputed |
42 | |P466 | occupant | person or organization occupying a building or facility |
43 | |P495 | country of origin | country of origin of the creative work or subject item |
44 | |P527 | has part | part of this subject. Inverse property of "part of" (P361). |
45 | |P551 | residence | the place where the person is, or has been, resident |
46 | |P674 | characters | characters which appear in this item (like plays, operas, operettas, books, comics, films, TV series, video games) |
47 | |P706 | located on terrain feature | located on the specified landform. Should not be used when the value is only political/administrative (provinces, states, countries, etc.). Use P131 for administrative entity. |
48 | |P710 | participant | person, group of people or organization (object) that actively takes/took part in the event (subject). Preferably qualify with "object has role" (P3831). Use P1923 for team participants. |
49 | |P740 | location of formation | location where a group or organization was formed |
50 | |P750 | distributor | distributor of a creative work |
51 | |P800 | notable work | notable scientific, artistic or literary work, or other work of significance among subject's works |
52 | |P931 | place served by transport hub | city or region served by this transport hub (airport, train station, etc.) |
53 | |P937 | work location | location where persons were active |
54 | |P974 | tributary | stream or river that flows into this main stem (or parent) river |
55 | |P991 | successful candidate | person(s) elected after the election |
56 | |P1001 | applies to jurisdiction | the item (an institution, law, public office ...) belongs to or has power over or applies to the value (a territorial jurisdiction: a country, state, municipality, ...) |
57 | |P1303 | instrument | musical instrument that a person plays |
58 | |P1344 | participant of | event a person or an organization was a participant in, inverse of P710 or P1923 |
59 | |P1346 | winner | winner of an event - do not use for wars or battles |
60 | |P1408 | licensed to broadcast to | place that a radio/TV station is licensed/required to broadcast to |
61 | |P1411 | nominated for | award nomination received by a person, organisation or creative work (inspired from "award received" (Property:P166)) |
62 | |P1435 | heritage designation | heritage designation of a historical site |
63 | |P1877 | after a work by | artist whose work strongly inspired/ was copied in this item |
64 | |P1923 | participating teams | Like 'Participant' (P710) but for teams. For an event like a cycle race or a football match you can use this property to list the teams and P710 to list the individuals (with 'member of sports team' (P54)' as a qualifier for the individuals) |
65 | |P3373 | sibling | the subject has the object as their sibling (brother, sister, etc.). Use "relative" (P1038) for siblings-in-law (brother-in-law, sister-in-law, etc.) and step-siblings (step-brothers, step-sisters, etc.) |
66 | |P3450 | sports season of | property that shows the competition of which the item is a season |
67 | |P4552 | mountain range | range or subrange to which the geographical item belongs |
68 | |P25 | mother | female parent of the subject. For stepmother, use "stepparent" (P3448) |
69 | |P26 | spouse | the subject has the object as their spouse (husband, wife, partner, etc.). Use "unmarried partner" (P451) for non-married companions |
70 | |P40 | child | subject has the object in their family as their offspring son or daughter (independently of their age) |
71 | |P59 | constellation | the area of the celestial sphere of which the subject is a part (from a scientific standpoint, not an astrological one) |
72 | |P155 | follows | immediately prior item in some series of which the subject is part. Use P1365 (replaces) if the preceding item was replaced, e.g. political offices, states and there is no identity between precedent and following geographic unit |
73 | |P177 | crosses | obstacle (body of water, road, ...) which this bridge crosses over or this tunnel goes under |
74 | |P206 | located in or next to body of water | sea, lake or river |
75 | |P361 | part of | object of which the subject is a part. Inverse property of "has part" (P527). See also "has parts of the class" (P2670). |
76 | |P364 | original language of work | language in which a film or a performance work was originally created. Deprecated for written works; use P407 ("language of work or name") instead. |
77 | |P410 | military rank | military rank achieved by a person (should usually have a "start date" qualifier), or military rank associated with a position |
78 | |P412 | voice type | person's voice type. expected values: soprano, mezzo-soprano, contralto, countertenor, tenor, baritone, bass (and derivatives) |
79 | |P413 | position played on team / speciality | position or specialism of a player on a team, e.g. Small Forward |
80 | |P463 | member of | organization or club to which the subject belongs. Do not use for membership in ethnic or social groups, nor for holding a position such as a member of parliament (use P39 for that). |
81 | |P641 | sport | sport in which the entity participates or belongs to |
82 | |P921 | main subject | primary topic of a work (see also P180: depicts) |
83 | |P2094 | competition class | official classification by a regulating body under which the subject (events, teams, participants, or equipment) qualifies for inclusion |
84 |
--------------------------------------------------------------------------------
/example/test_multilabel_bert.py:
--------------------------------------------------------------------------------
1 | # coding:utf-8
2 | import torch
3 | import numpy as np
4 | import json
5 | import opennre
6 | from opennre import encoder, model, framework
7 | import sys
8 | import os
9 | import argparse
10 | import logging
11 | import random
12 |
13 | def set_seed(seed):
14 | random.seed(seed)
15 | np.random.seed(seed)
16 | torch.manual_seed(seed)
17 | torch.cuda.manual_seed_all(seed)
18 |
19 |
20 | parser = argparse.ArgumentParser()
21 | parser.add_argument('--pretrain_path', default='bert-base-uncased',
22 | help='Pre-trained ckpt path / model name (hugginface)')
23 | parser.add_argument('--ckpt', default='',
24 | help='Checkpoint name')
25 | parser.add_argument('--pooler', default='entity', choices=['cls', 'entity'],
26 | help='Sentence representation pooler')
27 | parser.add_argument('--only_test', action='store_true',
28 | help='Only run test')
29 | parser.add_argument('--mask_entity', action='store_true',
30 | help='Mask entity mentions')
31 |
32 | # Data
33 | parser.add_argument('--metric', default='micro_f1', choices=['micro_f1', 'acc'],
34 | help='Metric for picking up best checkpoint')
35 | parser.add_argument('--dataset', default='none', choices=['none', 'semeval', 'wiki80', 'tacred'],
36 | help='Dataset. If not none, the following args can be ignored')
37 | parser.add_argument('--train_file', default='', type=str,
38 | help='Training data file')
39 | parser.add_argument('--val_file', default='', type=str,
40 | help='Validation data file')
41 | parser.add_argument('--test_file', default='', type=str,
42 | help='Test data file')
43 | parser.add_argument('--rel2id_file', default='', type=str,
44 | help='Relation to ID file')
45 |
46 | # Hyper-parameters
47 | parser.add_argument('--batch_size', default=64, type=int,
48 | help='Batch size')
49 | parser.add_argument('--lr', default=2e-5, type=float,
50 | help='Learning rate')
51 | parser.add_argument('--max_length', default=128, type=int,
52 | help='Maximum sentence length')
53 | parser.add_argument('--max_epoch', default=3, type=int,
54 | help='Max number of training epochs')
55 |
56 | # Seed
57 | parser.add_argument('--seed', default=42, type=int,
58 | help='Seed')
59 |
60 | args = parser.parse_args()
61 |
62 | # Set random seed
63 | set_seed(args.seed)
64 |
65 | # Some basic settings
66 | root_path = '.'
67 | sys.path.append(root_path)
68 | if not os.path.exists('ckpt'):
69 | os.mkdir('ckpt')
70 | if len(args.ckpt) == 0:
71 | args.ckpt = '{}_{}_{}'.format(args.dataset, args.pretrain_path, args.pooler)
72 | ckpt = 'ckpt/{}.pth.tar'.format(args.ckpt)
73 |
74 | if args.dataset != 'none':
75 | opennre.download(args.dataset, root_path=root_path)
76 | args.train_file = os.path.join(root_path, 'benchmark', args.dataset, '{}_train.txt'.format(args.dataset))
77 | args.val_file = os.path.join(root_path, 'benchmark', args.dataset, '{}_val.txt'.format(args.dataset))
78 | args.test_file = os.path.join(root_path, 'benchmark', args.dataset, '{}_test.txt'.format(args.dataset))
79 | if not os.path.exists(args.test_file):
80 | logging.warn("Test file {} does not exist! Use val file instead".format(args.test_file))
81 | args.test_file = args.val_file
82 | args.rel2id_file = os.path.join(root_path, 'benchmark', args.dataset, '{}_rel2id.json'.format(args.dataset))
83 | if args.dataset == 'wiki80':
84 | args.metric = 'acc'
85 | else:
86 | args.metric = 'micro_f1'
87 | else:
88 | if not (os.path.exists(args.train_file) and os.path.exists(args.val_file) and os.path.exists(args.test_file) and os.path.exists(args.rel2id_file)):
89 | raise Exception('--train_file, --val_file, --test_file and --rel2id_file are not specified or files do not exist. Or specify --dataset')
90 |
91 | logging.info('Arguments:')
92 | for arg in vars(args):
93 | logging.info(' {}: {}'.format(arg, getattr(args, arg)))
94 |
95 | rel2id = json.load(open(args.rel2id_file))
96 |
97 | # Define the sentence encoder
98 | if args.pooler == 'entity':
99 | sentence_encoder = opennre.encoder.BERTEntityEncoder(
100 | max_length=args.max_length,
101 | pretrain_path=args.pretrain_path,
102 | mask_entity=args.mask_entity
103 | )
104 | elif args.pooler == 'cls':
105 | sentence_encoder = opennre.encoder.BERTEncoder(
106 | max_length=args.max_length,
107 | pretrain_path=args.pretrain_path,
108 | mask_entity=args.mask_entity
109 | )
110 | else:
111 | raise NotImplementedError
112 |
113 | # Define the model
114 | model = opennre.model.SoftmaxNN(sentence_encoder, len(rel2id), rel2id)
115 |
116 | # Define the whole training framework
117 | framework = opennre.framework.MultiLabelSentenceRE(
118 | train_path=args.train_file,
119 | val_path=args.val_file,
120 | test_path=args.test_file,
121 | model=model,
122 | ckpt=ckpt,
123 | batch_size=args.batch_size,
124 | max_epoch=args.max_epoch,
125 | lr=args.lr,
126 | opt='adamw'
127 | )
128 |
129 | # Train the model
130 | if not args.only_test:
131 | framework.train_model('micro_f1')
132 |
133 | # Test
134 | framework.load_state_dict(torch.load(ckpt)['state_dict'])
135 | result = framework.eval_model(framework.test_loader)
136 |
137 | # Print the result
138 | logging.info('Test set results:')
139 | logging.info('Accuracy: {}'.format(result['acc']))
140 | logging.info('Micro precision: {}'.format(result['micro_p']))
141 | logging.info('Micro recall: {}'.format(result['micro_r']))
142 | logging.info('Micro F1: {}'.format(result['micro_f1']))
143 | logging.info('Macro precision: {}'.format(result['macro_p']))
144 | logging.info('Macro recall: {}'.format(result['macro_r']))
145 | logging.info('Macro F1: {}'.format(result['macro_f1']))
146 |
--------------------------------------------------------------------------------
/example/test_multilabel_cnn.py:
--------------------------------------------------------------------------------
1 | # coding:utf-8
2 | import torch
3 | import numpy as np
4 | import json
5 | import opennre
6 | from opennre import encoder, model, framework
7 | import sys
8 | import os
9 | import argparse
10 | import logging
11 |
12 | parser = argparse.ArgumentParser()
13 | parser.add_argument('--ckpt', default='',
14 | help='Checkpoint name')
15 | parser.add_argument('--only_test', action='store_true',
16 | help='Only run test')
17 |
18 | # Data
19 | parser.add_argument('--metric', default='micro_f1', choices=['micro_f1', 'acc'],
20 | help='Metric for picking up best checkpoint')
21 | parser.add_argument('--dataset', default='none', choices=['none', 'semeval', 'wiki80', 'tacred'],
22 | help='Dataset. If not none, the following args can be ignored')
23 | parser.add_argument('--train_file', default='', type=str,
24 | help='Training data file')
25 | parser.add_argument('--val_file', default='', type=str,
26 | help='Validation data file')
27 | parser.add_argument('--test_file', default='', type=str,
28 | help='Test data file')
29 | parser.add_argument('--rel2id_file', default='', type=str,
30 | help='Relation to ID file')
31 |
32 | # Hyper-parameters
33 | parser.add_argument('--batch_size', default=32, type=int,
34 | help='Batch size')
35 | parser.add_argument('--lr', default=1e-1, type=float,
36 | help='Learning rate')
37 | parser.add_argument('--weight_decay', default=1e-5, type=float,
38 | help='Weight decay')
39 | parser.add_argument('--max_length', default=128, type=int,
40 | help='Maximum sentence length')
41 | parser.add_argument('--max_epoch', default=100, type=int,
42 | help='Max number of training epochs')
43 |
44 | parser.add_argument('--encoder', default='cnn', choices=['pcnn', 'cnn'])
45 |
46 | args = parser.parse_args()
47 |
48 | # Some basic settings
49 | root_path = '.'
50 | sys.path.append(root_path)
51 | if not os.path.exists('ckpt'):
52 | os.mkdir('ckpt')
53 | if len(args.ckpt) == 0:
54 | args.ckpt = '{}_{}'.format(args.dataset, 'cnn')
55 | ckpt = 'ckpt/{}.pth.tar'.format(args.ckpt)
56 |
57 | if args.dataset != 'none':
58 | opennre.download(args.dataset, root_path=root_path)
59 | args.train_file = os.path.join(root_path, 'benchmark', args.dataset, '{}_train.txt'.format(args.dataset))
60 | args.val_file = os.path.join(root_path, 'benchmark', args.dataset, '{}_val.txt'.format(args.dataset))
61 | args.test_file = os.path.join(root_path, 'benchmark', args.dataset, '{}_test.txt'.format(args.dataset))
62 | args.rel2id_file = os.path.join(root_path, 'benchmark', args.dataset, '{}_rel2id.json'.format(args.dataset))
63 | if args.dataset == 'wiki80':
64 | args.metric = 'acc'
65 | else:
66 | args.metric = 'micro_f1'
67 | else:
68 | if not (os.path.exists(args.train_file) and os.path.exists(args.val_file) and os.path.exists(args.test_file) and os.path.exists(args.rel2id_file)):
69 | raise Exception('--train_file, --val_file, --test_file and --rel2id_file are not specified or files do not exist. Or specify --dataset')
70 |
71 | logging.info('Arguments:')
72 | for arg in vars(args):
73 | logging.info(' {}: {}'.format(arg, getattr(args, arg)))
74 |
75 | rel2id = json.load(open(args.rel2id_file))
76 |
77 | # Download glove
78 | opennre.download('glove', root_path=root_path)
79 | word2id = json.load(open(os.path.join(root_path, 'pretrain/glove/glove.6B.50d_word2id.json')))
80 | word2vec = np.load(os.path.join(root_path, 'pretrain/glove/glove.6B.50d_mat.npy'))
81 |
82 | # Define the sentence encoder
83 | if args.encoder == 'cnn':
84 | sentence_encoder = opennre.encoder.CNNEncoder(
85 | token2id=word2id,
86 | max_length=args.max_length,
87 | word_size=50,
88 | position_size=5,
89 | hidden_size=230,
90 | blank_padding=True,
91 | kernel_size=3,
92 | padding_size=1,
93 | word2vec=word2vec,
94 | dropout=0.5
95 | )
96 | elif args.encoder == 'pcnn':
97 | sentence_encoder = opennre.encoder.PCNNEncoder(
98 | token2id=word2id,
99 | max_length=args.max_length,
100 | word_size=50,
101 | position_size=5,
102 | hidden_size=230,
103 | blank_padding=True,
104 | kernel_size=3,
105 | padding_size=1,
106 | word2vec=word2vec,
107 | dropout=0.5
108 | )
109 | else:
110 | raise NotImplementedError
111 |
112 | # Define the model
113 | model = opennre.model.SoftmaxNN(sentence_encoder, len(rel2id), rel2id)
114 |
115 | # Define the whole training framework
116 | framework = opennre.framework.MultiLabelSentenceRE(
117 | train_path=args.train_file,
118 | val_path=args.val_file,
119 | test_path=args.test_file,
120 | model=model,
121 | ckpt=ckpt,
122 | batch_size=args.batch_size,
123 | max_epoch=args.max_epoch,
124 | lr=args.lr,
125 | weight_decay=args.weight_decay,
126 | opt='sgd'
127 | )
128 |
129 | # Train the model
130 | if not args.only_test:
131 | framework.train_model(args.metric)
132 |
133 | # Test
134 | framework.load_state_dict(torch.load(ckpt)['state_dict'])
135 | result = framework.eval_model(framework.test_loader)
136 |
137 | # Print the result
138 | logging.info('Test set results:')
139 | logging.info('Accuracy: %.5f' % (result['acc']))
140 | logging.info('Micro precision: %.5f' % (result['micro_p']))
141 | logging.info('Micro recall: %.5f' % (result['micro_r']))
142 | logging.info('Micro F1: %.5f' % (result['micro_f1']))
143 | logging.info('Macro precision: %.5f' % (result['macro_p']))
144 | logging.info('Macro recall: %.5f' % (result['macro_r']))
145 | logging.info('Macro F1: %.5f' % (result['macro_f1']))
146 | logging.info("AUC: %.5f" % (result['auc']))
147 | logging.info('P@100: %.5f' % (result['p@100']))
148 | logging.info('P@200: %.5f' % (result['p@200']))
149 | logging.info('P@300: %.5f' % (result['p@300']))
150 | logging.info("Max micro F1: %.5f" % (result['max_micro_f1']))
151 | logging.info("| threshold: %.5f" % (result['max_micro_f1_threshold']))
152 |
--------------------------------------------------------------------------------
/example/train_bag_bert.py:
--------------------------------------------------------------------------------
1 | # coding:utf-8
2 | import sys, json
3 | import torch
4 | import os
5 | import numpy as np
6 | import opennre
7 | import argparse
8 | import logging
9 | import random
10 |
11 | def set_seed(seed):
12 | random.seed(seed)
13 | np.random.seed(seed)
14 | torch.manual_seed(seed)
15 | torch.cuda.manual_seed_all(seed)
16 |
17 | parser = argparse.ArgumentParser()
18 | parser.add_argument('--pretrain_path', default='bert-base-uncased',
19 | help='Pre-trained ckpt path / model name (hugginface)')
20 | parser.add_argument('--ckpt', default='',
21 | help='Checkpoint name')
22 | parser.add_argument('--result', default='',
23 | help='Result name')
24 | parser.add_argument('--pooler', default='entity', choices=['cls', 'entity'],
25 | help='Sentence representation pooler')
26 | parser.add_argument('--only_test', action='store_true',
27 | help='Only run test')
28 | parser.add_argument('--mask_entity', action='store_true',
29 | help='Mask entity mentions')
30 |
31 | # Data
32 | parser.add_argument('--metric', default='auc', choices=['micro_f1', 'auc'],
33 | help='Metric for picking up best checkpoint')
34 | parser.add_argument('--dataset', default='none', choices=['none', 'wiki_distant', 'nyt10', 'nyt10m', 'wiki20m'],
35 | help='Dataset. If not none, the following args can be ignored')
36 | parser.add_argument('--train_file', default='', type=str,
37 | help='Training data file')
38 | parser.add_argument('--val_file', default='', type=str,
39 | help='Validation data file')
40 | parser.add_argument('--test_file', default='', type=str,
41 | help='Test data file')
42 | parser.add_argument('--rel2id_file', default='', type=str,
43 | help='Relation to ID file')
44 |
45 | # Bag related
46 | parser.add_argument('--bag_size', type=int, default=4,
47 | help='Fixed bag size. If set to 0, use original bag sizes')
48 |
49 | # Hyper-parameters
50 | parser.add_argument('--batch_size', default=16, type=int,
51 | help='Batch size')
52 | parser.add_argument('--lr', default=2e-5, type=float,
53 | help='Learning rate')
54 | parser.add_argument('--max_length', default=128, type=int,
55 | help='Maximum sentence length')
56 | parser.add_argument('--max_epoch', default=3, type=int,
57 | help='Max number of training epochs')
58 |
59 | # Exp
60 | parser.add_argument('--aggr', default='att', choices=['one', 'att', 'avg'])
61 |
62 |
63 | # Seed
64 | parser.add_argument('--seed', default=42, type=int,
65 | help='Seed')
66 |
67 | args = parser.parse_args()
68 |
69 | # Set random seed
70 | set_seed(args.seed)
71 | args = parser.parse_args()
72 |
73 | # Some basic settings
74 | root_path = '.'
75 | sys.path.append(root_path)
76 | if not os.path.exists('ckpt'):
77 | os.mkdir('ckpt')
78 | ckpt = 'ckpt/{}.pth.tar'.format(args.ckpt)
79 |
80 | if args.dataset != 'none':
81 | opennre.download(args.dataset, root_path=root_path)
82 | args.train_file = os.path.join(root_path, 'benchmark', args.dataset, '{}_train.txt'.format(args.dataset))
83 | args.val_file = os.path.join(root_path, 'benchmark', args.dataset, '{}_val.txt'.format(args.dataset))
84 | if not os.path.exists(args.val_file):
85 | logging.info("Cannot find the validation file. Use the test file instead.")
86 | args.val_file = os.path.join(root_path, 'benchmark', args.dataset, '{}_test.txt'.format(args.dataset))
87 | args.test_file = os.path.join(root_path, 'benchmark', args.dataset, '{}_test.txt'.format(args.dataset))
88 | args.rel2id_file = os.path.join(root_path, 'benchmark', args.dataset, '{}_rel2id.json'.format(args.dataset))
89 | else:
90 | if not (os.path.exists(args.train_file) and os.path.exists(args.val_file) and os.path.exists(args.test_file) and os.path.exists(args.rel2id_file)):
91 | raise Exception('--train_file, --val_file, --test_file and --rel2id_file are not specified or files do not exist. Or specify --dataset')
92 |
93 | logging.info('Arguments:')
94 | for arg in vars(args):
95 | logging.info(' {}: {}'.format(arg, getattr(args, arg)))
96 |
97 | rel2id = json.load(open(args.rel2id_file))
98 |
99 |
100 | # Define the sentence encoder
101 | if args.pooler == 'entity':
102 | sentence_encoder = opennre.encoder.BERTEntityEncoder(
103 | max_length=args.max_length,
104 | pretrain_path=args.pretrain_path,
105 | mask_entity=args.mask_entity
106 | )
107 | elif args.pooler == 'cls':
108 | sentence_encoder = opennre.encoder.BERTEncoder(
109 | max_length=args.max_length,
110 | pretrain_path=args.pretrain_path,
111 | mask_entity=args.mask_entity
112 | )
113 | else:
114 | raise NotImplementedError
115 |
116 |
117 | # Define the model
118 | if args.aggr == 'att':
119 | model = opennre.model.BagAttention(sentence_encoder, len(rel2id), rel2id)
120 | elif args.aggr == 'avg':
121 | model = opennre.model.BagAverage(sentence_encoder, len(rel2id), rel2id)
122 | elif args.aggr == 'one':
123 | model = opennre.model.BagOne(sentence_encoder, len(rel2id), rel2id)
124 | else:
125 | raise NotImplementedError
126 |
127 | # Define the whole training framework
128 | framework = opennre.framework.BagRE(
129 | train_path=args.train_file,
130 | val_path=args.val_file,
131 | test_path=args.test_file,
132 | model=model,
133 | ckpt=ckpt,
134 | batch_size=args.batch_size,
135 | max_epoch=args.max_epoch,
136 | lr=args.lr,
137 | opt="adamw",
138 | bag_size=args.bag_size
139 | )
140 |
141 | # Train the model
142 | if not args.only_test:
143 | framework.train_model(args.metric)
144 |
145 | # Test the model
146 | framework.load_state_dict(torch.load(ckpt)['state_dict'])
147 | result = framework.eval_model(framework.test_loader)
148 |
149 | # Print the result
150 | logging.info('Test set results:')
151 | logging.info('AUC: %.5f' % (result['auc']))
152 | logging.info('Maximum micro F1: %.5f' % (result['max_micro_f1']))
153 | logging.info('Maximum macro F1: %.5f' % (result['max_macro_f1']))
154 | logging.info('Micro F1: %.5f' % (result['micro_f1']))
155 | logging.info('Macro F1: %.5f' % (result['macro_f1']))
156 | logging.info('P@100: %.5f' % (result['p@100']))
157 | logging.info('P@200: %.5f' % (result['p@200']))
158 | logging.info('P@300: %.5f' % (result['p@300']))
159 |
160 |
161 | # Save precision/recall points
162 | np.save('result/{}_p.npy'.format(args.result), result['np_prec'])
163 | np.save('result/{}_r.npy'.format(args.result), result['np_rec'])
164 | json.dump(result['max_micro_f1_each_relation'], open('result/{}_mmicrof1_rel.json'.format(args.result), 'w'), ensure_ascii=False)
165 |
--------------------------------------------------------------------------------
/example/train_bag_cnn.py:
--------------------------------------------------------------------------------
1 | # coding:utf-8
2 | import sys, json
3 | import torch
4 | import os
5 | import numpy as np
6 | import opennre
7 | import argparse
8 | import logging
9 | import random
10 |
11 | def set_seed(seed):
12 | random.seed(seed)
13 | np.random.seed(seed)
14 | torch.manual_seed(seed)
15 | torch.cuda.manual_seed_all(seed)
16 |
17 | parser = argparse.ArgumentParser()
18 | parser.add_argument('--ckpt', default='',
19 | help='Checkpoint name')
20 | parser.add_argument('--result', default='',
21 | help='Save result name')
22 | parser.add_argument('--only_test', action='store_true',
23 | help='Only run test')
24 |
25 | # Data
26 | parser.add_argument('--metric', default='auc', choices=['micro_f1', 'auc'],
27 | help='Metric for picking up best checkpoint')
28 | parser.add_argument('--dataset', default='none', choices=['none', 'wiki_distant', 'nyt10', 'nyt10m', 'wiki20m'],
29 | help='Dataset. If not none, the following args can be ignored')
30 | parser.add_argument('--train_file', default='', type=str,
31 | help='Training data file')
32 | parser.add_argument('--val_file', default='', type=str,
33 | help='Validation data file')
34 | parser.add_argument('--test_file', default='', type=str,
35 | help='Test data file')
36 | parser.add_argument('--rel2id_file', default='', type=str,
37 | help='Relation to ID file')
38 |
39 | # Bag related
40 | parser.add_argument('--bag_size', type=int, default=0,
41 | help='Fixed bag size. If set to 0, use original bag sizes')
42 |
43 | # Hyper-parameters
44 | parser.add_argument('--batch_size', default=160, type=int,
45 | help='Batch size')
46 | parser.add_argument('--lr', default=0.1, type=float,
47 | help='Learning rate')
48 | parser.add_argument('--optim', default='sgd', type=str,
49 | help='Optimizer')
50 | parser.add_argument('--weight_decay', default=1e-5, type=float,
51 | help='Weight decay')
52 | parser.add_argument('--max_length', default=128, type=int,
53 | help='Maximum sentence length')
54 | parser.add_argument('--max_epoch', default=100, type=int,
55 | help='Max number of training epochs')
56 |
57 | # Others
58 | parser.add_argument('--seed', default=42, type=int,
59 | help='Random seed')
60 |
61 | # Exp
62 | parser.add_argument('--encoder', default='pcnn', choices=['pcnn', 'cnn'])
63 | parser.add_argument('--aggr', default='att', choices=['one', 'att', 'avg'])
64 |
65 | args = parser.parse_args()
66 |
67 | # Set random seed
68 | set_seed(args.seed)
69 |
70 | # Some basic settings
71 | root_path = '.'
72 | sys.path.append(root_path)
73 | if not os.path.exists('ckpt'):
74 | os.mkdir('ckpt')
75 | if len(args.ckpt) == 0:
76 | args.ckpt = '{}_{}'.format(args.dataset, 'pcnn_att')
77 | ckpt = 'ckpt/{}.pth.tar'.format(args.ckpt)
78 |
79 | if args.dataset != 'none':
80 | opennre.download(args.dataset, root_path=root_path)
81 | args.train_file = os.path.join(root_path, 'benchmark', args.dataset, '{}_train.txt'.format(args.dataset))
82 | args.val_file = os.path.join(root_path, 'benchmark', args.dataset, '{}_val.txt'.format(args.dataset))
83 | if not os.path.exists(args.val_file):
84 | logging.info("Cannot find the validation file. Use the test file instead.")
85 | args.val_file = os.path.join(root_path, 'benchmark', args.dataset, '{}_test.txt'.format(args.dataset))
86 | args.test_file = os.path.join(root_path, 'benchmark', args.dataset, '{}_test.txt'.format(args.dataset))
87 | args.rel2id_file = os.path.join(root_path, 'benchmark', args.dataset, '{}_rel2id.json'.format(args.dataset))
88 | else:
89 | if not (os.path.exists(args.train_file) and os.path.exists(args.val_file) and os.path.exists(args.test_file) and os.path.exists(args.rel2id_file)):
90 | raise Exception('--train_file, --val_file, --test_file and --rel2id_file are not specified or files do not exist. Or specify --dataset')
91 |
92 | logging.info('Arguments:')
93 | for arg in vars(args):
94 | logging.info(' {}: {}'.format(arg, getattr(args, arg)))
95 |
96 | rel2id = json.load(open(args.rel2id_file))
97 |
98 | # Download glove
99 | opennre.download('glove', root_path=root_path)
100 | word2id = json.load(open(os.path.join(root_path, 'pretrain/glove/glove.6B.50d_word2id.json')))
101 | word2vec = np.load(os.path.join(root_path, 'pretrain/glove/glove.6B.50d_mat.npy'))
102 |
103 | # Define the sentence encoder
104 | if args.encoder == 'pcnn':
105 | sentence_encoder = opennre.encoder.PCNNEncoder(
106 | token2id=word2id,
107 | max_length=args.max_length,
108 | word_size=50,
109 | position_size=5,
110 | hidden_size=230,
111 | blank_padding=True,
112 | kernel_size=3,
113 | padding_size=1,
114 | word2vec=word2vec,
115 | dropout=0.5
116 | )
117 | elif args.encoder == 'cnn':
118 | sentence_encoder = opennre.encoder.CNNEncoder(
119 | token2id=word2id,
120 | max_length=args.max_length,
121 | word_size=50,
122 | position_size=5,
123 | hidden_size=230,
124 | blank_padding=True,
125 | kernel_size=3,
126 | padding_size=1,
127 | word2vec=word2vec,
128 | dropout=0.5
129 | )
130 | else:
131 | raise NotImplementedError
132 |
133 |
134 | # Define the model
135 | if args.aggr == 'att':
136 | model = opennre.model.BagAttention(sentence_encoder, len(rel2id), rel2id)
137 | elif args.aggr == 'avg':
138 | model = opennre.model.BagAverage(sentence_encoder, len(rel2id), rel2id)
139 | elif args.aggr == 'one':
140 | model = opennre.model.BagOne(sentence_encoder, len(rel2id), rel2id)
141 | else:
142 | raise NotImplementedError
143 |
144 | # Define the whole training framework
145 | framework = opennre.framework.BagRE(
146 | train_path=args.train_file,
147 | val_path=args.val_file,
148 | test_path=args.test_file,
149 | model=model,
150 | ckpt=ckpt,
151 | batch_size=args.batch_size,
152 | max_epoch=args.max_epoch,
153 | lr=args.lr,
154 | weight_decay=args.weight_decay,
155 | opt=args.optim,
156 | bag_size=args.bag_size)
157 |
158 | # Train the model
159 | if not args.only_test:
160 | framework.train_model(args.metric)
161 |
162 | # Test the model
163 | framework.load_state_dict(torch.load(ckpt)['state_dict'])
164 | result = framework.eval_model(framework.test_loader)
165 |
166 | # Print the result
167 | logging.info('Test set results:')
168 | logging.info('AUC: %.5f' % (result['auc']))
169 | logging.info('Maximum micro F1: %.5f' % (result['max_micro_f1']))
170 | logging.info('Maximum macro F1: %.5f' % (result['max_macro_f1']))
171 | logging.info('Micro F1: %.5f' % (result['micro_f1']))
172 | logging.info('Macro F1: %.5f' % (result['macro_f1']))
173 | logging.info('P@100: %.5f' % (result['p@100']))
174 | logging.info('P@200: %.5f' % (result['p@200']))
175 | logging.info('P@300: %.5f' % (result['p@300']))
176 |
177 | # Save precision/recall points
178 | np.save('result/{}_p.npy'.format(args.result), result['np_prec'])
179 | np.save('result/{}_r.npy'.format(args.result), result['np_rec'])
180 | json.dump(result['max_micro_f1_each_relation'], open('result/{}_mmicrof1_rel.json'.format(args.result), 'w'), ensure_ascii=False)
181 |
--------------------------------------------------------------------------------
/example/train_supervised_bert.py:
--------------------------------------------------------------------------------
1 | # coding:utf-8
2 | import torch
3 | import numpy as np
4 | import json
5 | import opennre
6 | from opennre import encoder, model, framework
7 | import sys
8 | import os
9 | import argparse
10 | import logging
11 | import random
12 |
13 | def set_seed(seed):
14 | random.seed(seed)
15 | np.random.seed(seed)
16 | torch.manual_seed(seed)
17 | torch.cuda.manual_seed_all(seed)
18 |
19 |
20 | parser = argparse.ArgumentParser()
21 | parser.add_argument('--pretrain_path', default='bert-base-uncased',
22 | help='Pre-trained ckpt path / model name (hugginface)')
23 | parser.add_argument('--ckpt', default='',
24 | help='Checkpoint name')
25 | parser.add_argument('--pooler', default='entity', choices=['cls', 'entity'],
26 | help='Sentence representation pooler')
27 | parser.add_argument('--only_test', action='store_true',
28 | help='Only run test')
29 | parser.add_argument('--mask_entity', action='store_true',
30 | help='Mask entity mentions')
31 |
32 | # Data
33 | parser.add_argument('--metric', default='micro_f1', choices=['micro_f1', 'acc'],
34 | help='Metric for picking up best checkpoint')
35 | parser.add_argument('--dataset', default='none', choices=['none', 'semeval', 'wiki80', 'tacred'],
36 | help='Dataset. If not none, the following args can be ignored')
37 | parser.add_argument('--train_file', default='', type=str,
38 | help='Training data file')
39 | parser.add_argument('--val_file', default='', type=str,
40 | help='Validation data file')
41 | parser.add_argument('--test_file', default='', type=str,
42 | help='Test data file')
43 | parser.add_argument('--rel2id_file', default='', type=str,
44 | help='Relation to ID file')
45 |
46 | # Hyper-parameters
47 | parser.add_argument('--batch_size', default=64, type=int,
48 | help='Batch size')
49 | parser.add_argument('--lr', default=2e-5, type=float,
50 | help='Learning rate')
51 | parser.add_argument('--max_length', default=128, type=int,
52 | help='Maximum sentence length')
53 | parser.add_argument('--max_epoch', default=3, type=int,
54 | help='Max number of training epochs')
55 |
56 | # Seed
57 | parser.add_argument('--seed', default=42, type=int,
58 | help='Seed')
59 |
60 | args = parser.parse_args()
61 |
62 | # Set random seed
63 | set_seed(args.seed)
64 |
65 | # Some basic settings
66 | root_path = '.'
67 | sys.path.append(root_path)
68 | if not os.path.exists('ckpt'):
69 | os.mkdir('ckpt')
70 | if len(args.ckpt) == 0:
71 | args.ckpt = '{}_{}_{}'.format(args.dataset, args.pretrain_path, args.pooler)
72 | ckpt = 'ckpt/{}.pth.tar'.format(args.ckpt)
73 |
74 | if args.dataset != 'none':
75 | opennre.download(args.dataset, root_path=root_path)
76 | args.train_file = os.path.join(root_path, 'benchmark', args.dataset, '{}_train.txt'.format(args.dataset))
77 | args.val_file = os.path.join(root_path, 'benchmark', args.dataset, '{}_val.txt'.format(args.dataset))
78 | args.test_file = os.path.join(root_path, 'benchmark', args.dataset, '{}_test.txt'.format(args.dataset))
79 | if not os.path.exists(args.test_file):
80 | logging.warn("Test file {} does not exist! Use val file instead".format(args.test_file))
81 | args.test_file = args.val_file
82 | args.rel2id_file = os.path.join(root_path, 'benchmark', args.dataset, '{}_rel2id.json'.format(args.dataset))
83 | if args.dataset == 'wiki80':
84 | args.metric = 'acc'
85 | else:
86 | args.metric = 'micro_f1'
87 | else:
88 | if not (os.path.exists(args.train_file) and os.path.exists(args.val_file) and os.path.exists(args.test_file) and os.path.exists(args.rel2id_file)):
89 | raise Exception('--train_file, --val_file, --test_file and --rel2id_file are not specified or files do not exist. Or specify --dataset')
90 |
91 | logging.info('Arguments:')
92 | for arg in vars(args):
93 | logging.info(' {}: {}'.format(arg, getattr(args, arg)))
94 |
95 | rel2id = json.load(open(args.rel2id_file))
96 |
97 | # Define the sentence encoder
98 | if args.pooler == 'entity':
99 | sentence_encoder = opennre.encoder.BERTEntityEncoder(
100 | max_length=args.max_length,
101 | pretrain_path=args.pretrain_path,
102 | mask_entity=args.mask_entity
103 | )
104 | elif args.pooler == 'cls':
105 | sentence_encoder = opennre.encoder.BERTEncoder(
106 | max_length=args.max_length,
107 | pretrain_path=args.pretrain_path,
108 | mask_entity=args.mask_entity
109 | )
110 | else:
111 | raise NotImplementedError
112 |
113 | # Define the model
114 | model = opennre.model.SoftmaxNN(sentence_encoder, len(rel2id), rel2id)
115 |
116 | # Define the whole training framework
117 | framework = opennre.framework.SentenceRE(
118 | train_path=args.train_file,
119 | val_path=args.val_file,
120 | test_path=args.test_file,
121 | model=model,
122 | ckpt=ckpt,
123 | batch_size=args.batch_size,
124 | max_epoch=args.max_epoch,
125 | lr=args.lr,
126 | opt='adamw'
127 | )
128 |
129 | # Train the model
130 | if not args.only_test:
131 | framework.train_model('micro_f1')
132 |
133 | # Test
134 | framework.load_state_dict(torch.load(ckpt)['state_dict'])
135 | result = framework.eval_model(framework.test_loader)
136 |
137 | # Print the result
138 | logging.info('Test set results:')
139 | logging.info('Accuracy: {}'.format(result['acc']))
140 | logging.info('Micro precision: {}'.format(result['micro_p']))
141 | logging.info('Micro recall: {}'.format(result['micro_r']))
142 | logging.info('Micro F1: {}'.format(result['micro_f1']))
143 |
--------------------------------------------------------------------------------
/example/train_supervised_cnn.py:
--------------------------------------------------------------------------------
1 | # coding:utf-8
2 | import torch
3 | import numpy as np
4 | import json
5 | import opennre
6 | from opennre import encoder, model, framework
7 | import sys
8 | import os
9 | import argparse
10 | import logging
11 | import random
12 |
13 | def set_seed(seed):
14 | random.seed(seed)
15 | np.random.seed(seed)
16 | torch.manual_seed(seed)
17 | torch.cuda.manual_seed_all(seed)
18 |
19 | parser = argparse.ArgumentParser()
20 | parser.add_argument('--ckpt', default='',
21 | help='Checkpoint name')
22 | parser.add_argument('--only_test', action='store_true',
23 | help='Only run test')
24 | parser.add_argument('--encoder', default='pcnn', choices=['pcnn', 'cnn'])
25 |
26 | # Data
27 | parser.add_argument('--metric', default='micro_f1', choices=['micro_f1', 'acc'],
28 | help='Metric for picking up best checkpoint')
29 | parser.add_argument('--dataset', default='none', choices=['none', 'semeval', 'wiki80', 'tacred'],
30 | help='Dataset. If not none, the following args can be ignored')
31 | parser.add_argument('--train_file', default='', type=str,
32 | help='Training data file')
33 | parser.add_argument('--val_file', default='', type=str,
34 | help='Validation data file')
35 | parser.add_argument('--test_file', default='', type=str,
36 | help='Test data file')
37 | parser.add_argument('--rel2id_file', default='', type=str,
38 | help='Relation to ID file')
39 |
40 | # Hyper-parameters
41 | parser.add_argument('--batch_size', default=160, type=int,
42 | help='Batch size')
43 | parser.add_argument('--lr', default=1e-1, type=float,
44 | help='Learning rate')
45 | parser.add_argument('--weight_decay', default=1e-5, type=float,
46 | help='Weight decay')
47 | parser.add_argument('--max_length', default=128, type=int,
48 | help='Maximum sentence length')
49 | parser.add_argument('--max_epoch', default=100, type=int,
50 | help='Max number of training epochs')
51 |
52 | # Others
53 | parser.add_argument('--seed', default=42, type=int,
54 | help='Random seed')
55 |
56 | args = parser.parse_args()
57 |
58 | # Set random seed
59 | set_seed(args.seed)
60 |
61 | # Some basic settings
62 | root_path = '.'
63 | sys.path.append(root_path)
64 | if not os.path.exists('ckpt'):
65 | os.mkdir('ckpt')
66 | if len(args.ckpt) == 0:
67 | args.ckpt = '{}_{}'.format(args.dataset, 'cnn')
68 | ckpt = 'ckpt/{}.pth.tar'.format(args.ckpt)
69 |
70 | if args.dataset != 'none':
71 | opennre.download(args.dataset, root_path=root_path)
72 | args.train_file = os.path.join(root_path, 'benchmark', args.dataset, '{}_train.txt'.format(args.dataset))
73 | args.val_file = os.path.join(root_path, 'benchmark', args.dataset, '{}_val.txt'.format(args.dataset))
74 | args.test_file = os.path.join(root_path, 'benchmark', args.dataset, '{}_test.txt'.format(args.dataset))
75 | args.rel2id_file = os.path.join(root_path, 'benchmark', args.dataset, '{}_rel2id.json'.format(args.dataset))
76 | if args.dataset == 'wiki80':
77 | args.metric = 'acc'
78 | else:
79 | args.metric = 'micro_f1'
80 | else:
81 | if not (os.path.exists(args.train_file) and os.path.exists(args.val_file) and os.path.exists(args.test_file) and os.path.exists(args.rel2id_file)):
82 | raise Exception('--train_file, --val_file, --test_file and --rel2id_file are not specified or files do not exist. Or specify --dataset')
83 |
84 | logging.info('Arguments:')
85 | for arg in vars(args):
86 | logging.info(' {}: {}'.format(arg, getattr(args, arg)))
87 |
88 | rel2id = json.load(open(args.rel2id_file))
89 |
90 | # Download glove
91 | opennre.download('glove', root_path=root_path)
92 | word2id = json.load(open(os.path.join(root_path, 'pretrain/glove/glove.6B.50d_word2id.json')))
93 | word2vec = np.load(os.path.join(root_path, 'pretrain/glove/glove.6B.50d_mat.npy'))
94 |
95 | # Define the sentence encoder
96 | if args.encoder == 'pcnn':
97 | sentence_encoder = opennre.encoder.PCNNEncoder(
98 | token2id=word2id,
99 | max_length=args.max_length,
100 | word_size=50,
101 | position_size=5,
102 | hidden_size=230,
103 | blank_padding=True,
104 | kernel_size=3,
105 | padding_size=1,
106 | word2vec=word2vec,
107 | dropout=0.5
108 | )
109 | elif args.encoder == 'cnn':
110 | sentence_encoder = opennre.encoder.CNNEncoder(
111 | token2id=word2id,
112 | max_length=args.max_length,
113 | word_size=50,
114 | position_size=5,
115 | hidden_size=230,
116 | blank_padding=True,
117 | kernel_size=3,
118 | padding_size=1,
119 | word2vec=word2vec,
120 | dropout=0.5
121 | )
122 | else:
123 | raise NotImplementedError
124 |
125 | # Define the model
126 | model = opennre.model.SoftmaxNN(sentence_encoder, len(rel2id), rel2id)
127 |
128 | # Define the whole training framework
129 | framework = opennre.framework.SentenceRE(
130 | train_path=args.train_file,
131 | val_path=args.val_file,
132 | test_path=args.test_file,
133 | model=model,
134 | ckpt=ckpt,
135 | batch_size=args.batch_size,
136 | max_epoch=args.max_epoch,
137 | lr=args.lr,
138 | weight_decay=args.weight_decay,
139 | opt='sgd'
140 | )
141 |
142 | # Train the model
143 | if not args.only_test:
144 | framework.train_model(args.metric)
145 |
146 | # Test
147 | framework.load_state_dict(torch.load(ckpt)['state_dict'])
148 | result = framework.eval_model(framework.test_loader)
149 |
150 | # Print the result
151 | logging.info('Test set results:')
152 | if args.metric == 'acc':
153 | logging.info('Accuracy: {}'.format(result['acc']))
154 | else:
155 | logging.info('Micro precision: {}'.format(result['micro_p']))
156 | logging.info('Micro recall: {}'.format(result['micro_r']))
157 | logging.info('Micro F1: {}'.format(result['micro_f1']))
158 |
--------------------------------------------------------------------------------
/opennre/__init__.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | from .pretrain import check_root, get_model, download, download_pretrain
6 | import logging
7 | import os
8 |
9 | logging.basicConfig(format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', level=os.environ.get("LOGLEVEL", "INFO"))
10 |
11 | def fix_seed(seed=12345):
12 | import torch
13 | import numpy as np
14 | import random
15 | torch.manual_seed(seed) # cpu
16 | torch.cuda.manual_seed(seed) # gpu
17 | np.random.seed(seed) # numpy
18 | random.seed(seed) # random and transforms
19 | torch.backends.cudnn.deterministic=True # cudnn
20 |
--------------------------------------------------------------------------------
/opennre/encoder/__init__.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | from .cnn_encoder import CNNEncoder
6 | from .pcnn_encoder import PCNNEncoder
7 | from .bert_encoder import BERTEncoder, BERTEntityEncoder
8 |
9 | __all__ = [
10 | 'CNNEncoder',
11 | 'PCNNEncoder',
12 | 'BERTEncoder',
13 | 'BERTEntityEncoder'
14 | ]
--------------------------------------------------------------------------------
/opennre/encoder/base_encoder.py:
--------------------------------------------------------------------------------
1 | import math, logging
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | import numpy as np
6 | from ..tokenization import WordTokenizer
7 |
8 | class BaseEncoder(nn.Module):
9 |
10 | def __init__(self,
11 | token2id,
12 | max_length=128,
13 | hidden_size=230,
14 | word_size=50,
15 | position_size=5,
16 | blank_padding=True,
17 | word2vec=None,
18 | mask_entity=False):
19 | """
20 | Args:
21 | token2id: dictionary of token->idx mapping
22 | max_length: max length of sentence, used for postion embedding
23 | hidden_size: hidden size
24 | word_size: size of word embedding
25 | position_size: size of position embedding
26 | blank_padding: padding for CNN
27 | word2vec: pretrained word2vec numpy
28 | """
29 | # Hyperparameters
30 | super().__init__()
31 |
32 | self.token2id = token2id
33 | self.max_length = max_length
34 | self.num_token = len(token2id)
35 | self.num_position = max_length * 2
36 | self.mask_entity = mask_entity
37 |
38 | if word2vec is None:
39 | self.word_size = word_size
40 | else:
41 | self.word_size = word2vec.shape[-1]
42 |
43 | self.position_size = position_size
44 | self.hidden_size = hidden_size
45 | self.input_size = word_size + position_size * 2
46 | self.blank_padding = blank_padding
47 |
48 | if not '[UNK]' in self.token2id:
49 | self.token2id['[UNK]'] = len(self.token2id)
50 | self.num_token += 1
51 | if not '[PAD]' in self.token2id:
52 | self.token2id['[PAD]'] = len(self.token2id)
53 | self.num_token += 1
54 |
55 | # Word embedding
56 | self.word_embedding = nn.Embedding(self.num_token, self.word_size)
57 | if word2vec is not None:
58 | logging.info("Initializing word embedding with word2vec.")
59 | word2vec = torch.from_numpy(word2vec)
60 | if self.num_token == len(word2vec) + 2:
61 | unk = torch.randn(1, self.word_size) / math.sqrt(self.word_size)
62 | blk = torch.zeros(1, self.word_size)
63 | self.word_embedding.weight.data.copy_(torch.cat([word2vec, unk, blk], 0))
64 | else:
65 | self.word_embedding.weight.data.copy_(word2vec)
66 |
67 | # Position Embedding
68 | self.pos1_embedding = nn.Embedding(2 * max_length, self.position_size, padding_idx=0)
69 | self.pos2_embedding = nn.Embedding(2 * max_length, self.position_size, padding_idx=0)
70 | self.tokenizer = WordTokenizer(vocab=self.token2id, unk_token="[UNK]")
71 |
72 | def forward(self, token, pos1, pos2):
73 | """
74 | Args:
75 | token: (B, L), index of tokens
76 | pos1: (B, L), relative position to head entity
77 | pos2: (B, L), relative position to tail entity
78 | Return:
79 | (B, H), representations for sentences
80 | """
81 | # Check size of tensors
82 | pass
83 |
84 | def tokenize(self, item):
85 | """
86 | Args:
87 | item: input instance, including sentence, entity positions, etc.
88 | Return:
89 | index number of tokens and positions
90 | """
91 | if 'text' in item:
92 | sentence = item['text']
93 | is_token = False
94 | else:
95 | sentence = item['token']
96 | is_token = True
97 | pos_head = item['h']['pos']
98 | pos_tail = item['t']['pos']
99 |
100 | # Sentence -> token
101 | if not is_token:
102 | if pos_head[0] > pos_tail[0]:
103 | pos_min, pos_max = [pos_tail, pos_head]
104 | rev = True
105 | else:
106 | pos_min, pos_max = [pos_head, pos_tail]
107 | rev = False
108 | sent_0 = self.tokenizer.tokenize(sentence[:pos_min[0]])
109 | sent_1 = self.tokenizer.tokenize(sentence[pos_min[1]:pos_max[0]])
110 | sent_2 = self.tokenizer.tokenize(sentence[pos_max[1]:])
111 | ent_0 = self.tokenizer.tokenize(sentence[pos_min[0]:pos_min[1]])
112 | ent_1 = self.tokenizer.tokenize(sentence[pos_max[0]:pos_max[1]])
113 | if self.mask_entity:
114 | ent_0 = ['[UNK]']
115 | ent_1 = ['[UNK]']
116 | tokens = sent_0 + ent_0 + sent_1 + ent_1 + sent_2
117 | if rev:
118 | pos_tail = [len(sent_0), len(sent_0) + len(ent_0)]
119 | pos_head = [len(sent_0) + len(ent_0) + len(sent_1), len(sent_0) + len(ent_0) + len(sent_1) + len(ent_1)]
120 | else:
121 | pos_head = [len(sent_0), len(sent_0) + len(ent_0)]
122 | pos_tail = [len(sent_0) + len(ent_0) + len(sent_1), len(sent_0) + len(ent_0) + len(sent_1) + len(ent_1)]
123 | else:
124 | tokens = sentence
125 |
126 | # Token -> index
127 | if self.blank_padding:
128 | indexed_tokens = self.tokenizer.convert_tokens_to_ids(tokens, self.max_length, self.token2id['[PAD]'], self.token2id['[UNK]'])
129 | else:
130 | indexed_tokens = self.tokenizer.convert_tokens_to_ids(tokens, unk_id = self.token2id['[UNK]'])
131 |
132 | # Position -> index
133 | pos1 = []
134 | pos2 = []
135 | pos1_in_index = min(pos_head[0], self.max_length)
136 | pos2_in_index = min(pos_tail[0], self.max_length)
137 | for i in range(len(tokens)):
138 | pos1.append(min(i - pos1_in_index + self.max_length, 2 * self.max_length - 1))
139 | pos2.append(min(i - pos2_in_index + self.max_length, 2 * self.max_length - 1))
140 |
141 | if self.blank_padding:
142 | while len(pos1) < self.max_length:
143 | pos1.append(0)
144 | while len(pos2) < self.max_length:
145 | pos2.append(0)
146 | indexed_tokens = indexed_tokens[:self.max_length]
147 | pos1 = pos1[:self.max_length]
148 | pos2 = pos2[:self.max_length]
149 |
150 | indexed_tokens = torch.tensor(indexed_tokens).long().unsqueeze(0) # (1, L)
151 | pos1 = torch.tensor(pos1).long().unsqueeze(0) # (1, L)
152 | pos2 = torch.tensor(pos2).long().unsqueeze(0) # (1, L)
153 |
154 | return indexed_tokens, pos1, pos2
155 |
--------------------------------------------------------------------------------
/opennre/encoder/bert_encoder.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import torch
3 | import torch.nn as nn
4 | from transformers import BertModel, BertTokenizer
5 | from .base_encoder import BaseEncoder
6 |
7 | class BERTEncoder(nn.Module):
8 | def __init__(self, max_length, pretrain_path, blank_padding=True, mask_entity=False):
9 | """
10 | Args:
11 | max_length: max length of sentence
12 | pretrain_path: path of pretrain model
13 | """
14 | super().__init__()
15 | self.max_length = max_length
16 | self.blank_padding = blank_padding
17 | self.hidden_size = 768
18 | self.mask_entity = mask_entity
19 | logging.info('Loading BERT pre-trained checkpoint.')
20 | self.bert = BertModel.from_pretrained(pretrain_path)
21 | self.tokenizer = BertTokenizer.from_pretrained(pretrain_path)
22 |
23 | def forward(self, token, att_mask, pos1, pos2):
24 | """
25 | Args:
26 | token: (B, L), index of tokens
27 | att_mask: (B, L), attention mask (1 for contents and 0 for padding)
28 | Return:
29 | (B, H), representations for sentences
30 | """
31 | _, x = self.bert(token, attention_mask=att_mask, return_dict=False)
32 | #return_dict=fault is set to adapt to the new version of transformers
33 | return x
34 |
35 | def tokenize(self, item):
36 | """
37 | Args:
38 | item: data instance containing 'text' / 'token', 'h' and 't'
39 | Return:
40 | Name of the relation of the sentence
41 | """
42 | # Sentence -> token
43 | if 'text' in item:
44 | sentence = item['text']
45 | is_token = False
46 | else:
47 | sentence = item['token']
48 | is_token = True
49 | pos_head = item['h']['pos']
50 | pos_tail = item['t']['pos']
51 |
52 | pos_min = pos_head
53 | pos_max = pos_tail
54 | if pos_head[0] > pos_tail[0]:
55 | pos_min = pos_tail
56 | pos_max = pos_head
57 | rev = True
58 | else:
59 | rev = False
60 |
61 | if not is_token:
62 | sent0 = self.tokenizer.tokenize(sentence[:pos_min[0]])
63 | ent0 = self.tokenizer.tokenize(sentence[pos_min[0]:pos_min[1]])
64 | sent1 = self.tokenizer.tokenize(sentence[pos_min[1]:pos_max[0]])
65 | ent1 = self.tokenizer.tokenize(sentence[pos_max[0]:pos_max[1]])
66 | sent2 = self.tokenizer.tokenize(sentence[pos_max[1]:])
67 | else:
68 | sent0 = self.tokenizer.tokenize(' '.join(sentence[:pos_min[0]]))
69 | ent0 = self.tokenizer.tokenize(' '.join(sentence[pos_min[0]:pos_min[1]]))
70 | sent1 = self.tokenizer.tokenize(' '.join(sentence[pos_min[1]:pos_max[0]]))
71 | ent1 = self.tokenizer.tokenize(' '.join(sentence[pos_max[0]:pos_max[1]]))
72 | sent2 = self.tokenizer.tokenize(' '.join(sentence[pos_max[1]:]))
73 |
74 | if self.mask_entity:
75 | ent0 = ['[unused4]'] if not rev else ['[unused5]']
76 | ent1 = ['[unused5]'] if not rev else ['[unused4]']
77 | else:
78 | ent0 = ['[unused0]'] + ent0 + ['[unused1]'] if not rev else ['[unused2]'] + ent0 + ['[unused3]']
79 | ent1 = ['[unused2]'] + ent1 + ['[unused3]'] if not rev else ['[unused0]'] + ent1 + ['[unused1]']
80 |
81 | re_tokens = ['[CLS]'] + sent0 + ent0 + sent1 + ent1 + sent2 + ['[SEP]']
82 |
83 | pos1 = 1 + len(sent0) if not rev else 1 + len(sent0 + ent0 + sent1)
84 | pos2 = 1 + len(sent0 + ent0 + sent1) if not rev else 1 + len(sent0)
85 | pos1 = min(self.max_length - 1, pos1)
86 | pos2 = min(self.max_length - 1, pos2)
87 | indexed_tokens = self.tokenizer.convert_tokens_to_ids(re_tokens)
88 | avai_len = len(indexed_tokens)
89 | pos1 = torch.tensor([[pos1]]).long()
90 | pos2 = torch.tensor([[pos2]]).long()
91 |
92 | # Padding
93 | if self.blank_padding:
94 | while len(indexed_tokens) < self.max_length:
95 | indexed_tokens.append(0) # 0 is id for [PAD]
96 | indexed_tokens = indexed_tokens[:self.max_length]
97 | indexed_tokens = torch.tensor(indexed_tokens).long().unsqueeze(0) # (1, L)
98 |
99 | # Attention mask
100 | att_mask = torch.zeros(indexed_tokens.size()).long() # (1, L)
101 | att_mask[0, :avai_len] = 1
102 |
103 | return indexed_tokens, att_mask, pos1, pos2
104 |
105 |
106 | class BERTEntityEncoder(nn.Module):
107 | def __init__(self, max_length, pretrain_path, blank_padding=True, mask_entity=False):
108 | """
109 | Args:
110 | max_length: max length of sentence
111 | pretrain_path: path of pretrain model
112 | """
113 | super().__init__()
114 | self.max_length = max_length
115 | self.blank_padding = blank_padding
116 | self.hidden_size = 768 * 2
117 | self.mask_entity = mask_entity
118 | logging.info('Loading BERT pre-trained checkpoint.')
119 | self.bert = BertModel.from_pretrained(pretrain_path)
120 | self.tokenizer = BertTokenizer.from_pretrained(pretrain_path)
121 | self.linear = nn.Linear(self.hidden_size, self.hidden_size)
122 |
123 | def forward(self, token, att_mask, pos1, pos2):
124 | """
125 | Args:
126 | token: (B, L), index of tokens
127 | att_mask: (B, L), attention mask (1 for contents and 0 for padding)
128 | pos1: (B, 1), position of the head entity starter
129 | pos2: (B, 1), position of the tail entity starter
130 | Return:
131 | (B, 2H), representations for sentences
132 | """
133 | hidden, _ = self.bert(token, attention_mask=att_mask, return_dict=False)
134 | # Get entity start hidden state
135 | onehot_head = torch.zeros(hidden.size()[:2]).float().to(hidden.device) # (B, L)
136 | onehot_tail = torch.zeros(hidden.size()[:2]).float().to(hidden.device) # (B, L)
137 | onehot_head = onehot_head.scatter_(1, pos1, 1)
138 | onehot_tail = onehot_tail.scatter_(1, pos2, 1)
139 | head_hidden = (onehot_head.unsqueeze(2) * hidden).sum(1) # (B, H)
140 | tail_hidden = (onehot_tail.unsqueeze(2) * hidden).sum(1) # (B, H)
141 | x = torch.cat([head_hidden, tail_hidden], 1) # (B, 2H)
142 | x = self.linear(x)
143 | return x
144 |
145 | def tokenize(self, item):
146 | """
147 | Args:
148 | item: data instance containing 'text' / 'token', 'h' and 't'
149 | Return:
150 | Name of the relation of the sentence
151 | """
152 | # Sentence -> token
153 | if 'text' in item:
154 | sentence = item['text']
155 | is_token = False
156 | else:
157 | sentence = item['token']
158 | is_token = True
159 | pos_head = item['h']['pos']
160 | pos_tail = item['t']['pos']
161 |
162 | pos_min = pos_head
163 | pos_max = pos_tail
164 | if pos_head[0] > pos_tail[0]:
165 | pos_min = pos_tail
166 | pos_max = pos_head
167 | rev = True
168 | else:
169 | rev = False
170 |
171 | if not is_token:
172 | sent0 = self.tokenizer.tokenize(sentence[:pos_min[0]])
173 | ent0 = self.tokenizer.tokenize(sentence[pos_min[0]:pos_min[1]])
174 | sent1 = self.tokenizer.tokenize(sentence[pos_min[1]:pos_max[0]])
175 | ent1 = self.tokenizer.tokenize(sentence[pos_max[0]:pos_max[1]])
176 | sent2 = self.tokenizer.tokenize(sentence[pos_max[1]:])
177 | else:
178 | sent0 = self.tokenizer.tokenize(' '.join(sentence[:pos_min[0]]))
179 | ent0 = self.tokenizer.tokenize(' '.join(sentence[pos_min[0]:pos_min[1]]))
180 | sent1 = self.tokenizer.tokenize(' '.join(sentence[pos_min[1]:pos_max[0]]))
181 | ent1 = self.tokenizer.tokenize(' '.join(sentence[pos_max[0]:pos_max[1]]))
182 | sent2 = self.tokenizer.tokenize(' '.join(sentence[pos_max[1]:]))
183 |
184 | if self.mask_entity:
185 | ent0 = ['[unused4]'] if not rev else ['[unused5]']
186 | ent1 = ['[unused5]'] if not rev else ['[unused4]']
187 | else:
188 | ent0 = ['[unused0]'] + ent0 + ['[unused1]'] if not rev else ['[unused2]'] + ent0 + ['[unused3]']
189 | ent1 = ['[unused2]'] + ent1 + ['[unused3]'] if not rev else ['[unused0]'] + ent1 + ['[unused1]']
190 |
191 | re_tokens = ['[CLS]'] + sent0 + ent0 + sent1 + ent1 + sent2 + ['[SEP]']
192 | pos1 = 1 + len(sent0) if not rev else 1 + len(sent0 + ent0 + sent1)
193 | pos2 = 1 + len(sent0 + ent0 + sent1) if not rev else 1 + len(sent0)
194 | pos1 = min(self.max_length - 1, pos1)
195 | pos2 = min(self.max_length - 1, pos2)
196 |
197 | indexed_tokens = self.tokenizer.convert_tokens_to_ids(re_tokens)
198 | avai_len = len(indexed_tokens)
199 |
200 | # Position
201 | pos1 = torch.tensor([[pos1]]).long()
202 | pos2 = torch.tensor([[pos2]]).long()
203 |
204 | # Padding
205 | if self.blank_padding:
206 | while len(indexed_tokens) < self.max_length:
207 | indexed_tokens.append(0) # 0 is id for [PAD]
208 | indexed_tokens = indexed_tokens[:self.max_length]
209 | indexed_tokens = torch.tensor(indexed_tokens).long().unsqueeze(0) # (1, L)
210 |
211 | # Attention mask
212 | att_mask = torch.zeros(indexed_tokens.size()).long() # (1, L)
213 | att_mask[0, :avai_len] = 1
214 |
215 | return indexed_tokens, att_mask, pos1, pos2
216 |
--------------------------------------------------------------------------------
/opennre/encoder/cnn_encoder.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | from ..module.nn import CNN
6 | from ..module.pool import MaxPool
7 | from .base_encoder import BaseEncoder
8 |
9 | class CNNEncoder(BaseEncoder):
10 |
11 | def __init__(self,
12 | token2id,
13 | max_length=128,
14 | hidden_size=230,
15 | word_size=50,
16 | position_size=5,
17 | blank_padding=True,
18 | word2vec=None,
19 | kernel_size=3,
20 | padding_size=1,
21 | dropout=0,
22 | activation_function=F.relu,
23 | mask_entity=False):
24 | """
25 | Args:
26 | token2id: dictionary of token->idx mapping
27 | max_length: max length of sentence, used for postion embedding
28 | hidden_size: hidden size
29 | word_size: size of word embedding
30 | position_size: size of position embedding
31 | blank_padding: padding for CNN
32 | word2vec: pretrained word2vec numpy
33 | kernel_size: kernel_size size for CNN
34 | padding_size: padding_size for CNN
35 | """
36 | # Hyperparameters
37 | super(CNNEncoder, self).__init__(token2id, max_length, hidden_size, word_size, position_size, blank_padding, word2vec, mask_entity=mask_entity)
38 | self.drop = nn.Dropout(dropout)
39 | self.kernel_size = kernel_size
40 | self.padding_size = padding_size
41 | self.act = activation_function
42 |
43 | self.conv = nn.Conv1d(self.input_size, self.hidden_size, self.kernel_size, padding=self.padding_size)
44 | self.pool = nn.MaxPool1d(self.max_length)
45 |
46 | def forward(self, token, pos1, pos2):
47 | """
48 | Args:
49 | token: (B, L), index of tokens
50 | pos1: (B, L), relative position to head entity
51 | pos2: (B, L), relative position to tail entity
52 | Return:
53 | (B, EMBED), representations for sentences
54 | """
55 | # Check size of tensors
56 | if len(token.size()) != 2 or token.size() != pos1.size() or token.size() != pos2.size():
57 | raise Exception("Size of token, pos1 ans pos2 should be (B, L)")
58 | x = torch.cat([self.word_embedding(token),
59 | self.pos1_embedding(pos1),
60 | self.pos2_embedding(pos2)], 2) # (B, L, EMBED)
61 | x = x.transpose(1, 2) # (B, EMBED, L)
62 | x = self.act(self.conv(x)) # (B, H, L)
63 | x = self.pool(x).squeeze(-1)
64 | x = self.drop(x)
65 | return x
66 |
67 | def tokenize(self, item):
68 | return super().tokenize(item)
69 |
--------------------------------------------------------------------------------
/opennre/encoder/pcnn_encoder.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | from ..module.nn import CNN
6 | from ..module.pool import MaxPool
7 | from .base_encoder import BaseEncoder
8 |
9 | from nltk import word_tokenize
10 |
11 | class PCNNEncoder(BaseEncoder):
12 |
13 | def __init__(self,
14 | token2id,
15 | max_length=128,
16 | hidden_size=230,
17 | word_size=50,
18 | position_size=5,
19 | blank_padding=True,
20 | word2vec=None,
21 | kernel_size=3,
22 | padding_size=1,
23 | dropout=0.0,
24 | activation_function=F.relu,
25 | mask_entity=False):
26 | """
27 | Args:
28 | token2id: dictionary of token->idx mapping
29 | max_length: max length of sentence, used for postion embedding
30 | hidden_size: hidden size
31 | word_size: size of word embedding
32 | position_size: size of position embedding
33 | blank_padding: padding for CNN
34 | word2vec: pretrained word2vec numpy
35 | kernel_size: kernel_size size for CNN
36 | padding_size: padding_size for CNN
37 | """
38 | # hyperparameters
39 | super().__init__(token2id, max_length, hidden_size, word_size, position_size, blank_padding, word2vec, mask_entity=mask_entity)
40 | self.drop = nn.Dropout(dropout)
41 | self.kernel_size = kernel_size
42 | self.padding_size = padding_size
43 | self.act = activation_function
44 |
45 | self.conv = nn.Conv1d(self.input_size, self.hidden_size, self.kernel_size, padding=self.padding_size)
46 | self.pool = nn.MaxPool1d(self.max_length)
47 | self.mask_embedding = nn.Embedding(4, 3)
48 | self.mask_embedding.weight.data.copy_(torch.FloatTensor([[0, 0, 0], [1, 0, 0], [0, 1, 0], [0, 0, 1]]))
49 | self.mask_embedding.weight.requires_grad = False
50 | self._minus = -100
51 |
52 | self.hidden_size *= 3
53 |
54 | def forward(self, token, pos1, pos2, mask):
55 | """
56 | Args:
57 | token: (B, L), index of tokens
58 | pos1: (B, L), relative position to head entity
59 | pos2: (B, L), relative position to tail entity
60 | Return:
61 | (B, EMBED), representations for sentences
62 | """
63 | # Check size of tensors
64 | if len(token.size()) != 2 or token.size() != pos1.size() or token.size() != pos2.size():
65 | raise Exception("Size of token, pos1 ans pos2 should be (B, L)")
66 | x = torch.cat([self.word_embedding(token),
67 | self.pos1_embedding(pos1),
68 | self.pos2_embedding(pos2)], 2) # (B, L, EMBED)
69 | x = x.transpose(1, 2) # (B, EMBED, L)
70 | x = self.conv(x) # (B, H, L)
71 |
72 | mask = 1 - self.mask_embedding(mask).transpose(1, 2) # (B, L) -> (B, L, 3) -> (B, 3, L)
73 | pool1 = self.pool(self.act(x + self._minus * mask[:, 0:1, :])) # (B, H, 1)
74 | pool2 = self.pool(self.act(x + self._minus * mask[:, 1:2, :]))
75 | pool3 = self.pool(self.act(x + self._minus * mask[:, 2:3, :]))
76 | x = torch.cat([pool1, pool2, pool3], 1) # (B, 3H, 1)
77 | x = x.squeeze(2) # (B, 3H)
78 | x = self.drop(x)
79 |
80 | return x
81 |
82 | def tokenize(self, item):
83 | """
84 | Args:
85 | sentence: string, the input sentence
86 | pos_head: [start, end], position of the head entity
87 | pos_end: [start, end], position of the tail entity
88 | is_token: if is_token == True, sentence becomes an array of token
89 | Return:
90 | Name of the relation of the sentence
91 | """
92 | if 'text' in item:
93 | sentence = item['text']
94 | is_token = False
95 | else:
96 | sentence = item['token']
97 | is_token = True
98 | pos_head = item['h']['pos']
99 | pos_tail = item['t']['pos']
100 |
101 | # Sentence -> token
102 | if not is_token:
103 | if pos_head[0] > pos_tail[0]:
104 | pos_min, pos_max = [pos_tail, pos_head]
105 | rev = True
106 | else:
107 | pos_min, pos_max = [pos_head, pos_tail]
108 | rev = False
109 | sent_0 = self.tokenizer.tokenize(sentence[:pos_min[0]])
110 | sent_1 = self.tokenizer.tokenize(sentence[pos_min[1]:pos_max[0]])
111 | sent_2 = self.tokenizer.tokenize(sentence[pos_max[1]:])
112 | ent_0 = self.tokenizer.tokenize(sentence[pos_min[0]:pos_min[1]])
113 | ent_1 = self.tokenizer.tokenize(sentence[pos_max[0]:pos_max[1]])
114 | if self.mask_entity:
115 | ent_0 = ['[UNK]']
116 | ent_1 = ['[UNK]']
117 | tokens = sent_0 + ent_0 + sent_1 + ent_1 + sent_2
118 | if rev:
119 | pos_tail = [len(sent_0), len(sent_0) + len(ent_0)]
120 | pos_head = [len(sent_0) + len(ent_0) + len(sent_1), len(sent_0) + len(ent_0) + len(sent_1) + len(ent_1)]
121 | else:
122 | pos_head = [len(sent_0), len(sent_0) + len(ent_0)]
123 | pos_tail = [len(sent_0) + len(ent_0) + len(sent_1), len(sent_0) + len(ent_0) + len(sent_1) + len(ent_1)]
124 | else:
125 | tokens = sentence
126 |
127 | # Token -> index
128 | if self.blank_padding:
129 | indexed_tokens = self.tokenizer.convert_tokens_to_ids(tokens, self.max_length, self.token2id['[PAD]'], self.token2id['[UNK]'])
130 | else:
131 | indexed_tokens = self.tokenizer.convert_tokens_to_ids(tokens, unk_id = self.token2id['[UNK]'])
132 |
133 | # Position -> index
134 | pos1 = []
135 | pos2 = []
136 | pos1_in_index = min(pos_head[0], self.max_length)
137 | pos2_in_index = min(pos_tail[0], self.max_length)
138 | for i in range(len(tokens)):
139 | pos1.append(min(i - pos1_in_index + self.max_length, 2 * self.max_length - 1))
140 | pos2.append(min(i - pos2_in_index + self.max_length, 2 * self.max_length - 1))
141 |
142 | if self.blank_padding:
143 | while len(pos1) < self.max_length:
144 | pos1.append(0)
145 | while len(pos2) < self.max_length:
146 | pos2.append(0)
147 | indexed_tokens = indexed_tokens[:self.max_length]
148 | pos1 = pos1[:self.max_length]
149 | pos2 = pos2[:self.max_length]
150 |
151 | indexed_tokens = torch.tensor(indexed_tokens).long().unsqueeze(0) # (1, L)
152 | pos1 = torch.tensor(pos1).long().unsqueeze(0) # (1, L)
153 | pos2 = torch.tensor(pos2).long().unsqueeze(0) # (1, L)
154 |
155 | # Mask
156 | mask = []
157 | pos_min = min(pos1_in_index, pos2_in_index)
158 | pos_max = max(pos1_in_index, pos2_in_index)
159 | for i in range(len(tokens)):
160 | if i <= pos_min:
161 | mask.append(1)
162 | elif i <= pos_max:
163 | mask.append(2)
164 | else:
165 | mask.append(3)
166 | # Padding
167 | if self.blank_padding:
168 | while len(mask) < self.max_length:
169 | mask.append(0)
170 | mask = mask[:self.max_length]
171 |
172 | mask = torch.tensor(mask).long().unsqueeze(0) # (1, L)
173 | return indexed_tokens, pos1, pos2, mask
174 |
--------------------------------------------------------------------------------
/opennre/framework/__init__.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | from .data_loader import SentenceREDataset, SentenceRELoader, BagREDataset, BagRELoader, MultiLabelSentenceREDataset, MultiLabelSentenceRELoader
6 | from .sentence_re import SentenceRE
7 | from .bag_re import BagRE
8 | from .multi_label_sentence_re import MultiLabelSentenceRE
9 |
10 | __all__ = [
11 | 'SentenceREDataset',
12 | 'SentenceRELoader',
13 | 'SentenceRE',
14 | 'BagRE',
15 | 'BagREDataset',
16 | 'BagRELoader',
17 | 'MultiLabelSentenceREDataset',
18 | 'MultiLabelSentenceRELoader',
19 | 'MultiLabelSentenceRE'
20 | ]
21 |
--------------------------------------------------------------------------------
/opennre/framework/bag_re.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn, optim
3 | import json
4 | from .data_loader import SentenceRELoader, BagRELoader
5 | from .utils import AverageMeter
6 | from tqdm import tqdm
7 | import os
8 |
9 | class BagRE(nn.Module):
10 |
11 | def __init__(self,
12 | model,
13 | train_path,
14 | val_path,
15 | test_path,
16 | ckpt,
17 | batch_size=32,
18 | max_epoch=100,
19 | lr=0.1,
20 | weight_decay=1e-5,
21 | opt='sgd',
22 | bag_size=0,
23 | loss_weight=False):
24 |
25 | super().__init__()
26 | self.max_epoch = max_epoch
27 | self.bag_size = bag_size
28 | # Load data
29 | if train_path != None:
30 | self.train_loader = BagRELoader(
31 | train_path,
32 | model.rel2id,
33 | model.sentence_encoder.tokenize,
34 | batch_size,
35 | True,
36 | bag_size=bag_size,
37 | entpair_as_bag=False)
38 |
39 | if val_path != None:
40 | self.val_loader = BagRELoader(
41 | val_path,
42 | model.rel2id,
43 | model.sentence_encoder.tokenize,
44 | batch_size,
45 | False,
46 | bag_size=bag_size,
47 | entpair_as_bag=True)
48 |
49 | if test_path != None:
50 | self.test_loader = BagRELoader(
51 | test_path,
52 | model.rel2id,
53 | model.sentence_encoder.tokenize,
54 | batch_size,
55 | False,
56 | bag_size=bag_size,
57 | entpair_as_bag=True
58 | )
59 | # Model
60 | self.model = nn.DataParallel(model)
61 | # Criterion
62 | if loss_weight:
63 | self.criterion = nn.CrossEntropyLoss(weight=self.train_loader.dataset.weight)
64 | else:
65 | self.criterion = nn.CrossEntropyLoss()
66 | # Params and optimizer
67 | params = self.model.parameters()
68 | self.lr = lr
69 | if opt == 'sgd':
70 | self.optimizer = optim.SGD(params, lr, weight_decay=weight_decay)
71 | elif opt == 'adam':
72 | self.optimizer = optim.Adam(params, lr, weight_decay=weight_decay)
73 | elif opt == 'adamw':
74 | from transformers import AdamW
75 | params = list(self.named_parameters())
76 | no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
77 | grouped_params = [
78 | {
79 | 'params': [p for n, p in params if not any(nd in n for nd in no_decay)],
80 | 'weight_decay': 0.01,
81 | 'lr': lr,
82 | 'ori_lr': lr
83 | },
84 | {
85 | 'params': [p for n, p in params if any(nd in n for nd in no_decay)],
86 | 'weight_decay': 0.0,
87 | 'lr': lr,
88 | 'ori_lr': lr
89 | }
90 | ]
91 | self.optimizer = AdamW(grouped_params, correct_bias=False)
92 | else:
93 | raise Exception("Invalid optimizer. Must be 'sgd' or 'adam' or 'bert_adam'.")
94 | # Cuda
95 | if torch.cuda.is_available():
96 | self.cuda()
97 | # Ckpt
98 | self.ckpt = ckpt
99 |
100 | def train_model(self, metric='auc'):
101 | best_metric = 0
102 | for epoch in range(self.max_epoch):
103 | # Train
104 | self.train()
105 | print("=== Epoch %d train ===" % epoch)
106 | avg_loss = AverageMeter()
107 | avg_acc = AverageMeter()
108 | avg_pos_acc = AverageMeter()
109 | t = tqdm(self.train_loader)
110 | for iter, data in enumerate(t):
111 | if torch.cuda.is_available():
112 | for i in range(len(data)):
113 | try:
114 | data[i] = data[i].cuda()
115 | except:
116 | pass
117 | label = data[0]
118 | bag_name = data[1]
119 | scope = data[2]
120 | args = data[3:]
121 | logits = self.model(label, scope, *args, bag_size=self.bag_size)
122 | loss = self.criterion(logits, label)
123 | score, pred = logits.max(-1) # (B)
124 | acc = float((pred == label).long().sum()) / label.size(0)
125 | pos_total = (label != 0).long().sum()
126 | pos_correct = ((pred == label).long() * (label != 0).long()).sum()
127 | if pos_total > 0:
128 | pos_acc = float(pos_correct) / float(pos_total)
129 | else:
130 | pos_acc = 0
131 |
132 | # Log
133 | avg_loss.update(loss.item(), 1)
134 | avg_acc.update(acc, 1)
135 | avg_pos_acc.update(pos_acc, 1)
136 | t.set_postfix(loss=avg_loss.avg, acc=avg_acc.avg, pos_acc=avg_pos_acc.avg)
137 |
138 | # Optimize
139 | loss.backward()
140 | self.optimizer.step()
141 | self.optimizer.zero_grad()
142 |
143 | # Val
144 | print("=== Epoch %d val ===" % epoch)
145 | result = self.eval_model(self.val_loader)
146 | print("AUC: %.4f" % result['auc'])
147 | print("Micro F1: %.4f" % (result['max_micro_f1']))
148 | if result[metric] > best_metric:
149 | print("Best ckpt and saved.")
150 | torch.save({'state_dict': self.model.module.state_dict()}, self.ckpt)
151 | best_metric = result[metric]
152 | print("Best %s on val set: %f" % (metric, best_metric))
153 |
154 | def eval_model(self, eval_loader):
155 | self.model.eval()
156 | with torch.no_grad():
157 | t = tqdm(eval_loader)
158 | pred_result = []
159 | for iter, data in enumerate(t):
160 | if torch.cuda.is_available():
161 | for i in range(len(data)):
162 | try:
163 | data[i] = data[i].cuda()
164 | except:
165 | pass
166 | label = data[0]
167 | bag_name = data[1]
168 | scope = data[2]
169 | args = data[3:]
170 | logits = self.model(None, scope, *args, train=False, bag_size=self.bag_size) # results after softmax
171 | logits = logits.cpu().numpy()
172 | for i in range(len(logits)):
173 | for relid in range(self.model.module.num_class):
174 | if self.model.module.id2rel[relid] != 'NA':
175 | pred_result.append({
176 | 'entpair': bag_name[i][:2],
177 | 'relation': self.model.module.id2rel[relid],
178 | 'score': logits[i][relid]
179 | })
180 | result = eval_loader.dataset.eval(pred_result)
181 | return result
182 |
183 | def load_state_dict(self, state_dict):
184 | self.model.module.load_state_dict(state_dict)
185 |
--------------------------------------------------------------------------------
/opennre/framework/data_loader.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.utils.data as data
3 | import os, random, json, logging
4 | import numpy as np
5 | import sklearn.metrics
6 |
7 | class SentenceREDataset(data.Dataset):
8 | """
9 | Sentence-level relation extraction dataset
10 | """
11 | def __init__(self, path, rel2id, tokenizer, kwargs):
12 | """
13 | Args:
14 | path: path of the input file
15 | rel2id: dictionary of relation->id mapping
16 | tokenizer: function of tokenizing
17 | """
18 | super().__init__()
19 | self.path = path
20 | self.tokenizer = tokenizer
21 | self.rel2id = rel2id
22 | self.kwargs = kwargs
23 |
24 | # Load the file
25 | f = open(path)
26 | self.data = []
27 | for line in f.readlines():
28 | line = line.rstrip()
29 | if len(line) > 0:
30 | self.data.append(eval(line))
31 | f.close()
32 | logging.info("Loaded sentence RE dataset {} with {} lines and {} relations.".format(path, len(self.data), len(self.rel2id)))
33 |
34 | def __len__(self):
35 | return len(self.data)
36 |
37 | def __getitem__(self, index):
38 | item = self.data[index]
39 | seq = list(self.tokenizer(item, **self.kwargs))
40 | res = [self.rel2id[item['relation']]] + seq
41 | return [self.rel2id[item['relation']]] + seq # label, seq1, seq2, ...
42 |
43 | def collate_fn(data):
44 | data = list(zip(*data))
45 | labels = data[0]
46 | seqs = data[1:]
47 | batch_labels = torch.tensor(labels).long() # (B)
48 | batch_seqs = []
49 | for seq in seqs:
50 | batch_seqs.append(torch.cat(seq, 0)) # (B, L)
51 | return [batch_labels] + batch_seqs
52 |
53 | def eval(self, pred_result, use_name=False):
54 | """
55 | Args:
56 | pred_result: a list of predicted label (id)
57 | Make sure that the `shuffle` param is set to `False` when getting the loader.
58 | use_name: if True, `pred_result` contains predicted relation names instead of ids
59 | Return:
60 | {'acc': xx}
61 | """
62 | correct = 0
63 | total = len(self.data)
64 | correct_positive = 0
65 | pred_positive = 0
66 | gold_positive = 0
67 | neg = -1
68 | for name in ['NA', 'na', 'no_relation', 'Other', 'Others']:
69 | if name in self.rel2id:
70 | if use_name:
71 | neg = name
72 | else:
73 | neg = self.rel2id[name]
74 | break
75 | for i in range(total):
76 | if use_name:
77 | golden = self.data[i]['relation']
78 | else:
79 | golden = self.rel2id[self.data[i]['relation']]
80 | if golden == pred_result[i]:
81 | correct += 1
82 | if golden != neg:
83 | correct_positive += 1
84 | if golden != neg:
85 | gold_positive +=1
86 | if pred_result[i] != neg:
87 | pred_positive += 1
88 | acc = float(correct) / float(total)
89 | try:
90 | micro_p = float(correct_positive) / float(pred_positive)
91 | except:
92 | micro_p = 0
93 | try:
94 | micro_r = float(correct_positive) / float(gold_positive)
95 | except:
96 | micro_r = 0
97 | try:
98 | micro_f1 = 2 * micro_p * micro_r / (micro_p + micro_r)
99 | except:
100 | micro_f1 = 0
101 | result = {'acc': acc, 'micro_p': micro_p, 'micro_r': micro_r, 'micro_f1': micro_f1}
102 | logging.info('Evaluation result: {}.'.format(result))
103 | return result
104 |
105 | def SentenceRELoader(path, rel2id, tokenizer, batch_size,
106 | shuffle, num_workers=8, collate_fn=SentenceREDataset.collate_fn, **kwargs):
107 | dataset = SentenceREDataset(path = path, rel2id = rel2id, tokenizer = tokenizer, kwargs=kwargs)
108 | data_loader = data.DataLoader(dataset=dataset,
109 | batch_size=batch_size,
110 | shuffle=shuffle,
111 | pin_memory=True,
112 | num_workers=num_workers,
113 | collate_fn=collate_fn)
114 | return data_loader
115 |
116 | class BagREDataset(data.Dataset):
117 | """
118 | Bag-level relation extraction dataset. Note that relation of NA should be named as 'NA'.
119 | """
120 | def __init__(self, path, rel2id, tokenizer, entpair_as_bag=False, bag_size=0, mode=None):
121 | """
122 | Args:
123 | path: path of the input file
124 | rel2id: dictionary of relation->id mapping
125 | tokenizer: function of tokenizing
126 | entpair_as_bag: if True, bags are constructed based on same
127 | entity pairs instead of same relation facts (ignoring
128 | relation labels)
129 | """
130 | super().__init__()
131 | self.tokenizer = tokenizer
132 | self.rel2id = rel2id
133 | self.entpair_as_bag = entpair_as_bag
134 | self.bag_size = bag_size
135 |
136 | # Load the file
137 | f = open(path)
138 | self.data = []
139 | for line in f:
140 | line = line.rstrip()
141 | if len(line) > 0:
142 | self.data.append(eval(line))
143 | f.close()
144 |
145 | # Construct bag-level dataset (a bag contains instances sharing the same relation fact)
146 | if mode == None:
147 | self.weight = np.ones((len(self.rel2id)), dtype=np.float32)
148 | self.bag_scope = []
149 | self.name2id = {}
150 | self.bag_name = []
151 | self.facts = {}
152 | for idx, item in enumerate(self.data):
153 | # Annotated test set
154 | if 'anno_relation_list' in item:
155 | for r in item['anno_relation_list']:
156 | fact = (item['h']['id'], item['t']['id'], r)
157 | if r != 'NA':
158 | self.facts[fact] = 1
159 | assert entpair_as_bag
160 | name = (item['h']['id'], item['t']['id'])
161 | else:
162 | fact = (item['h']['id'], item['t']['id'], item['relation'])
163 | if item['relation'] != 'NA':
164 | self.facts[fact] = 1
165 | if entpair_as_bag:
166 | name = (item['h']['id'], item['t']['id'])
167 | else:
168 | name = fact
169 | if name not in self.name2id:
170 | self.name2id[name] = len(self.name2id)
171 | self.bag_scope.append([])
172 | self.bag_name.append(name)
173 | self.bag_scope[self.name2id[name]].append(idx)
174 | self.weight[self.rel2id[item['relation']]] += 1.0
175 | self.weight = 1.0 / (self.weight ** 0.05)
176 | self.weight = torch.from_numpy(self.weight)
177 | else:
178 | pass
179 |
180 | def __len__(self):
181 | return len(self.bag_scope)
182 |
183 | def __getitem__(self, index):
184 | bag = self.bag_scope[index]
185 | if self.bag_size > 0:
186 | if self.bag_size <= len(bag):
187 | resize_bag = random.sample(bag, self.bag_size)
188 | else:
189 | resize_bag = bag + list(np.random.choice(bag, self.bag_size - len(bag)))
190 | bag = resize_bag
191 |
192 | seqs = None
193 | rel = self.rel2id[self.data[bag[0]]['relation']]
194 | for sent_id in bag:
195 | item = self.data[sent_id]
196 | seq = list(self.tokenizer(item))
197 | if seqs is None:
198 | seqs = []
199 | for i in range(len(seq)):
200 | seqs.append([])
201 | for i in range(len(seq)):
202 | seqs[i].append(seq[i])
203 | for i in range(len(seqs)):
204 | seqs[i] = torch.cat(seqs[i], 0) # (n, L), n is the size of bag
205 | return [rel, self.bag_name[index], len(bag)] + seqs
206 |
207 | def collate_fn(data):
208 | data = list(zip(*data))
209 | label, bag_name, count = data[:3]
210 | seqs = data[3:]
211 | for i in range(len(seqs)):
212 | seqs[i] = torch.cat(seqs[i], 0) # (sumn, L)
213 | seqs[i] = seqs[i].expand((torch.cuda.device_count() if torch.cuda.device_count() > 0 else 1, ) + seqs[i].size())
214 | scope = [] # (B, 2)
215 | start = 0
216 | for c in count:
217 | scope.append((start, start + c))
218 | start += c
219 | assert(start == seqs[0].size(1))
220 | scope = torch.tensor(scope).long()
221 | label = torch.tensor(label).long() # (B)
222 | return [label, bag_name, scope] + seqs
223 |
224 | def collate_bag_size_fn(data):
225 | data = list(zip(*data))
226 | label, bag_name, count = data[:3]
227 | seqs = data[3:]
228 | for i in range(len(seqs)):
229 | seqs[i] = torch.stack(seqs[i], 0) # (batch, bag, L)
230 | scope = [] # (B, 2)
231 | start = 0
232 | for c in count:
233 | scope.append((start, start + c))
234 | start += c
235 | label = torch.tensor(label).long() # (B)
236 | return [label, bag_name, scope] + seqs
237 |
238 |
239 | def eval(self, pred_result, threshold=0.5):
240 | """
241 | Args:
242 | pred_result: a list with dict {'entpair': (head_id, tail_id), 'relation': rel, 'score': score}.
243 | Note that relation of NA should be excluded.
244 | Return:
245 | {'prec': narray[...], 'rec': narray[...], 'mean_prec': xx, 'f1': xx, 'auc': xx}
246 | prec (precision) and rec (recall) are in micro style.
247 | prec (precision) and rec (recall) are sorted in the decreasing order of the score.
248 | f1 is the max f1 score of those precison-recall points
249 | """
250 | sorted_pred_result = sorted(pred_result, key=lambda x: x['score'], reverse=True)
251 | prec = []
252 | rec = []
253 | correct = 0
254 | total = len(self.facts)
255 |
256 | entpair = {}
257 |
258 | for i, item in enumerate(sorted_pred_result):
259 | # Save entpair label and result for later calculating F1
260 | idtf = item['entpair'][0] + '#' + item['entpair'][1]
261 | if idtf not in entpair:
262 | entpair[idtf] = {
263 | 'label': np.zeros((len(self.rel2id)), dtype=np.int),
264 | 'pred': np.zeros((len(self.rel2id)), dtype=np.int),
265 | 'score': np.zeros((len(self.rel2id)), dtype=np.float)
266 | }
267 | if (item['entpair'][0], item['entpair'][1], item['relation']) in self.facts:
268 | correct += 1
269 | entpair[idtf]['label'][self.rel2id[item['relation']]] = 1
270 | if item['score'] >= threshold:
271 | entpair[idtf]['pred'][self.rel2id[item['relation']]] = 1
272 | entpair[idtf]['score'][self.rel2id[item['relation']]] = item['score']
273 |
274 | prec.append(float(correct) / float(i + 1))
275 | rec.append(float(correct) / float(total))
276 |
277 | auc = sklearn.metrics.auc(x=rec, y=prec)
278 | np_prec = np.array(prec)
279 | np_rec = np.array(rec)
280 | max_micro_f1 = (2 * np_prec * np_rec / (np_prec + np_rec + 1e-20)).max()
281 | best_threshold = sorted_pred_result[(2 * np_prec * np_rec / (np_prec + np_rec + 1e-20)).argmax()]['score']
282 | mean_prec = np_prec.mean()
283 |
284 | label_vec = []
285 | pred_result_vec = []
286 | score_vec = []
287 | for ep in entpair:
288 | label_vec.append(entpair[ep]['label'])
289 | pred_result_vec.append(entpair[ep]['pred'])
290 | score_vec.append(entpair[ep]['score'])
291 | label_vec = np.stack(label_vec, 0)
292 | pred_result_vec = np.stack(pred_result_vec, 0)
293 | score_vec = np.stack(score_vec, 0)
294 |
295 | micro_p = sklearn.metrics.precision_score(label_vec, pred_result_vec, labels=list(range(1, len(self.rel2id))), average='micro')
296 | micro_r = sklearn.metrics.recall_score(label_vec, pred_result_vec, labels=list(range(1, len(self.rel2id))), average='micro')
297 | micro_f1 = sklearn.metrics.f1_score(label_vec, pred_result_vec, labels=list(range(1, len(self.rel2id))), average='micro')
298 |
299 | macro_p = sklearn.metrics.precision_score(label_vec, pred_result_vec, labels=list(range(1, len(self.rel2id))), average='macro')
300 | macro_r = sklearn.metrics.recall_score(label_vec, pred_result_vec, labels=list(range(1, len(self.rel2id))), average='macro')
301 | macro_f1 = sklearn.metrics.f1_score(label_vec, pred_result_vec, labels=list(range(1, len(self.rel2id))), average='macro')
302 |
303 | pred_result_vec = score_vec >= best_threshold
304 | max_macro_f1 = sklearn.metrics.f1_score(label_vec, pred_result_vec, labels=list(range(1, len(self.rel2id))), average='macro')
305 | max_micro_f1_each_relation = {}
306 | for rel in self.rel2id:
307 | if rel != 'NA':
308 | max_micro_f1_each_relation[rel] = sklearn.metrics.f1_score(label_vec, pred_result_vec, labels=[self.rel2id[rel]], average='micro')
309 |
310 | return {'np_prec': np_prec, 'np_rec': np_rec, 'max_micro_f1': max_micro_f1, 'max_macro_f1': max_macro_f1, 'auc': auc, 'p@100': np_prec[99], 'p@200': np_prec[199], 'p@300': np_prec[299], 'avg_p300': (np_prec[99] + np_prec[199] + np_prec[299]) / 3, 'micro_f1': micro_f1, 'macro_f1': macro_f1, 'max_micro_f1_each_relation': max_micro_f1_each_relation}
311 |
312 | def BagRELoader(path, rel2id, tokenizer, batch_size,
313 | shuffle, entpair_as_bag=False, bag_size=0, num_workers=8,
314 | collate_fn=BagREDataset.collate_fn):
315 | if bag_size == 0:
316 | collate_fn = BagREDataset.collate_fn
317 | else:
318 | collate_fn = BagREDataset.collate_bag_size_fn
319 | dataset = BagREDataset(path, rel2id, tokenizer, entpair_as_bag=entpair_as_bag, bag_size=bag_size)
320 | data_loader = data.DataLoader(dataset=dataset,
321 | batch_size=batch_size,
322 | shuffle=shuffle,
323 | pin_memory=True,
324 | num_workers=num_workers,
325 | collate_fn=collate_fn)
326 | return data_loader
327 |
328 |
329 | class MultiLabelSentenceREDataset(data.Dataset):
330 | """
331 | Sentence-level relation extraction dataset
332 | """
333 | def __init__(self, path, rel2id, tokenizer, kwargs):
334 | """
335 | Args:
336 | path: path of the input file
337 | rel2id: dictionary of relation->id mapping
338 | tokenizer: function of tokenizing
339 | """
340 | super().__init__()
341 | self.path = path
342 | self.tokenizer = tokenizer
343 | self.rel2id = rel2id
344 | self.kwargs = kwargs
345 |
346 | # Load the file
347 | f = open(path)
348 | self.data = []
349 | for line in f.readlines():
350 | line = line.rstrip()
351 | if len(line) > 0:
352 | self.data.append(eval(line))
353 | f.close()
354 | logging.info("Loaded sentence RE dataset {} with {} lines and {} relations.".format(path, len(self.data), len(self.rel2id)))
355 |
356 | def __len__(self):
357 | return len(self.data)
358 |
359 | def __getitem__(self, index):
360 | item = self.data[index]
361 | seq = list(self.tokenizer(item, **self.kwargs))
362 | res = [self.rel2id[item['relation']]] + seq
363 | return [self.rel2id[item['relation']]] + seq # label, seq1, seq2, ...
364 |
365 | def collate_fn(data):
366 | data = list(zip(*data))
367 | labels = data[0]
368 | seqs = data[1:]
369 | batch_labels = torch.tensor(labels).long() # (B)
370 | batch_seqs = []
371 | for seq in seqs:
372 | batch_seqs.append(torch.cat(seq, 0)) # (B, L)
373 | return [batch_labels] + batch_seqs
374 |
375 | def eval(self, pred_score, threshold=0.5, use_name=False):
376 | """
377 | Args:
378 | pred_score: [sent_num, label_num]
379 | use_name: if True, `pred_result` contains predicted relation names instead of ids
380 | Return:
381 | {'acc': xx}
382 | """
383 | assert len(self.data) == len(pred_score)
384 | pred_score = np.array(pred_score)
385 |
386 | # Calculate AUC
387 | sorted_result = []
388 | total = 0
389 | for sent_id in range(len(self.data)):
390 | for rel in self.rel2id:
391 | if rel not in ['NA', 'na', 'N/A', 'None', 'none', 'n/a', 'no_relation']:
392 | sorted_result.append({'sent_id': sent_id, 'relation': rel, 'score': pred_score[sent_id][self.rel2id[rel]]})
393 | if 'anno_relation_list' in self.data[sent_id]:
394 | if rel in self.data[sent_id]['anno_relation_list']:
395 | total += 1
396 | else:
397 | if rel == self.data[sent_id]['relation']:
398 | total += 1
399 |
400 | sorted_result.sort(key=lambda x: x['score'], reverse=True)
401 | prec = []
402 | rec = []
403 | correct = 0
404 | for i, item in enumerate(sorted_result):
405 | if 'anno_relation_list' in self.data[item['sent_id']]:
406 | if item['relation'] in self.data[item['sent_id']]['anno_relation_list']:
407 | correct += 1
408 | else:
409 | if item['relation'] == self.data[item['sent_id']]['relation']:
410 | correct += 1
411 | prec.append(float(correct) / float(i + 1))
412 | rec.append(float(correct) / float(total))
413 | auc = sklearn.metrics.auc(x=rec, y=prec)
414 | np_prec = np.array(prec)
415 | np_rec = np.array(rec)
416 | max_micro_f1 = (2 * np_prec * np_rec / (np_prec + np_rec + 1e-20)).max()
417 | max_micro_f1_threshold = sorted_result[(2 * np_prec * np_rec / (np_prec + np_rec + 1e-20)).argmax()]['score']
418 | mean_prec = np_prec.mean()
419 |
420 | # Calculate F1
421 | pred_result_vec = np.zeros((len(self.data), len(self.rel2id)), dtype=np.int)
422 | pred_result_vec[pred_score >= threshold] = 1
423 | label_vec = []
424 | for item in self.data:
425 | if 'anno_relation_list' in item:
426 | label_vec.append(np.array(item['anno_relation_vec'], dtype=np.int))
427 | else:
428 | one_hot = np.zeros((len(self.rel2id)), dtype=np.int)
429 | one_hot[self.rel2id[item['relation']]] = 1
430 | label_vec.append(one_hot)
431 | label_vec = np.stack(label_vec, 0)
432 | assert label_vec.shape == pred_result_vec.shape
433 |
434 | micro_p = sklearn.metrics.precision_score(label_vec, pred_result_vec, labels=list(range(1, len(self.rel2id))), average='micro')
435 | micro_r = sklearn.metrics.recall_score(label_vec, pred_result_vec, labels=list(range(1, len(self.rel2id))), average='micro')
436 | micro_f1 = sklearn.metrics.f1_score(label_vec, pred_result_vec, labels=list(range(1, len(self.rel2id))), average='micro')
437 |
438 | macro_p = sklearn.metrics.precision_score(label_vec, pred_result_vec, labels=list(range(1, len(self.rel2id))), average='macro')
439 | macro_r = sklearn.metrics.recall_score(label_vec, pred_result_vec, labels=list(range(1, len(self.rel2id))), average='macro')
440 | macro_f1 = sklearn.metrics.f1_score(label_vec, pred_result_vec, labels=list(range(1, len(self.rel2id))), average='macro')
441 |
442 | acc = (label_vec == pred_result_vec).mean()
443 |
444 | result = {'acc': acc, 'micro_p': micro_p, 'micro_r': micro_r, 'micro_f1': micro_f1, 'macro_p': macro_p, 'macro_r': macro_r, 'macro_f1': macro_f1, 'np_prec': np_prec, 'np_rec': np_rec, 'max_micro_f1': max_micro_f1, 'max_micro_f1_threshold': max_micro_f1_threshold, 'auc': auc, 'p@100': np_prec[99], 'p@200': np_prec[199], 'p@300': np_prec[299]}
445 | logging.info('Evaluation result: {}.'.format(result))
446 | return result
447 |
448 | def MultiLabelSentenceRELoader(path, rel2id, tokenizer, batch_size,
449 | shuffle, num_workers=8, collate_fn=SentenceREDataset.collate_fn, **kwargs):
450 | dataset = MultiLabelSentenceREDataset(path = path, rel2id = rel2id, tokenizer = tokenizer, kwargs=kwargs)
451 | data_loader = data.DataLoader(dataset=dataset,
452 | batch_size=batch_size,
453 | shuffle=shuffle,
454 | pin_memory=True,
455 | num_workers=num_workers,
456 | collate_fn=collate_fn)
457 | return data_loader
458 |
459 |
460 |
--------------------------------------------------------------------------------
/opennre/framework/multi_label_sentence_re.py:
--------------------------------------------------------------------------------
1 | import os, logging, json
2 | from tqdm import tqdm
3 | import torch
4 | from torch import nn, optim
5 | from .data_loader import MultiLabelSentenceRELoader
6 | from .utils import AverageMeter
7 | import numpy as np
8 |
9 | class MultiLabelSentenceRE(nn.Module):
10 |
11 | def __init__(self,
12 | model,
13 | train_path,
14 | val_path,
15 | test_path,
16 | ckpt,
17 | batch_size=32,
18 | max_epoch=100,
19 | lr=0.1,
20 | weight_decay=1e-5,
21 | warmup_step=300,
22 | opt='sgd'):
23 |
24 | super().__init__()
25 | self.max_epoch = max_epoch
26 | # Load data
27 | if train_path != None:
28 | self.train_loader = MultiLabelSentenceRELoader(
29 | train_path,
30 | model.rel2id,
31 | model.sentence_encoder.tokenize,
32 | batch_size,
33 | True)
34 |
35 | if val_path != None:
36 | self.val_loader = MultiLabelSentenceRELoader(
37 | val_path,
38 | model.rel2id,
39 | model.sentence_encoder.tokenize,
40 | batch_size,
41 | False)
42 |
43 | if test_path != None:
44 | self.test_loader = MultiLabelSentenceRELoader(
45 | test_path,
46 | model.rel2id,
47 | model.sentence_encoder.tokenize,
48 | batch_size,
49 | False
50 | )
51 | # Model
52 | self.model = model
53 | self.parallel_model = nn.DataParallel(self.model)
54 | # Criterion
55 | self.criterion = nn.BCEWithLogitsLoss()
56 | # Params and optimizer
57 | params = self.parameters()
58 | self.lr = lr
59 | if opt == 'sgd':
60 | self.optimizer = optim.SGD(params, lr, weight_decay=weight_decay)
61 | elif opt == 'adam':
62 | self.optimizer = optim.Adam(params, lr, weight_decay=weight_decay)
63 | elif opt == 'adamw': # Optimizer for BERT
64 | from transformers import AdamW
65 | params = list(self.named_parameters())
66 | no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
67 | grouped_params = [
68 | {
69 | 'params': [p for n, p in params if not any(nd in n for nd in no_decay)],
70 | 'weight_decay': 0.01,
71 | 'lr': lr,
72 | 'ori_lr': lr
73 | },
74 | {
75 | 'params': [p for n, p in params if any(nd in n for nd in no_decay)],
76 | 'weight_decay': 0.0,
77 | 'lr': lr,
78 | 'ori_lr': lr
79 | }
80 | ]
81 | self.optimizer = AdamW(grouped_params, correct_bias=False)
82 | else:
83 | raise Exception("Invalid optimizer. Must be 'sgd' or 'adam' or 'adamw'.")
84 | # Warmup
85 | if warmup_step > 0:
86 | from transformers import get_linear_schedule_with_warmup
87 | training_steps = self.train_loader.dataset.__len__() // batch_size * self.max_epoch
88 | self.scheduler = get_linear_schedule_with_warmup(self.optimizer, num_warmup_steps=warmup_step, num_training_steps=training_steps)
89 | else:
90 | self.scheduler = None
91 | # Cuda
92 | if torch.cuda.is_available():
93 | self.cuda()
94 | # Ckpt
95 | self.ckpt = ckpt
96 |
97 | def train_model(self, metric='acc'):
98 | best_metric = 0
99 | global_step = 0
100 | for epoch in range(self.max_epoch):
101 | self.train()
102 | logging.info("=== Epoch %d train ===" % epoch)
103 | avg_loss = AverageMeter()
104 | avg_acc = AverageMeter()
105 | t = tqdm(self.train_loader)
106 | for iter, data in enumerate(t):
107 | if torch.cuda.is_available():
108 | for i in range(len(data)):
109 | try:
110 | data[i] = data[i].cuda()
111 | except:
112 | pass
113 | label = data[0]
114 | args = data[1:]
115 | logits = self.parallel_model(*args)
116 |
117 | label_vec = torch.zeros_like(logits).cuda()
118 | label_vec[torch.arange(label_vec.size(0)), label] = 1
119 | label_vec = label_vec[:, 1:]
120 | logits = logits[:, 1:]
121 |
122 | loss = self.criterion(logits.reshape(-1), label_vec.reshape(-1))
123 | pred = (torch.sigmoid(logits) >= 0.5).long()
124 | acc = float((pred == label_vec).long().sum()) / (label_vec.size(0) * label_vec.size(1))
125 |
126 | # Log
127 | avg_loss.update(loss.item(), 1)
128 | avg_acc.update(acc, 1)
129 | t.set_postfix(loss=avg_loss.avg, acc=avg_acc.avg)
130 | # Optimize
131 | loss.backward()
132 | self.optimizer.step()
133 | if self.scheduler is not None:
134 | self.scheduler.step()
135 | self.optimizer.zero_grad()
136 | global_step += 1
137 | # Val
138 | logging.info("=== Epoch %d val ===" % epoch)
139 | result = self.eval_model(self.val_loader)
140 | logging.info('Metric {} current / best: {} / {}'.format(metric, result[metric], best_metric))
141 | if result[metric] > best_metric:
142 | logging.info("Best ckpt and saved.")
143 | folder_path = '/'.join(self.ckpt.split('/')[:-1])
144 | if not os.path.exists(folder_path):
145 | os.mkdir(folder_path)
146 | torch.save({'state_dict': self.model.state_dict()}, self.ckpt)
147 | best_metric = result[metric]
148 | logging.info("Best %s on val set: %f" % (metric, best_metric))
149 |
150 | def eval_model(self, eval_loader):
151 | self.eval()
152 | pred_score = []
153 | with torch.no_grad():
154 | t = tqdm(eval_loader)
155 | for iter, data in enumerate(t):
156 | if torch.cuda.is_available():
157 | for i in range(len(data)):
158 | try:
159 | data[i] = data[i].cuda()
160 | except:
161 | pass
162 | label = data[0]
163 | args = data[1:]
164 | logits = self.parallel_model(*args)
165 | score = self.parallel_model.module.logit_to_score(logits).cpu().numpy()
166 | # Save result
167 | pred_score.append(score)
168 | # Log
169 | pred_score = np.concatenate(pred_score, 0)
170 | result = eval_loader.dataset.eval(pred_score)
171 | return result
172 |
173 | def load_state_dict(self, state_dict):
174 | self.model.load_state_dict(state_dict)
175 |
176 |
--------------------------------------------------------------------------------
/opennre/framework/sentence_re.py:
--------------------------------------------------------------------------------
1 | import os, logging, json
2 | from tqdm import tqdm
3 | import torch
4 | from torch import nn, optim
5 | from .data_loader import SentenceRELoader
6 | from .utils import AverageMeter
7 |
8 | class SentenceRE(nn.Module):
9 |
10 | def __init__(self,
11 | model,
12 | train_path,
13 | val_path,
14 | test_path,
15 | ckpt,
16 | batch_size=32,
17 | max_epoch=100,
18 | lr=0.1,
19 | weight_decay=1e-5,
20 | warmup_step=300,
21 | opt='sgd'):
22 |
23 | super().__init__()
24 | self.max_epoch = max_epoch
25 | # Load data
26 | if train_path != None:
27 | self.train_loader = SentenceRELoader(
28 | train_path,
29 | model.rel2id,
30 | model.sentence_encoder.tokenize,
31 | batch_size,
32 | True)
33 |
34 | if val_path != None:
35 | self.val_loader = SentenceRELoader(
36 | val_path,
37 | model.rel2id,
38 | model.sentence_encoder.tokenize,
39 | batch_size,
40 | False)
41 |
42 | if test_path != None:
43 | self.test_loader = SentenceRELoader(
44 | test_path,
45 | model.rel2id,
46 | model.sentence_encoder.tokenize,
47 | batch_size,
48 | False
49 | )
50 | # Model
51 | self.model = model
52 | self.parallel_model = nn.DataParallel(self.model)
53 | # Criterion
54 | self.criterion = nn.CrossEntropyLoss()
55 | # Params and optimizer
56 | params = self.parameters()
57 | self.lr = lr
58 | if opt == 'sgd':
59 | self.optimizer = optim.SGD(params, lr, weight_decay=weight_decay)
60 | elif opt == 'adam':
61 | self.optimizer = optim.Adam(params, lr, weight_decay=weight_decay)
62 | elif opt == 'adamw': # Optimizer for BERT
63 | from transformers import AdamW
64 | params = list(self.named_parameters())
65 | no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
66 | grouped_params = [
67 | {
68 | 'params': [p for n, p in params if not any(nd in n for nd in no_decay)],
69 | 'weight_decay': 0.01,
70 | 'lr': lr,
71 | 'ori_lr': lr
72 | },
73 | {
74 | 'params': [p for n, p in params if any(nd in n for nd in no_decay)],
75 | 'weight_decay': 0.0,
76 | 'lr': lr,
77 | 'ori_lr': lr
78 | }
79 | ]
80 | self.optimizer = AdamW(grouped_params, correct_bias=False)
81 | else:
82 | raise Exception("Invalid optimizer. Must be 'sgd' or 'adam' or 'adamw'.")
83 | # Warmup
84 | if warmup_step > 0:
85 | from transformers import get_linear_schedule_with_warmup
86 | training_steps = self.train_loader.dataset.__len__() // batch_size * self.max_epoch
87 | self.scheduler = get_linear_schedule_with_warmup(self.optimizer, num_warmup_steps=warmup_step, num_training_steps=training_steps)
88 | else:
89 | self.scheduler = None
90 | # Cuda
91 | if torch.cuda.is_available():
92 | self.cuda()
93 | # Ckpt
94 | self.ckpt = ckpt
95 |
96 | def train_model(self, metric='acc'):
97 | best_metric = 0
98 | global_step = 0
99 | for epoch in range(self.max_epoch):
100 | self.train()
101 | logging.info("=== Epoch %d train ===" % epoch)
102 | avg_loss = AverageMeter()
103 | avg_acc = AverageMeter()
104 | t = tqdm(self.train_loader)
105 | for iter, data in enumerate(t):
106 | if torch.cuda.is_available():
107 | for i in range(len(data)):
108 | try:
109 | data[i] = data[i].cuda()
110 | except:
111 | pass
112 | label = data[0]
113 | args = data[1:]
114 | logits = self.parallel_model(*args)
115 | loss = self.criterion(logits, label)
116 | score, pred = logits.max(-1) # (B)
117 | acc = float((pred == label).long().sum()) / label.size(0)
118 | # Log
119 | avg_loss.update(loss.item(), 1)
120 | avg_acc.update(acc, 1)
121 | t.set_postfix(loss=avg_loss.avg, acc=avg_acc.avg)
122 | # Optimize
123 | loss.backward()
124 | self.optimizer.step()
125 | if self.scheduler is not None:
126 | self.scheduler.step()
127 | self.optimizer.zero_grad()
128 | global_step += 1
129 | # Val
130 | logging.info("=== Epoch %d val ===" % epoch)
131 | result = self.eval_model(self.val_loader)
132 | logging.info('Metric {} current / best: {} / {}'.format(metric, result[metric], best_metric))
133 | if result[metric] > best_metric:
134 | logging.info("Best ckpt and saved.")
135 | folder_path = '/'.join(self.ckpt.split('/')[:-1])
136 | if not os.path.exists(folder_path):
137 | os.mkdir(folder_path)
138 | torch.save({'state_dict': self.model.state_dict()}, self.ckpt)
139 | best_metric = result[metric]
140 | logging.info("Best %s on val set: %f" % (metric, best_metric))
141 |
142 | def eval_model(self, eval_loader):
143 | self.eval()
144 | avg_acc = AverageMeter()
145 | pred_result = []
146 | with torch.no_grad():
147 | t = tqdm(eval_loader)
148 | for iter, data in enumerate(t):
149 | if torch.cuda.is_available():
150 | for i in range(len(data)):
151 | try:
152 | data[i] = data[i].cuda()
153 | except:
154 | pass
155 | label = data[0]
156 | args = data[1:]
157 | logits = self.parallel_model(*args)
158 | score, pred = logits.max(-1) # (B)
159 | # Save result
160 | for i in range(pred.size(0)):
161 | pred_result.append(pred[i].item())
162 | # Log
163 | acc = float((pred == label).long().sum()) / label.size(0)
164 | avg_acc.update(acc, pred.size(0))
165 | t.set_postfix(acc=avg_acc.avg)
166 | result = eval_loader.dataset.eval(pred_result)
167 | return result
168 |
169 | def load_state_dict(self, state_dict):
170 | self.model.load_state_dict(state_dict)
171 |
172 |
--------------------------------------------------------------------------------
/opennre/framework/utils.py:
--------------------------------------------------------------------------------
1 | class AverageMeter(object):
2 | """
3 | Computes and stores the average and current value of metrics.
4 | """
5 |
6 | def __init__(self):
7 | self.reset()
8 |
9 | def reset(self):
10 | self.val = 0
11 | self.avg = 0
12 | self.sum = 0
13 | self.count = 0
14 |
15 | def update(self, val, n=0):
16 | self.val = val
17 | self.sum += val * n
18 | self.count += n
19 | self.avg = self.sum / (.0001 + self.count)
20 |
21 | def __str__(self):
22 | """
23 | String representation for logging
24 | """
25 | # for values that should be recorded exactly e.g. iteration number
26 | if self.count == 0:
27 | return str(self.val)
28 | # for stats
29 | return '%.4f (%.4f)' % (self.val, self.avg)
30 |
--------------------------------------------------------------------------------
/opennre/model/__init__.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | from .base_model import SentenceRE, BagRE, FewShotRE, NER
6 | from .softmax_nn import SoftmaxNN
7 | from .sigmoid_nn import SigmoidNN
8 | from .bag_attention import BagAttention
9 | from .bag_average import BagAverage
10 | from .bag_one import BagOne
11 |
12 | __all__ = [
13 | 'SentenceRE',
14 | 'BagRE',
15 | 'FewShotRE',
16 | 'NER',
17 | 'SoftmaxNN',
18 | 'BagAttention',
19 | 'BagAverage',
20 | 'BagOne'
21 | ]
--------------------------------------------------------------------------------
/opennre/model/bag_attention.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn, optim
3 | from .base_model import BagRE
4 |
5 | class BagAttention(BagRE):
6 | """
7 | Instance attention for bag-level relation extraction.
8 | """
9 |
10 | def __init__(self, sentence_encoder, num_class, rel2id, use_diag=True):
11 | """
12 | Args:
13 | sentence_encoder: encoder for sentences
14 | num_class: number of classes
15 | id2rel: dictionary of id -> relation name mapping
16 | """
17 | super().__init__()
18 | self.sentence_encoder = sentence_encoder
19 | self.num_class = num_class
20 | self.fc = nn.Linear(self.sentence_encoder.hidden_size, num_class)
21 | self.softmax = nn.Softmax(-1)
22 | self.rel2id = rel2id
23 | self.id2rel = {}
24 | self.drop = nn.Dropout()
25 | for rel, id in rel2id.items():
26 | self.id2rel[id] = rel
27 | if use_diag:
28 | self.use_diag = True
29 | self.diag = nn.Parameter(torch.ones(self.sentence_encoder.hidden_size))
30 | else:
31 | self.use_diag = False
32 |
33 | def infer(self, bag):
34 | """
35 | Args:
36 | bag: bag of sentences with the same entity pair
37 | [{
38 | 'text' or 'token': ...,
39 | 'h': {'pos': [start, end], ...},
40 | 't': {'pos': [start, end], ...}
41 | }]
42 | Return:
43 | (relation, score)
44 | """
45 | self.eval()
46 | tokens = []
47 | pos1s = []
48 | pos2s = []
49 | masks = []
50 | for item in bag:
51 | token, pos1, pos2, mask = self.sentence_encoder.tokenize(item)
52 | tokens.append(token)
53 | pos1s.append(pos1)
54 | pos2s.append(pos2)
55 | masks.append(mask)
56 | tokens = torch.cat(tokens, 0).unsqueeze(0) # (n, L)
57 | pos1s = torch.cat(pos1s, 0).unsqueeze(0)
58 | pos2s = torch.cat(pos2s, 0).unsqueeze(0)
59 | masks = torch.cat(masks, 0).unsqueeze(0)
60 | scope = torch.tensor([[0, len(bag)]]).long() # (1, 2)
61 | bag_logits = self.forward(None, scope, tokens, pos1s, pos2s, masks, train=False).squeeze(0) # (N) after softmax
62 | score, pred = bag_logits.max(0)
63 | score = score.item()
64 | pred = pred.item()
65 | rel = self.id2rel[pred]
66 | return (rel, score)
67 |
68 | def forward(self, label, scope, token, pos1, pos2, mask=None, train=True, bag_size=0):
69 | """
70 | Args:
71 | label: (B), label of the bag
72 | scope: (B), scope for each bag
73 | token: (nsum, L), index of tokens
74 | pos1: (nsum, L), relative position to head entity
75 | pos2: (nsum, L), relative position to tail entity
76 | mask: (nsum, L), used for piece-wise CNN
77 | Return:
78 | logits, (B, N)
79 |
80 | Dirty hack:
81 | When the encoder is BERT, the input is actually token, att_mask, pos1, pos2, but
82 | since the arguments are then fed into BERT encoder with the original order,
83 | the encoder can actually work out correclty.
84 | """
85 | if bag_size > 0:
86 | token = token.view(-1, token.size(-1))
87 | pos1 = pos1.view(-1, pos1.size(-1))
88 | pos2 = pos2.view(-1, pos2.size(-1))
89 | if mask is not None:
90 | mask = mask.view(-1, mask.size(-1))
91 | else:
92 | begin, end = scope[0][0], scope[-1][1]
93 | token = token[:, begin:end, :].view(-1, token.size(-1))
94 | pos1 = pos1[:, begin:end, :].view(-1, pos1.size(-1))
95 | pos2 = pos2[:, begin:end, :].view(-1, pos2.size(-1))
96 | if mask is not None:
97 | mask = mask[:, begin:end, :].view(-1, mask.size(-1))
98 | scope = torch.sub(scope, torch.zeros_like(scope).fill_(begin))
99 |
100 | # Attention
101 | if train:
102 | if mask is not None:
103 | rep = self.sentence_encoder(token, pos1, pos2, mask) # (nsum, H)
104 | else:
105 | rep = self.sentence_encoder(token, pos1, pos2) # (nsum, H)
106 |
107 | if bag_size == 0:
108 | bag_rep = []
109 | query = torch.zeros((rep.size(0))).long()
110 | if torch.cuda.is_available():
111 | query = query.cuda()
112 | for i in range(len(scope)):
113 | query[scope[i][0]:scope[i][1]] = label[i]
114 | att_mat = self.fc.weight[query] # (nsum, H)
115 | if self.use_diag:
116 | att_mat = att_mat * self.diag.unsqueeze(0)
117 | att_score = (rep * att_mat).sum(-1) # (nsum)
118 |
119 | for i in range(len(scope)):
120 | bag_mat = rep[scope[i][0]:scope[i][1]] # (n, H)
121 | softmax_att_score = self.softmax(att_score[scope[i][0]:scope[i][1]]) # (n)
122 | bag_rep.append((softmax_att_score.unsqueeze(-1) * bag_mat).sum(0)) # (n, 1) * (n, H) -> (n, H) -> (H)
123 | bag_rep = torch.stack(bag_rep, 0) # (B, H)
124 | else:
125 | batch_size = label.size(0)
126 | query = label.unsqueeze(1) # (B, 1)
127 | att_mat = self.fc.weight[query] # (B, 1, H)
128 | if self.use_diag:
129 | att_mat = att_mat * self.diag.unsqueeze(0)
130 | rep = rep.view(batch_size, bag_size, -1)
131 | att_score = (rep * att_mat).sum(-1) # (B, bag)
132 | softmax_att_score = self.softmax(att_score) # (B, bag)
133 | bag_rep = (softmax_att_score.unsqueeze(-1) * rep).sum(1) # (B, bag, 1) * (B, bag, H) -> (B, bag, H) -> (B, H)
134 | bag_rep = self.drop(bag_rep)
135 | bag_logits = self.fc(bag_rep) # (B, N)
136 | else:
137 |
138 | if bag_size == 0:
139 | rep = []
140 | bs = 256
141 | total_bs = len(token) // bs + (1 if len(token) % bs != 0 else 0)
142 | for b in range(total_bs):
143 | with torch.no_grad():
144 | left = bs * b
145 | right = min(bs * (b + 1), len(token))
146 | if mask is not None:
147 | rep.append(self.sentence_encoder(token[left:right], pos1[left:right], pos2[left:right], mask[left:right]).detach()) # (nsum, H)
148 | else:
149 | rep.append(self.sentence_encoder(token[left:right], pos1[left:right], pos2[left:right]).detach()) # (nsum, H)
150 | rep = torch.cat(rep, 0)
151 |
152 | bag_logits = []
153 | att_mat = self.fc.weight.transpose(0, 1)
154 | if self.use_diag:
155 | att_mat = att_mat * self.diag.unsqueeze(1)
156 | att_score = torch.matmul(rep, att_mat) # (nsum, H) * (H, N) -> (nsum, N)
157 | for i in range(len(scope)):
158 | bag_mat = rep[scope[i][0]:scope[i][1]] # (n, H)
159 | softmax_att_score = self.softmax(att_score[scope[i][0]:scope[i][1]].transpose(0, 1)) # (N, (softmax)n)
160 | rep_for_each_rel = torch.matmul(softmax_att_score, bag_mat) # (N, n) * (n, H) -> (N, H)
161 | logit_for_each_rel = self.softmax(self.fc(rep_for_each_rel)) # ((each rel)N, (logit)N)
162 | logit_for_each_rel = logit_for_each_rel.diag() # (N)
163 | bag_logits.append(logit_for_each_rel)
164 | bag_logits = torch.stack(bag_logits, 0) # after **softmax**
165 | else:
166 | if mask is not None:
167 | rep = self.sentence_encoder(token, pos1, pos2, mask) # (nsum, H)
168 | else:
169 | rep = self.sentence_encoder(token, pos1, pos2) # (nsum, H)
170 |
171 | batch_size = rep.size(0) // bag_size
172 | att_mat = self.fc.weight.transpose(0, 1)
173 | if self.use_diag:
174 | att_mat = att_mat * self.diag.unsqueeze(1)
175 | att_score = torch.matmul(rep, att_mat) # (nsum, H) * (H, N) -> (nsum, N)
176 | att_score = att_score.view(batch_size, bag_size, -1) # (B, bag, N)
177 | rep = rep.view(batch_size, bag_size, -1) # (B, bag, H)
178 | softmax_att_score = self.softmax(att_score.transpose(1, 2)) # (B, N, (softmax)bag)
179 | rep_for_each_rel = torch.matmul(softmax_att_score, rep) # (B, N, bag) * (B, bag, H) -> (B, N, H)
180 | bag_logits = self.softmax(self.fc(rep_for_each_rel)).diagonal(dim1=1, dim2=2) # (B, (each rel)N)
181 | return bag_logits
182 |
183 |
--------------------------------------------------------------------------------
/opennre/model/bag_average.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn, optim
3 | from .base_model import BagRE
4 |
5 | class BagAverage(BagRE):
6 | """
7 | Average policy for bag-level relation extraction.
8 | """
9 |
10 | def __init__(self, sentence_encoder, num_class, rel2id):
11 | """
12 | Args:
13 | sentence_encoder: encoder for sentences
14 | num_class: number of classes
15 | id2rel: dictionary of id -> relation name mapping
16 | """
17 | super().__init__()
18 | self.sentence_encoder = sentence_encoder
19 | self.num_class = num_class
20 | self.fc = nn.Linear(self.sentence_encoder.hidden_size, num_class)
21 | self.softmax = nn.Softmax(-1)
22 | self.rel2id = rel2id
23 | self.id2rel = {}
24 | self.drop = nn.Dropout()
25 | for rel, id in rel2id.items():
26 | self.id2rel[id] = rel
27 |
28 | def infer(self, bag):
29 | """
30 | Args:
31 | bag: bag of sentences with the same entity pair
32 | [{
33 | 'text' or 'token': ...,
34 | 'h': {'pos': [start, end], ...},
35 | 't': {'pos': [start, end], ...}
36 | }]
37 | Return:
38 | (relation, score)
39 | """
40 | pass
41 |
42 | """
43 | tokens = []
44 | pos1s = []
45 | pos2s = []
46 | masks = []
47 | for item in bag:
48 | if 'text' in item:
49 | token, pos1, pos2, mask = self.tokenizer(item['text'],
50 | item['h']['pos'], item['t']['pos'], is_token=False, padding=True)
51 | else:
52 | token, pos1, pos2, mask = self.tokenizer(item['token'],
53 | item['h']['pos'], item['t']['pos'], is_token=True, padding=True)
54 | tokens.append(token)
55 | pos1s.append(pos1)
56 | pos2s.append(pos2)
57 | masks.append(mask)
58 | tokens = torch.cat(tokens, 0) # (n, L)
59 | pos1s = torch.cat(pos1s, 0)
60 | pos2s = torch.cat(pos2s, 0)
61 | masks = torch.cat(masks, 0)
62 | scope = torch.tensor([[0, len(bag)]]).long() # (1, 2)
63 | bag_logits = self.forward(None, scope, tokens, pos1s, pos2s, masks, train=False).squeeze(0) # (N) after softmax
64 | score, pred = bag_logits.max()
65 | score = score.item()
66 | pred = pred.item()
67 | rel = self.id2rel[pred]
68 | return (rel, score)
69 | """
70 |
71 | def forward(self, label, scope, token, pos1, pos2, mask=None, train=True, bag_size=None):
72 | """
73 | Args:
74 | label: (B), label of the bag
75 | scope: (B), scope for each bag
76 | token: (nsum, L), index of tokens
77 | pos1: (nsum, L), relative position to head entity
78 | pos2: (nsum, L), relative position to tail entity
79 | mask: (nsum, L), used for piece-wise CNN
80 | Return:
81 | logits, (B, N)
82 | """
83 | if bag_size > 0:
84 | token = token.view(-1, token.size(-1))
85 | pos1 = pos1.view(-1, pos1.size(-1))
86 | pos2 = pos2.view(-1, pos2.size(-1))
87 | if mask is not None:
88 | mask = mask.view(-1, mask.size(-1))
89 | else:
90 | begin, end = scope[0][0], scope[-1][1]
91 | token = token[:, begin:end, :].view(-1, token.size(-1))
92 | pos1 = pos1[:, begin:end, :].view(-1, pos1.size(-1))
93 | pos2 = pos2[:, begin:end, :].view(-1, pos2.size(-1))
94 | if mask is not None:
95 | mask = mask[:, begin:end, :].view(-1, mask.size(-1))
96 | scope = torch.sub(scope, torch.zeros_like(scope).fill_(begin))
97 |
98 | if train or bag_size > 0:
99 | if mask is not None:
100 | rep = self.sentence_encoder(token, pos1, pos2, mask) # (nsum, H)
101 | else:
102 | rep = self.sentence_encoder(token, pos1, pos2) # (nsum, H)
103 | else:
104 | rep = []
105 | bs = 256
106 | total_bs = len(token) // bs + (1 if len(token) % bs != 0 else 0)
107 | for b in range(total_bs):
108 | with torch.no_grad():
109 | left = bs * b
110 | right = min(bs * (b + 1), len(token))
111 | if mask is not None:
112 | rep.append(self.sentence_encoder(token[left:right], pos1[left:right], pos2[left:right], mask[left:right]).detach()) # (nsum, H)
113 | else:
114 | rep.append(self.sentence_encoder(token[left:right], pos1[left:right], pos2[left:right]).detach()) # (nsum, H)
115 | rep = torch.cat(rep, 0)
116 |
117 | # Average
118 | bag_rep = []
119 | if bag_size is None or bag_size == 0:
120 | for i in range(len(scope)):
121 | bag_rep.append(rep[scope[i][0]:scope[i][1]].mean(0))
122 | bag_rep = torch.stack(bag_rep, 0) # (B, H)
123 | else:
124 | batch_size = len(scope)
125 | rep = rep.view(batch_size, bag_size, -1) # (B, bag, H)
126 | bag_rep = rep.mean(1) # (B, H)
127 | bag_rep = self.drop(bag_rep)
128 | bag_logits = self.fc(bag_rep) # (B, N)
129 |
130 | if not train:
131 | bag_logits = torch.softmax(bag_logits, -1)
132 |
133 | return bag_logits
134 |
135 |
--------------------------------------------------------------------------------
/opennre/model/bag_one.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn, optim
3 | from .base_model import BagRE
4 |
5 | class BagOne(BagRE):
6 | """
7 | Instance one(max) for bag-level relation extraction.
8 | """
9 |
10 | def __init__(self, sentence_encoder, num_class, rel2id):
11 | """
12 | Args:
13 | sentence_encoder: encoder for sentences
14 | num_class: number of classes
15 | id2rel: dictionary of id -> relation name mapping
16 | """
17 | super().__init__()
18 | self.sentence_encoder = sentence_encoder
19 | self.num_class = num_class
20 | self.fc = nn.Linear(self.sentence_encoder.hidden_size, num_class)
21 | self.softmax = nn.Softmax(-1)
22 | self.rel2id = rel2id
23 | self.id2rel = {}
24 | self.drop = nn.Dropout()
25 | for rel, id in rel2id.items():
26 | self.id2rel[id] = rel
27 |
28 | def infer(self, bag):
29 | """
30 | Args:
31 | bag: bag of sentences with the same entity pair
32 | [{
33 | 'text' or 'token': ...,
34 | 'h': {'pos': [start, end], ...},
35 | 't': {'pos': [start, end], ...}
36 | }]
37 | Return:
38 | (relation, score)
39 | """
40 | self.eval()
41 | tokens = []
42 | pos1s = []
43 | pos2s = []
44 | masks = []
45 | for item in bag:
46 | token, pos1, pos2, mask = self.sentence_encoder.tokenize(item)
47 | tokens.append(token)
48 | pos1s.append(pos1)
49 | pos2s.append(pos2)
50 | masks.append(mask)
51 | tokens = torch.cat(tokens, 0).unsqueeze(0) # (n, L)
52 | pos1s = torch.cat(pos1s, 0).unsqueeze(0)
53 | pos2s = torch.cat(pos2s, 0).unsqueeze(0)
54 | masks = torch.cat(masks, 0).unsqueeze(0)
55 | scope = torch.tensor([[0, len(bag)]]).long() # (1, 2)
56 | bag_logits = self.forward(None, scope, tokens, pos1s, pos2s, masks, train=False).squeeze(0) # (N) after softmax
57 | score, pred = bag_logits.max(0)
58 | score = score.item()
59 | pred = pred.item()
60 | rel = self.id2rel[pred]
61 | return (rel, score)
62 |
63 | def forward(self, label, scope, token, pos1, pos2, mask=None, train=True, bag_size=0):
64 | """
65 | Args:
66 | label: (B), label of the bag
67 | scope: (B), scope for each bag
68 | token: (nsum, L), index of tokens
69 | pos1: (nsum, L), relative position to head entity
70 | pos2: (nsum, L), relative position to tail entity
71 | mask: (nsum, L), used for piece-wise CNN
72 | Return:
73 | logits, (B, N)
74 | """
75 | # Encode
76 | if bag_size > 0:
77 | token = token.view(-1, token.size(-1))
78 | pos1 = pos1.view(-1, pos1.size(-1))
79 | pos2 = pos2.view(-1, pos2.size(-1))
80 | if mask is not None:
81 | mask = mask.view(-1, mask.size(-1))
82 | else:
83 | begin, end = scope[0][0], scope[-1][1]
84 | token = token[:, begin:end, :].view(-1, token.size(-1))
85 | pos1 = pos1[:, begin:end, :].view(-1, pos1.size(-1))
86 | pos2 = pos2[:, begin:end, :].view(-1, pos2.size(-1))
87 | if mask is not None:
88 | mask = mask[:, begin:end, :].view(-1, mask.size(-1))
89 | scope = torch.sub(scope, torch.zeros_like(scope).fill_(begin))
90 |
91 | if train or bag_size > 0:
92 | if mask is not None:
93 | rep = self.sentence_encoder(token, pos1, pos2, mask) # (nsum, H)
94 | else:
95 | rep = self.sentence_encoder(token, pos1, pos2) # (nsum, H)
96 | else:
97 | rep = []
98 | bs = 256
99 | total_bs = len(token) // bs + (1 if len(token) % bs != 0 else 0)
100 | for b in range(total_bs):
101 | with torch.no_grad():
102 | left = bs * b
103 | right = min(bs * (b + 1), len(token))
104 | if mask is not None:
105 | rep.append(self.sentence_encoder(token[left:right], pos1[left:right], pos2[left:right], mask[left:right]).detach()) # (nsum, H)
106 | else:
107 | rep.append(self.sentence_encoder(token[left:right], pos1[left:right], pos2[left:right]).detach()) # (nsum, H)
108 | rep = torch.cat(rep, 0)
109 |
110 | # Max
111 | if train:
112 | if bag_size == 0:
113 | bag_rep = []
114 | query = torch.zeros((rep.size(0))).long()
115 | if torch.cuda.is_available():
116 | query = query.cuda()
117 | for i in range(len(scope)):
118 | query[scope[i][0]:scope[i][1]] = label[i]
119 |
120 | for i in range(len(scope)): # iterate over bags
121 | bag_mat = rep[scope[i][0]:scope[i][1]] # (n, H)
122 | instance_logit = self.softmax(self.fc(bag_mat)) # (n, N)
123 | # select j* which scores highest on the known label
124 | max_index = instance_logit[:, query[i]].argmax() # (1)
125 | bag_rep.append(bag_mat[max_index]) # (n, H) -> (H)
126 | bag_rep = torch.stack(bag_rep, 0) # (B, H)
127 | bag_rep = self.drop(bag_rep)
128 | bag_logits = self.fc(bag_rep) # (B, N)
129 | else:
130 | batch_size = label.size(0)
131 | query = label # (B)
132 | rep = rep.view(batch_size, bag_size, -1)
133 | instance_logit = self.softmax(self.fc(rep))
134 | max_index = instance_logit[torch.arange(batch_size), :, query].argmax(-1)
135 | bag_rep = rep[torch.arange(batch_size), max_index]
136 |
137 | bag_rep = self.drop(bag_rep)
138 | bag_logits = self.fc(bag_rep) # (B, N)
139 |
140 | else:
141 | if bag_size == 0:
142 | bag_logits = []
143 | for i in range(len(scope)):
144 | bag_mat = rep[scope[i][0]:scope[i][1]] # (n, H)
145 | instance_logit = self.softmax(self.fc(bag_mat)) # (n, N)
146 | logit_for_each_rel = instance_logit.max(dim=0)[0] # (N)
147 | bag_logits.append(logit_for_each_rel)
148 | bag_logits = torch.stack(bag_logits, 0) # after **softmax**
149 | else:
150 | batch_size = rep.size(0) // bag_size
151 | rep = rep.view(batch_size, bag_size, -1)
152 | bag_logits = self.softmax(self.fc(rep)).max(1)[0]
153 |
154 | return bag_logits
155 |
156 |
--------------------------------------------------------------------------------
/opennre/model/base_model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | import json
4 |
5 | class SentenceRE(nn.Module):
6 | def __init__(self):
7 | super().__init__()
8 |
9 | def infer(self, item):
10 | """
11 | Args:
12 | item: {'text' or 'token', 'h': {'pos': [start, end]}, 't': ...}
13 | Return:
14 | (Name of the relation of the sentence, score)
15 | """
16 | raise NotImplementedError
17 |
18 |
19 | class BagRE(nn.Module):
20 | def __init__(self):
21 | super().__init__()
22 |
23 | def infer(self, bag):
24 | """
25 | Args:
26 | bag: bag of sentences with the same entity pair
27 | [{
28 | 'text' or 'token': ...,
29 | 'h': {'pos': [start, end], ...},
30 | 't': {'pos': [start, end], ...}
31 | }]
32 | Return:
33 | (relation, score)
34 | """
35 | raise NotImplementedError
36 |
37 | class FewShotRE(nn.Module):
38 | def __init__(self):
39 | super().__init__()
40 |
41 | def infer(self, support, query):
42 | """
43 | Args:
44 | support: supporting set.
45 | [{'text' or 'token': ...,
46 | 'h': {'pos': [start, end], ...},
47 | 't': {'pos': [start, end], ...},
48 | 'relation': ...}]
49 | query: same format as support
50 | Return:
51 | [(relation, score), ...]
52 |
53 |
54 | For few-shot relation extraction, please refer to FewRel
55 | https://github.com/thunlp/FewRel
56 | """
57 | raise NotImplementedError
58 |
59 | class NER(nn.Module):
60 | def __init__(self):
61 | super().__init__()
62 |
63 | def ner(self, sentence, is_token=False):
64 | """
65 | Args:
66 | sentence: string, the input sentence
67 | is_token: if is_token == True, senetence becomes an array of token
68 | Return:
69 | [{name: xx, pos: [start, end]}], a list of named entities
70 | """
71 | raise NotImplementedError
72 |
--------------------------------------------------------------------------------
/opennre/model/sigmoid_nn.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn, optim
3 | from .base_model import SentenceRE
4 |
5 | class SigmoidNN(SentenceRE):
6 | """
7 | Sigmoid (binary) classifier for sentence-level relation extraction.
8 | """
9 |
10 | def __init__(self, sentence_encoder, num_class, rel2id):
11 | """
12 | Args:
13 | sentence_encoder: encoder for sentences
14 | num_class: number of classes
15 | id2rel: dictionary of id -> relation name mapping
16 | """
17 | super().__init__()
18 | self.sentence_encoder = sentence_encoder
19 | self.num_class = num_class
20 | self.fc = nn.Linear(self.sentence_encoder.hidden_size, num_class)
21 | self.rel2id = rel2id
22 | self.id2rel = {}
23 | self.drop = nn.Dropout()
24 | for rel, id in rel2id.items():
25 | self.id2rel[id] = rel
26 |
27 | def forward(self, *args):
28 | """
29 | Args:
30 | args: depends on the encoder
31 | Return:
32 | logits, (B, N)
33 | """
34 | rep = self.sentence_encoder(*args) # (B, H)
35 | rep = self.drop(rep)
36 | logits = self.fc(rep) # (B, N)
37 | return logits
38 |
39 | def logit_to_score(self, logits):
40 | return torch.sigmoid(logits)
41 |
--------------------------------------------------------------------------------
/opennre/model/softmax_nn.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn, optim
3 | from .base_model import SentenceRE
4 |
5 | class SoftmaxNN(SentenceRE):
6 | """
7 | Softmax classifier for sentence-level relation extraction.
8 | """
9 |
10 | def __init__(self, sentence_encoder, num_class, rel2id):
11 | """
12 | Args:
13 | sentence_encoder: encoder for sentences
14 | num_class: number of classes
15 | id2rel: dictionary of id -> relation name mapping
16 | """
17 | super().__init__()
18 | self.sentence_encoder = sentence_encoder
19 | self.num_class = num_class
20 | self.fc = nn.Linear(self.sentence_encoder.hidden_size, num_class)
21 | self.softmax = nn.Softmax(-1)
22 | self.rel2id = rel2id
23 | self.id2rel = {}
24 | self.drop = nn.Dropout()
25 | for rel, id in rel2id.items():
26 | self.id2rel[id] = rel
27 |
28 | def infer(self, item):
29 | self.eval()
30 | _item = self.sentence_encoder.tokenize(item)
31 | item = []
32 | for x in _item:
33 | item.append(x.to(next(self.parameters()).device))
34 | logits = self.forward(*item)
35 | logits = self.softmax(logits)
36 | score, pred = logits.max(-1)
37 | score = score.item()
38 | pred = pred.item()
39 | return self.id2rel[pred], score
40 |
41 | def forward(self, *args):
42 | """
43 | Args:
44 | args: depends on the encoder
45 | Return:
46 | logits, (B, N)
47 | """
48 | rep = self.sentence_encoder(*args) # (B, H)
49 | rep = self.drop(rep)
50 | logits = self.fc(rep) # (B, N)
51 | return logits
52 |
53 | def logit_to_score(self, logits):
54 | return torch.softmax(logits, -1)
55 |
--------------------------------------------------------------------------------
/opennre/module/__init__.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
--------------------------------------------------------------------------------
/opennre/module/nn/__init__.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | from .cnn import CNN
6 | from .rnn import RNN
7 | from .lstm import LSTM
8 |
9 | __all__ = [
10 | 'CNN',
11 | 'RNN',
12 | 'LSTM',
13 | ]
--------------------------------------------------------------------------------
/opennre/module/nn/cnn.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | class CNN(nn.Module):
6 |
7 | def __init__(self, input_size=50, hidden_size=256, dropout=0, kernel_size=3, padding=1, activation_function=F.relu):
8 | """
9 | Args:
10 | input_size: dimention of input embedding
11 | kernel_size: kernel_size for CNN
12 | padding: padding for CNN
13 | hidden_size: hidden size
14 | """
15 | super().__init__()
16 | self.conv = nn.Conv1d(input_size, hidden_size, kernel_size, padding=padding)
17 | self.act = activation_function
18 | self.dropout = nn.Dropout(dropout)
19 |
20 | def forward(self, x):
21 | """
22 | Args:
23 | input features: (B, L, I_EMBED)
24 | Return:
25 | output features: (B, H_EMBED)
26 | """
27 | # Check size of tensors
28 | x = x.transpose(1, 2) # (B, I_EMBED, L)
29 | x = self.conv(x) # (B, H_EMBED, L)
30 | x = self.act(x) # (B, H_EMBED, L)
31 | x = self.dropout(x) # (B, H_EMBED, L)
32 | x = x.transpose(1, 2) # (B, L, H_EMBED)
33 | return x
34 |
--------------------------------------------------------------------------------
/opennre/module/nn/lstm.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | class LSTM(nn.Module):
5 |
6 | def __init__(self, input_size=50, hidden_size=256, dropout=0, bidirectional=False, num_layers=1, activation_function="tanh"):
7 | """
8 | Args:
9 | input_size: dimention of input embedding
10 | hidden_size: hidden size
11 | dropout: dropout layer on the outputs of each RNN layer except the last layer
12 | bidirectional: if it is a bidirectional RNN
13 | num_layers: number of recurrent layers
14 | activation_function: the activation function of RNN, tanh/relu
15 | """
16 | super().__init__()
17 | if bidirectional:
18 | hidden_size /= 2
19 | self.lstm = nn.LSTM(input_size,
20 | hidden_size,
21 | num_layers,
22 | nonlinearity=activation_function,
23 | dropout=dropout,
24 | bidirectional=bidirectional)
25 |
26 | def forward(self, x):
27 | """
28 | Args:
29 | input features: (B, L, I_EMBED)
30 | Return:
31 | output features: (B, L, H_EMBED)
32 | """
33 | # Check size of tensors
34 | x = x.transpose(0, 1) # (L, B, I_EMBED)
35 | x, h, c = self.lstm(x) # (L, B, H_EMBED)
36 | x = x.transpose(0, 1) # (B, L, I_EMBED)
37 | return x
38 |
--------------------------------------------------------------------------------
/opennre/module/nn/rnn.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | class RNN(nn.Module):
5 |
6 | def __init__(self, input_size=50, hidden_size=256, dropout=0, bidirectional=False, num_layers=1, activation_function="tanh"):
7 | """
8 | Args:
9 | input_size: dimention of input embedding
10 | hidden_size: hidden size
11 | dropout: dropout layer on the outputs of each RNN layer except the last layer
12 | bidirectional: if it is a bidirectional RNN
13 | num_layers: number of recurrent layers
14 | activation_function: the activation function of RNN, tanh/relu
15 | """
16 | super().__init__()
17 | if bidirectional:
18 | hidden_size /= 2
19 | self.rnn = nn.RNN(input_size,
20 | hidden_size,
21 | num_layers,
22 | nonlinearity=activation_function,
23 | dropout=dropout,
24 | bidirectional=bidirectional)
25 |
26 | def forward(self, x):
27 | """
28 | Args:
29 | input features: (B, L, I_EMBED)
30 | Return:
31 | output features: (B, L, H_EMBED)
32 | """
33 | # Check size of tensors
34 | x = x.transpose(0, 1) # (L, B, I_EMBED)
35 | x, h = self.rnn(x) # (L, B, H_EMBED)
36 | x = x.transpose(0, 1) # (B, L, I_EMBED)
37 | return x
38 |
--------------------------------------------------------------------------------
/opennre/module/pool/__init__.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | from .max_pool import MaxPool
6 | from .avg_pool import AvgPool
7 |
8 | __all__ = [
9 | 'MaxPool',
10 | 'AvgPool'
11 | ]
--------------------------------------------------------------------------------
/opennre/module/pool/avg_pool.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import numpy as np
5 |
6 | class AvgPool(nn.Module):
7 |
8 | def __init__(self, kernel_size, segment_num=None):
9 | """
10 | Args:
11 | input_size: dimention of input embedding
12 | kernel_size: kernel_size for CNN
13 | padding: padding for CNN
14 | hidden_size: hidden size
15 | """
16 | super().__init__()
17 | self.segment_num = segment_num
18 | if self.segment_num != None:
19 | self.mask_embedding = nn.Embedding(segment_num + 1, segment_num)
20 | self.mask_embedding.weight.data.copy_(torch.FloatTensor(np.concatenate([np.zeros(segment_num), np.identity(segment_num)], axis = 0)))
21 | self.mask_embedding.weight.requires_grad = False
22 | self.pool = nn.AvgPool1d(kernel_size)
23 |
24 | def forward(self, x, mask=None):
25 | """
26 | Args:
27 | input features: (B, L, I_EMBED)
28 | Return:
29 | output features: (B, H_EMBED)
30 | """
31 | # Check size of tensors
32 | if mask == None or self.segment_num == None or self.segment_num == 1:
33 | x = x.transpose(1, 2) # (B, L, I_EMBED) -> (B, I_EMBED, L)
34 | x = self.pool(x).squeeze(-1) # (B, I_EMBED, 1) -> (B, I_EMBED)
35 | return x
36 | else:
37 | B, L, I_EMBED = x.size()[:2]
38 | mask = self.mask_embedding(mask).transpose(1, 2).unsqueeze(2) # (B, L) -> (B, L, S) -> (B, S, L) -> (B, S, 1, L)
39 | x = x.transpose(1, 2).unsqueeze(1) # (B, L, I_EMBED) -> (B, I_EMBED, L) -> (B, 1, I_EMBED, L)
40 | x = (x * mask).view([-1, I_EMBED, L]) # (B, S, I_EMBED, L) -> (B * S, I_EMBED, L)
41 | x = self.pool(x).squeeze(-1) # (B * S, I_EMBED, 1) -> (B * S, I_EMBED)
42 | x = x.view([B, -1]) - self._minus # (B, S * I_EMBED)
43 | return x
--------------------------------------------------------------------------------
/opennre/module/pool/max_pool.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import numpy as np
5 |
6 | class MaxPool(nn.Module):
7 |
8 | def __init__(self, kernel_size, segment_num=None):
9 | """
10 | Args:
11 | input_size: dimention of input embedding
12 | kernel_size: kernel_size for CNN
13 | padding: padding for CNN
14 | hidden_size: hidden size
15 | """
16 | super().__init__()
17 | self.segment_num = segment_num
18 | if self.segment_num != None:
19 | self.mask_embedding = nn.Embedding(segment_num + 1, segment_num)
20 | self.mask_embedding.weight.data.copy_(torch.FloatTensor(np.concatenate([np.zeros((1, segment_num)), np.identity(segment_num)], axis=0)))
21 | self.mask_embedding.weight.requires_grad = False
22 | self._minus = -100
23 | self.pool = nn.MaxPool1d(kernel_size)
24 |
25 | def forward(self, x, mask=None):
26 | """
27 | Args:
28 | input features: (B, L, I_EMBED)
29 | Return:
30 | output features: (B, H_EMBED)
31 | """
32 | # Check size of tensors
33 | if mask is None or self.segment_num is None or self.segment_num == 1:
34 | x = x.transpose(1, 2) # (B, L, I_EMBED) -> (B, I_EMBED, L)
35 | x = self.pool(x).squeeze(-1) # (B, I_EMBED, 1) -> (B, I_EMBED)
36 | return x
37 | else:
38 | B, L, I_EMBED = x.size()[:3]
39 | # mask = 1 - self.mask_embedding(mask).transpose(1, 2).unsqueeze(2) # (B, L) -> (B, L, S) -> (B, S, L) -> (B, S, 1, L)
40 | # x = x.transpose(1, 2).unsqueeze(1) # (B, L, I_EMBED) -> (B, I_EMBED, L) -> (B, 1, I_EMBED, L)
41 | # x = (x + self._minus * mask).contiguous().view([-1, I_EMBED, L]) # (B, S, I_EMBED, L) -> (B * S, I_EMBED, L)
42 | # x = self.pool(x).squeeze(-1) # (B * S, I_EMBED, 1) -> (B * S, I_EMBED)
43 | # x = x.view([B, -1]) # (B, S * I_EMBED)
44 | # return x
45 | mask = 1 - self.mask_embedding(mask).transpose(1, 2)
46 | x = x.transpose(1, 2)
47 | pool1 = self.pool(x + self._minus * mask[:, 0:1, :])
48 | pool2 = self.pool(x + self._minus * mask[:, 1:2, :])
49 | pool3 = self.pool(x + self._minus * mask[:, 2:3, :])
50 |
51 | x = torch.cat([pool1, pool2, pool3], 1)
52 | # x = x.squeeze(-1)
53 | return x
54 |
--------------------------------------------------------------------------------
/opennre/pretrain.py:
--------------------------------------------------------------------------------
1 | from . import encoder
2 | from . import model
3 | from . import framework
4 | import torch
5 | import os
6 | import sys
7 | import json
8 | import numpy as np
9 | import logging
10 |
11 | root_url = "https://thunlp.oss-cn-qingdao.aliyuncs.com/"
12 | default_root_path = os.path.join(os.getenv('HOME'), '.opennre')
13 |
14 | def check_root(root_path=default_root_path):
15 | if not os.path.exists(root_path):
16 | os.mkdir(root_path)
17 | os.mkdir(os.path.join(root_path, 'benchmark'))
18 | os.mkdir(os.path.join(root_path, 'pretrain'))
19 | os.mkdir(os.path.join(root_path, 'pretrain/nre'))
20 |
21 | def download_wiki80(root_path=default_root_path):
22 | check_root()
23 | if not os.path.exists(os.path.join(root_path, 'benchmark/wiki80')):
24 | os.mkdir(os.path.join(root_path, 'benchmark/wiki80'))
25 | os.system('wget -P ' + os.path.join(root_path, 'benchmark/wiki80') + ' ' + root_url + 'opennre/benchmark/wiki80/wiki80_rel2id.json')
26 | os.system('wget -P ' + os.path.join(root_path, 'benchmark/wiki80') + ' ' + root_url + 'opennre/benchmark/wiki80/wiki80_train.txt')
27 | os.system('wget -P ' + os.path.join(root_path, 'benchmark/wiki80') + ' ' + root_url + 'opennre/benchmark/wiki80/wiki80_val.txt')
28 |
29 | def download_tacred(root_path=default_root_path):
30 | check_root()
31 | if not os.path.exists(os.path.join(root_path, 'benchmark/tacred')):
32 | os.mkdir(os.path.join(root_path, 'benchmark/tacred'))
33 | os.system('wget -P ' + os.path.join(root_path, 'benchmark/tacred') + ' ' + root_url + 'opennre/benchmark/tacred/tacred_rel2id.json')
34 | logging.info('Due to copyright limits, we only provide rel2id for TACRED. Please download TACRED manually and convert the data to OpenNRE format if needed.')
35 |
36 | def download_nyt10(root_path=default_root_path):
37 | check_root()
38 | if not os.path.exists(os.path.join(root_path, 'benchmark/nyt10')):
39 | os.mkdir(os.path.join(root_path, 'benchmark/nyt10'))
40 | os.system('wget -P ' + os.path.join(root_path, 'benchmark/nyt10') + ' ' + root_url + 'opennre/benchmark/nyt10/nyt10_rel2id.json')
41 | os.system('wget -P ' + os.path.join(root_path, 'benchmark/nyt10') + ' ' + root_url + 'opennre/benchmark/nyt10/nyt10_train.txt')
42 | os.system('wget -P ' + os.path.join(root_path, 'benchmark/nyt10') + ' ' + root_url + 'opennre/benchmark/nyt10/nyt10_test.txt')
43 |
44 | def download_nyt10m(root_path=default_root_path):
45 | check_root()
46 | if not os.path.exists(os.path.join(root_path, 'benchmark/nyt10m')):
47 | os.mkdir(os.path.join(root_path, 'benchmark/nyt10m'))
48 | os.system('wget -P ' + os.path.join(root_path, 'benchmark/nyt10m') + ' ' + root_url + 'opennre/benchmark/nyt10m/nyt10m_rel2id.json')
49 | os.system('wget -P ' + os.path.join(root_path, 'benchmark/nyt10m') + ' ' + root_url + 'opennre/benchmark/nyt10m/nyt10m_train.txt')
50 | os.system('wget -P ' + os.path.join(root_path, 'benchmark/nyt10m') + ' ' + root_url + 'opennre/benchmark/nyt10m/nyt10m_test.txt')
51 | os.system('wget -P ' + os.path.join(root_path, 'benchmark/nyt10m') + ' ' + root_url + 'opennre/benchmark/nyt10m/nyt10m_val.txt')
52 |
53 | def download_wiki20m(root_path=default_root_path):
54 | check_root()
55 | if not os.path.exists(os.path.join(root_path, 'benchmark/wiki20m')):
56 | os.mkdir(os.path.join(root_path, 'benchmark/wiki20m'))
57 | os.system('wget -P ' + os.path.join(root_path, 'benchmark/wiki20m') + ' ' + root_url + 'opennre/benchmark/wiki20m/wiki20m_rel2id.json')
58 | os.system('wget -P ' + os.path.join(root_path, 'benchmark/wiki20m') + ' ' + root_url + 'opennre/benchmark/wiki20m/wiki20m_train.txt')
59 | os.system('wget -P ' + os.path.join(root_path, 'benchmark/wiki20m') + ' ' + root_url + 'opennre/benchmark/wiki20m/wiki20m_test.txt')
60 | os.system('wget -P ' + os.path.join(root_path, 'benchmark/wiki20m') + ' ' + root_url + 'opennre/benchmark/wiki20m/wiki20m_val.txt')
61 |
62 | def download_wiki_distant(root_path=default_root_path):
63 | check_root()
64 | if not os.path.exists(os.path.join(root_path, 'benchmark/wiki_distant')):
65 | os.mkdir(os.path.join(root_path, 'benchmark/wiki_distant'))
66 | os.system('wget -P ' + os.path.join(root_path, 'benchmark/wiki_distant') + ' ' + root_url + 'opennre/benchmark/wiki_distant/wiki_distant_rel2id.json')
67 | os.system('wget -P ' + os.path.join(root_path, 'benchmark/wiki_distant') + ' ' + root_url + 'opennre/benchmark/wiki_distant/wiki_distant_train.txt')
68 | os.system('wget -P ' + os.path.join(root_path, 'benchmark/wiki_distant') + ' ' + root_url + 'opennre/benchmark/wiki_distant/wiki_distant_test.txt')
69 | os.system('wget -P ' + os.path.join(root_path, 'benchmark/wiki_distant') + ' ' + root_url + 'opennre/benchmark/wiki_distant/wiki_distant_val.txt')
70 |
71 | def download_semeval(root_path=default_root_path):
72 | check_root()
73 | if not os.path.exists(os.path.join(root_path, 'benchmark/semeval')):
74 | os.mkdir(os.path.join(root_path, 'benchmark/semeval'))
75 | os.system('wget -P ' + os.path.join(root_path, 'benchmark/semeval') + ' ' + root_url + 'opennre/benchmark/semeval/semeval_rel2id.json')
76 | os.system('wget -P ' + os.path.join(root_path, 'benchmark/semeval') + ' ' + root_url + 'opennre/benchmark/semeval/semeval_train.txt')
77 | os.system('wget -P ' + os.path.join(root_path, 'benchmark/semeval') + ' ' + root_url + 'opennre/benchmark/semeval/semeval_test.txt')
78 | os.system('wget -P ' + os.path.join(root_path, 'benchmark/semeval') + ' ' + root_url + 'opennre/benchmark/semeval/semeval_val.txt')
79 |
80 | def download_glove(root_path=default_root_path):
81 | check_root()
82 | if not os.path.exists(os.path.join(root_path, 'pretrain/glove')):
83 | os.mkdir(os.path.join(root_path, 'pretrain/glove'))
84 | os.system('wget -P ' + os.path.join(root_path, 'pretrain/glove') + ' ' + root_url + 'opennre/pretrain/glove/glove.6B.50d_mat.npy')
85 | os.system('wget -P ' + os.path.join(root_path, 'pretrain/glove') + ' ' + root_url + 'opennre/pretrain/glove/glove.6B.50d_word2id.json')
86 |
87 | def download_bert_base_uncased(root_path=default_root_path):
88 | check_root()
89 | if not os.path.exists(os.path.join(root_path, 'pretrain/bert-base-uncased')):
90 | os.mkdir(os.path.join(root_path, 'pretrain/bert-base-uncased'))
91 | os.system('wget -P ' + os.path.join(root_path, 'pretrain/bert-base-uncased') + ' ' + root_url + 'opennre/pretrain/bert-base-uncased/config.json')
92 | os.system('wget -P ' + os.path.join(root_path, 'pretrain/bert-base-uncased') + ' ' + root_url + 'opennre/pretrain/bert-base-uncased/pytorch_model.bin')
93 | os.system('wget -P ' + os.path.join(root_path, 'pretrain/bert-base-uncased') + ' ' + root_url + 'opennre/pretrain/bert-base-uncased/vocab.txt')
94 |
95 | def download_pretrain(model_name, root_path=default_root_path):
96 | ckpt = os.path.join(root_path, 'pretrain/nre/' + model_name + '.pth.tar')
97 | if not os.path.exists(ckpt):
98 | os.system('wget -P ' + os.path.join(root_path, 'pretrain/nre') + ' ' + root_url + 'opennre/pretrain/nre/' + model_name + '.pth.tar')
99 |
100 | def download(name, root_path=default_root_path):
101 | if not os.path.exists(os.path.join(root_path, 'benchmark')):
102 | os.mkdir(os.path.join(root_path, 'benchmark'))
103 | if not os.path.exists(os.path.join(root_path, 'pretrain')):
104 | os.mkdir(os.path.join(root_path, 'pretrain'))
105 | if name == 'nyt10':
106 | download_nyt10(root_path=root_path)
107 | elif name == 'nyt10m':
108 | download_nyt10m(root_path=root_path)
109 | elif name == 'wiki20m':
110 | download_wiki20m(root_path=root_path)
111 | elif name == 'wiki_distant':
112 | download_wiki_distant(root_path=root_path)
113 | elif name == 'semeval':
114 | download_semeval(root_path=root_path)
115 | elif name == 'wiki80':
116 | download_wiki80(root_path=root_path)
117 | elif name == 'tacred':
118 | download_tacred(root_path=root_path)
119 | elif name == 'glove':
120 | download_glove(root_path=root_path)
121 | elif name == 'bert_base_uncased':
122 | download_bert_base_uncased(root_path=root_path)
123 | else:
124 | raise Exception('Cannot find corresponding data.')
125 |
126 | def get_model(model_name, root_path=default_root_path):
127 | check_root()
128 | ckpt = os.path.join(root_path, 'pretrain/nre/' + model_name + '.pth.tar')
129 | if model_name == 'wiki80_cnn_softmax':
130 | download_pretrain(model_name, root_path=root_path)
131 | download('glove', root_path=root_path)
132 | download('wiki80', root_path=root_path)
133 | wordi2d = json.load(open(os.path.join(root_path, 'pretrain/glove/glove.6B.50d_word2id.json')))
134 | word2vec = np.load(os.path.join(root_path, 'pretrain/glove/glove.6B.50d_mat.npy'))
135 | rel2id = json.load(open(os.path.join(root_path, 'benchmark/wiki80/wiki80_rel2id.json')))
136 | sentence_encoder = encoder.CNNEncoder(token2id=wordi2d,
137 | max_length=40,
138 | word_size=50,
139 | position_size=5,
140 | hidden_size=230,
141 | blank_padding=True,
142 | kernel_size=3,
143 | padding_size=1,
144 | word2vec=word2vec,
145 | dropout=0.5)
146 | m = model.SoftmaxNN(sentence_encoder, len(rel2id), rel2id)
147 | m.load_state_dict(torch.load(ckpt, map_location='cpu')['state_dict'])
148 | return m
149 | elif model_name in ['wiki80_bert_softmax', 'wiki80_bertentity_softmax']:
150 | download_pretrain(model_name, root_path=root_path)
151 | download('bert_base_uncased', root_path=root_path)
152 | download('wiki80', root_path=root_path)
153 | rel2id = json.load(open(os.path.join(root_path, 'benchmark/wiki80/wiki80_rel2id.json')))
154 | if 'entity' in model_name:
155 | sentence_encoder = encoder.BERTEntityEncoder(
156 | max_length=80, pretrain_path=os.path.join(root_path, 'pretrain/bert-base-uncased'))
157 | else:
158 | sentence_encoder = encoder.BERTEncoder(
159 | max_length=80, pretrain_path=os.path.join(root_path, 'pretrain/bert-base-uncased'))
160 | m = model.SoftmaxNN(sentence_encoder, len(rel2id), rel2id)
161 | m.load_state_dict(torch.load(ckpt, map_location='cpu')['state_dict'])
162 | return m
163 | elif model_name in ['tacred_bert_softmax', 'tacred_bertentity_softmax']:
164 | download_pretrain(model_name, root_path=root_path)
165 | download('bert_base_uncased', root_path=root_path)
166 | download('tacred', root_path=root_path)
167 | rel2id = json.load(open(os.path.join(root_path, 'benchmark/tacred/tacred_rel2id.json')))
168 | if 'entity' in model_name:
169 | sentence_encoder = encoder.BERTEntityEncoder(
170 | max_length=80, pretrain_path=os.path.join(root_path, 'pretrain/bert-base-uncased'))
171 | else:
172 | sentence_encoder = encoder.BERTEncoder(
173 | max_length=80, pretrain_path=os.path.join(root_path, 'pretrain/bert-base-uncased'))
174 | m = model.SoftmaxNN(sentence_encoder, len(rel2id), rel2id)
175 | m.load_state_dict(torch.load(ckpt, map_location='cpu')['state_dict'])
176 | return m
177 | else:
178 | raise NotImplementedError
179 |
--------------------------------------------------------------------------------
/opennre/tokenization/__init__.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 |
3 | from __future__ import absolute_import
4 | from __future__ import division
5 | from __future__ import print_function
6 |
7 | from .basic_tokenizer import BasicTokenizer
8 | from .word_piece_tokenizer import WordpieceTokenizer
9 | from .word_tokenizer import WordTokenizer
10 | from .bert_tokenizer import BertTokenizer
11 |
12 | __all__ = [
13 | 'BasicTokenizer',
14 | 'WordpieceTokenizer',
15 | 'WordTokenizer',
16 | 'BertTokenizer',
17 | 'is_whitespace',
18 | 'is_control',
19 | 'is_punctuation',
20 | 'is_chinese_char',
21 | 'convert_to_unicode',
22 | 'clean_text',
23 | 'split_on_whitespace',
24 | 'split_on_punctuation',
25 | 'tokenize_chinese_chars',
26 | 'strip_accents',
27 | 'printable_text',
28 | 'convert_by_vocab',
29 | 'convert_tokens_to_ids',
30 | 'convert_ids_to_tokens'
31 | ]
32 |
33 |
34 |
--------------------------------------------------------------------------------
/opennre/tokenization/basic_tokenizer.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The Google AI Language Team Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """BasicTokenizer classes."""
17 |
18 | from __future__ import absolute_import
19 | from __future__ import division
20 | from __future__ import print_function
21 |
22 | from .utils import (convert_to_unicode,
23 | clean_text,
24 | split_on_whitespace,
25 | split_on_punctuation,
26 | tokenize_chinese_chars,
27 | strip_accents)
28 |
29 | class BasicTokenizer(object):
30 | """Runs basic tokenization (punctuation splitting, lower casing, etc.)."""
31 |
32 | def __init__(self,
33 | do_lower_case=True,
34 | never_split=("[UNK]", "[SEP]", "[PAD]", "[CLS]", "[MASK]")):
35 | """Constructs a BasicTokenizer.
36 | Args:
37 | do_lower_case: Whether to lower case the input.
38 | """
39 | self.do_lower_case = do_lower_case
40 | self.never_split = never_split
41 |
42 | def tokenize(self, text):
43 | """Tokenizes a piece of text."""
44 | text = convert_to_unicode(text)
45 | text = clean_text(text)
46 | text = tokenize_chinese_chars(text)
47 | # This was added on November 1st, 2018 for the multilingual and Chinese
48 | # models. This is also applied to the English models now, but it doesn't
49 | # matter since the English models were not trained on any Chinese data
50 | # and generally don't have any Chinese data in them (there are Chinese
51 | # characters in the vocabulary because Wikipedia does have some Chinese
52 | # words in the English Wikipedia.).
53 | orig_tokens = split_on_whitespace(text)
54 | split_tokens = []
55 | current_positions = []
56 | for token in orig_tokens:
57 | if self.do_lower_case and token not in self.never_split:
58 | token = token.lower()
59 | token = strip_accents(token)
60 | current_positions.append([])
61 | current_positions[-1].append(len(split_tokens))
62 | split_tokens.extend(split_on_punctuation(token))
63 | current_positions[-1].append(len(split_tokens))
64 | return split_tokens, current_positions
65 |
--------------------------------------------------------------------------------
/opennre/tokenization/bert_tokenizer.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The Google AI Language Team Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """BertTokenizer classes."""
17 |
18 | from __future__ import absolute_import
19 | from __future__ import division
20 | from __future__ import print_function
21 |
22 | import collections
23 | import unicodedata
24 | import six
25 |
26 | from .basic_tokenizer import BasicTokenizer
27 | from .word_piece_tokenizer import WordpieceTokenizer
28 | from .utils import (convert_to_unicode,
29 | load_vocab,
30 | convert_by_vocab,
31 | convert_tokens_to_ids,
32 | convert_ids_to_tokens)
33 |
34 | class BertTokenizer(object):
35 |
36 | def __init__(self,
37 | vocab = None,
38 | do_lower_case = True,
39 | do_basic_tokenize = True,
40 | never_split=["[UNK]", "[SEP]", "[PAD]", "[CLS]", "[MASK]"]):
41 |
42 | self.vocab = load_vocab(vocab)
43 | self.inv_vocab = {v: k for k, v in self.vocab.items()}
44 | self.basic_tokenizer = BasicTokenizer(do_lower_case = do_lower_case, never_split = never_split)
45 | self.wordpiece_tokenizer = WordpieceTokenizer(vocab = self.vocab)
46 | self.do_basic_tokenize = do_basic_tokenize
47 |
48 | def tokenize(self, text):
49 | split_tokens = []
50 | if self.do_basic_tokenize:
51 | tokens, _ = self.basic_tokenizer.tokenize(text)
52 | text = " ".join(tokens)
53 | split_tokens, current_positions = self.wordpiece_tokenizer.tokenize(text)
54 | return split_tokens, current_positions
55 |
56 | def convert_tokens_to_ids(self, tokens):
57 | return convert_by_vocab(self.vocab, tokens)
58 |
59 | def convert_ids_to_tokens(self, ids):
60 | return convert_by_vocab(self.inv_vocab, ids)
61 |
--------------------------------------------------------------------------------
/opennre/tokenization/utils.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | import collections
6 | import unicodedata
7 | import six
8 |
9 | def is_whitespace(char):
10 | """ Checks whether `chars` is a whitespace character.
11 | \t, \n, and \r are technically contorl characters but we treat them
12 | as whitespace since they are generally considered as such.
13 | """
14 | if char == " " or char == "\t" or char == "\n" or char == "\r":
15 | return True
16 | cat = unicodedata.category(char)
17 | if cat == "Zs":
18 | return True
19 | return False
20 |
21 | def is_control(char):
22 | """ Checks whether `chars` is a control character.
23 | These are technically control characters but we count them as whitespace characters.
24 | """
25 | if char == "\t" or char == "\n" or char == "\r":
26 | return False
27 | cat = unicodedata.category(char)
28 | if cat.startswith("C"):
29 | return True
30 | return False
31 |
32 | def is_punctuation(char):
33 | """ Checks whether `chars` is a punctuation character.
34 | We treat all non-letter/number ASCII as punctuation. Characters such as "^", "$", and "`" are not in the Unicode.
35 | Punctuation class but we treat them as punctuation anyways, for consistency.
36 | """
37 | cp = ord(char)
38 | if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)):
39 | return True
40 | cat = unicodedata.category(char)
41 | if cat.startswith("P"):
42 | return True
43 | return False
44 |
45 | def is_chinese_char(cp):
46 | """ Checks whether CP is the codepoint of a CJK character.
47 | This defines a "chinese character" as anything in the CJK Unicode block:
48 | https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
49 | Note that the CJK Unicode block is NOT all Japanese and Korean characters,
50 | despite its name. The modern Korean Hangul alphabet is a different block,
51 | as is Japanese Hiragana and Katakana. Those alphabets are used to write
52 | space-separated words, so they are not treated specially and handled
53 | like the all of the other languages.
54 | """
55 | if ((cp >= 0x4E00 and cp <= 0x9FFF) or
56 | (cp >= 0x3400 and cp <= 0x4DBF) or
57 | (cp >= 0x20000 and cp <= 0x2A6DF) or
58 | (cp >= 0x2A700 and cp <= 0x2B73F) or
59 | (cp >= 0x2B740 and cp <= 0x2B81F) or
60 | (cp >= 0x2B820 and cp <= 0x2CEAF) or
61 | (cp >= 0xF900 and cp <= 0xFAFF) or
62 | (cp >= 0x2F800 and cp <= 0x2FA1F)):
63 | return True
64 | return False
65 |
66 | def convert_to_unicode(text):
67 | """Converts `text` to Unicode (if it's not already), assuming utf-8 input."""
68 | if six.PY3:
69 | if isinstance(text, str):
70 | return text
71 | elif isinstance(text, bytes):
72 | return text.decode("utf-8", "ignore")
73 | else:
74 | raise ValueError("Unsupported string type: %s" % (type(text)))
75 | elif six.PY2:
76 | if isinstance(text, str):
77 | return text.decode("utf-8", "ignore")
78 | elif isinstance(text, unicode):
79 | return text
80 | else:
81 | raise ValueError("Unsupported string type: %s" % (type(text)))
82 | else:
83 | raise ValueError("Not running on Python2 or Python 3?")
84 |
85 | def clean_text(text):
86 | output = []
87 | for char in text:
88 | cp = ord(char)
89 | if cp == 0 or cp == 0xfffd or is_control(char):
90 | continue
91 | if is_whitespace(char):
92 | output.append(" ")
93 | else:
94 | output.append(char)
95 | return "".join(output)
96 |
97 | def split_on_whitespace(text):
98 | """ Runs basic whitespace cleaning and splitting on a peice of text.
99 | e.g, 'a b c' -> ['a', 'b', 'c']
100 | """
101 | text = text.strip()
102 | if not text:
103 | return []
104 | return text.split()
105 |
106 | def split_on_punctuation(text):
107 | """Splits punctuation on a piece of text."""
108 | start_new_word = True
109 | output = []
110 | for char in text:
111 | if is_punctuation(char):
112 | output.append([char])
113 | start_new_word = True
114 | else:
115 | if start_new_word:
116 | output.append([])
117 | start_new_word = False
118 | output[-1].append(char)
119 | return ["".join(x) for x in output]
120 |
121 | def tokenize_chinese_chars(text):
122 | """Adds whitespace around any CJK character."""
123 | output = []
124 | for char in text:
125 | cp = ord(char)
126 | if is_chinese_char(cp):
127 | output.append(" ")
128 | output.append(char)
129 | output.append(" ")
130 | else:
131 | output.append(char)
132 | return "".join(output)
133 |
134 | def strip_accents(text):
135 | """Strips accents from a piece of text."""
136 | text = unicodedata.normalize("NFD", text)
137 | output = []
138 | for char in text:
139 | cat = unicodedata.category(char)
140 | if cat == "Mn":
141 | continue
142 | output.append(char)
143 | return "".join(output)
144 |
145 | def load_vocab(vocab_file):
146 | """Loads a vocabulary file into a dictionary."""
147 | if vocab_file == None:
148 | raise Exception("Vocab file not provided")
149 | if isinstance(vocab_file, str) or isinstance(vocab_file, bytes):
150 | vocab = collections.OrderedDict()
151 | index = 0
152 | with open(vocab_file, "r", encoding="utf-8") as reader:
153 | while True:
154 | token = reader.readline()
155 | if not token:
156 | break
157 | token = token.strip()
158 | vocab[token] = index
159 | index += 1
160 | return vocab
161 | else:
162 | return vocab_file
163 |
164 | def printable_text(text):
165 | """ Returns text encoded in a way suitable for print or `tf.logging`.
166 | These functions want `str` for both Python2 and Python3, but in one case
167 | it's a Unicode string and in the other it's a byte string.
168 | """
169 | if six.PY3:
170 | if isinstance(text, str):
171 | return text
172 | elif isinstance(text, bytes):
173 | return text.decode("utf-8", "ignore")
174 | else:
175 | raise ValueError("Unsupported string type: %s" % (type(text)))
176 | elif six.PY2:
177 | if isinstance(text, str):
178 | return text
179 | elif isinstance(text, unicode):
180 | return text.encode("utf-8")
181 | else:
182 | raise ValueError("Unsupported string type: %s" % (type(text)))
183 | else:
184 | raise ValueError("Not running on Python2 or Python 3?")
185 |
186 | def convert_by_vocab(vocab, items, max_seq_length = None, blank_id = 0, unk_id = 1, uncased = True):
187 | """Converts a sequence of [tokens|ids] using the vocab."""
188 | output = []
189 | for item in items:
190 | if uncased:
191 | item = item.lower()
192 | if item in vocab:
193 | output.append(vocab[item])
194 | else:
195 | output.append(unk_id)
196 | if max_seq_length != None:
197 | if len(output) > max_seq_length:
198 | output = output[:max_seq_length]
199 | else:
200 | while len(output) < max_seq_length:
201 | output.append(blank_id)
202 | return output
203 |
204 | def convert_tokens_to_ids(vocab, tokens, max_seq_length = None, blank_id = 0, unk_id = 1):
205 | return convert_by_vocab(vocab, tokens, max_seq_length, blank_id, unk_id)
206 |
207 | def convert_ids_to_tokens(inv_vocab, ids):
208 | return convert_by_vocab(inv_vocab, ids, uncased=False)
209 |
210 | def truncate_seq_pair(tokens_a, tokens_b, max_num_tokens, rng):
211 | """Truncates a pair of sequences to a maximum sequence length."""
212 | while True:
213 | total_length = len(tokens_a) + len(tokens_b)
214 | if total_length <= max_num_tokens:
215 | break
216 | trunc_tokens = tokens_a if len(tokens_a) > len(tokens_b) else tokens_b
217 | assert len(trunc_tokens) >= 1
218 | # We want to sometimes truncate from the front and sometimes from the
219 | # back to add more randomness and avoid biases.
220 | if rng.random() < 0.5:
221 | del trunc_tokens[0]
222 | else:
223 | trunc_tokens.pop()
224 |
225 | def create_int_feature(values):
226 | feature = tf.train.Feature(int64_list=tf.train.Int64List(value=list(values)))
227 | return feature
228 |
229 | def create_float_feature(values):
230 | feature = tf.train.Feature(float_list=tf.train.FloatList(value=list(values)))
231 | return feature
232 |
233 | def add_token(tokens_a, tokens_b = None):
234 | assert len(tokens_a) >= 1
235 |
236 | tokens = []
237 | segment_ids = []
238 |
239 | tokens.append("[CLS]")
240 | segment_ids.append(0)
241 |
242 | for token in tokens_a:
243 | tokens.append(token)
244 | segment_ids.append(0)
245 |
246 | tokens.append("[SEP]")
247 | segment_ids.append(0)
248 |
249 | if tokens_b != None:
250 | assert len(tokens_b) >= 1
251 |
252 | for token in tokens_b:
253 | tokens.append(token)
254 | segment_ids.append(1)
255 |
256 | tokens.append("[SEP]")
257 | segment_ids.append(1)
258 |
259 | return tokens, segment_ids
260 |
--------------------------------------------------------------------------------
/opennre/tokenization/word_piece_tokenizer.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The Google AI Language Team Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """WordpieceTokenizer classes."""
17 |
18 | from __future__ import absolute_import
19 | from __future__ import division
20 | from __future__ import print_function
21 |
22 | import unicodedata
23 |
24 | from .utils import (load_vocab,
25 | convert_to_unicode,
26 | clean_text,
27 | split_on_whitespace,
28 | convert_by_vocab,
29 | tokenize_chinese_chars)
30 |
31 | class WordpieceTokenizer(object):
32 | """Runs WordPiece tokenziation."""
33 |
34 | def __init__(self, vocab = None, unk_token="[UNK]", max_input_chars_per_word=200):
35 | self.vocab = load_vocab(vocab)
36 | self.inv_vocab = {v: k for k, v in self.vocab.items()}
37 | self.unk_token = unk_token
38 | self.max_input_chars_per_word = max_input_chars_per_word
39 |
40 | def tokenize(self, text):
41 | """ Tokenizes a piece of text into its word pieces. This uses a greedy longest-match-first algorithm to perform tokenization
42 | using the given vocabulary.
43 |
44 | For example:
45 | input = "unaffable"
46 | output = ["un", "##aff", "##able"]
47 |
48 | Args:
49 | text: A single token or whitespace separated tokens. This should have already been passed through `BasicTokenizer`.
50 | Returns:
51 | output_tokens: A list of wordpiece tokens.
52 | current_positions: A list of the current positions for the original words in text .
53 | """
54 | text = convert_to_unicode(text)
55 | text = clean_text(text)
56 | text = tokenize_chinese_chars(text)
57 | output_tokens = []
58 | current_positions = []
59 | token_list = split_on_whitespace(text)
60 | for chars in token_list:
61 | if len(chars) > self.max_input_chars_per_word:
62 | output_tokens.append(self.unk_token)
63 | continue
64 | is_bad = False
65 | start = 0
66 | sub_tokens = []
67 | while start < len(chars):
68 | end = len(chars)
69 | if start > 0:
70 | substr = "##" + chars[start:end]
71 | else:
72 | substr = chars[start:end]
73 | cur_substr = None
74 | while start < end:
75 | if substr in self.vocab:
76 | cur_substr = substr
77 | break
78 | end -= 1
79 | substr = substr[:-1]
80 | if cur_substr is None:
81 | is_bad = True
82 | break
83 | else:
84 | sub_tokens.append(cur_substr)
85 | start = end
86 | current_positions.append([])
87 | if is_bad:
88 | current_positions[-1].append(len(output_tokens))
89 | output_tokens.append(self.unk_token)
90 | current_positions[-1].append(len(output_tokens))
91 | else:
92 | current_positions[-1].append(len(output_tokens))
93 | output_tokens.extend(sub_tokens)
94 | current_positions[-1].append(len(output_tokens))
95 |
96 | return output_tokens, current_positions
97 |
98 | def convert_tokens_to_ids(self, tokens):
99 | return convert_by_vocab(self.vocab, tokens)
100 |
101 | def convert_ids_to_tokens(self, ids):
102 | return convert_by_vocab(self.inv_vocab, ids)
103 |
--------------------------------------------------------------------------------
/opennre/tokenization/word_tokenizer.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The Google AI Language Team Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """WordpieceTokenizer classes."""
17 |
18 | from __future__ import absolute_import
19 | from __future__ import division
20 | from __future__ import print_function
21 |
22 | import unicodedata
23 |
24 | from .utils import (load_vocab,
25 | convert_to_unicode,
26 | clean_text,
27 | split_on_whitespace,
28 | convert_by_vocab,
29 | tokenize_chinese_chars)
30 |
31 | class WordTokenizer(object):
32 | """Runs WordPiece tokenziation."""
33 |
34 | def __init__(self, vocab = None, unk_token="[UNK]"):
35 | self.vocab = load_vocab(vocab)
36 | self.inv_vocab = {v: k for k, v in self.vocab.items()}
37 | self.unk_token = unk_token
38 |
39 | def tokenize(self, text):
40 | """ Tokenizes a piece of text into its word pieces. This uses a greedy longest-match-first algorithm to perform tokenization
41 | using the given vocabulary.
42 |
43 | For example:
44 | input = "unaffable"
45 | output = ["un", "##aff", "##able"]
46 |
47 | Args:
48 | text: A single token or whitespace separated tokens. This should have already been passed through `BasicTokenizer`.
49 | Returns:
50 | output_tokens: A list of wordpiece tokens.
51 | current_positions: A list of the current positions for the original words in text .
52 | """
53 | text = convert_to_unicode(text)
54 | text = clean_text(text)
55 | text = tokenize_chinese_chars(text)
56 | # output_tokens = []
57 | token_list = split_on_whitespace(text)
58 | # for chars in token_list:
59 | # # current_positions.append([])
60 | # if chars in self.vocab:
61 | # output_tokens.append(chars)
62 | # else:
63 | # output_tokens.append(self.unk_token)
64 | return token_list
65 |
66 | def convert_tokens_to_ids(self, tokens, max_seq_length = None, blank_id = 0, unk_id = 1, uncased = True):
67 | return convert_by_vocab(self.vocab, tokens, max_seq_length, blank_id, unk_id, uncased=uncased)
68 |
69 | def convert_ids_to_tokens(self, ids):
70 | return convert_by_vocab(self.inv_vocab, ids)
71 |
--------------------------------------------------------------------------------
/pretrain/download_bert.sh:
--------------------------------------------------------------------------------
1 | mkdir bert-base-uncased
2 | wget -P bert-base-uncased https://thunlp.oss-cn-qingdao.aliyuncs.com/opennre/pretrain/bert-base-uncased/config.json
3 | wget -P bert-base-uncased https://thunlp.oss-cn-qingdao.aliyuncs.com/opennre/pretrain/bert-base-uncased/pytorch_model.bin
4 | wget -P bert-base-uncased https://thunlp.oss-cn-qingdao.aliyuncs.com/opennre/pretrain/bert-base-uncased/vocab.txt
5 |
--------------------------------------------------------------------------------
/pretrain/download_glove.sh:
--------------------------------------------------------------------------------
1 | mkdir glove
2 | wget -P glove https://thunlp.oss-cn-qingdao.aliyuncs.com/opennre/pretrain/glove/glove.6B.50d_mat.npy
3 | wget -P glove https://thunlp.oss-cn-qingdao.aliyuncs.com/opennre/pretrain/glove/glove.6B.50d_word2id.json
4 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | torch==1.6.0
2 | transformers==3.4.0
3 | pytest==5.3.2
4 | scikit-learn==0.22.1
5 | scipy==1.4.1
6 | nltk>=3.6.4
7 |
8 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | import setuptools
2 | with open("README.md", "r") as fh:
3 | setuptools.setup(
4 | name='open-nre',
5 | version='0.1.1',
6 | author="Tianyu Gao",
7 | author_email="gaotianyu1350@126.com",
8 | description="An open source toolkit for relation extraction",
9 | url="https://github.com/thunlp/opennre",
10 | packages=setuptools.find_packages(),
11 | classifiers=[
12 | "Programming Language :: Python :: 3",
13 | "License :: OSI Approved :: MIT License",
14 | "Operating System :: POSIX :: Linux",
15 | ],
16 | install_requires=['nltk', 'transformers', 'scikit-learn'],
17 | setup_requires=['wheel']
18 | )
19 |
--------------------------------------------------------------------------------
/tests/test_inference.py:
--------------------------------------------------------------------------------
1 | import unittest
2 | import opennre
3 |
4 | class TestInference(unittest.TestCase):
5 |
6 | def test_wiki80_cnn_softmax(self):
7 | model = opennre.get_model('wiki80_cnn_softmax')
8 | result = model.infer({'text': 'He was the son of Máel Dúin mac Máele Fithrich, and grandson of the high king Áed Uaridnach (died 612).', 'h': {'pos': (18, 46)}, 't': {'pos': (78, 91)}})
9 | print(result)
10 | self.assertEqual(result[0], 'father')
11 | self.assertTrue(abs(result[1] - 0.7500484585762024) < 1e-6)
12 |
13 | if __name__ == '__main__':
14 | unittest.main()
15 |
--------------------------------------------------------------------------------