├── .gitignore ├── LICENSE ├── README.md ├── preprocessor ├── __init__.py └── adapters.py ├── requirements.txt ├── tests ├── __init__.py ├── adapters_test.py ├── dataset_test.py ├── layers_torch_test.py ├── metrics_test.py ├── models_torch_test.py ├── tagging_scheme_test.py └── truncator_test.py └── tplinker ├── __init__.py ├── dataset.py ├── layers_torch.py ├── metrics.py ├── models_torch.py ├── run_tplinker.py ├── tagging_scheme.py └── truncator.py /.gitignore: -------------------------------------------------------------------------------- 1 | data 2 | log 3 | *.log 4 | .vscode 5 | .idea 6 | *.pyc 7 | *.iml 8 | __pycache__ -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # TPLinker 2 | 3 | 论文 [TPLinker: Single-stage Joint Extraction of Entities and Relations Through Token Pair Linking](https://www.aclweb.org/anthology/2020.coling-main.138.pdf) 的PyTorch实现。 4 | 5 | ## 所需依赖 6 | 7 | * pytorch 8 | * pytorch-lightning 9 | 10 | 11 | ## 训练数据 12 | 13 | 整理好的NYT数据集下载:[NYT](https://huaichen-oss.oss-cn-hangzhou.aliyuncs.com/public/tplinker-bert-nyt.zip?versionId=CAEQDxiBgMCKm83SxhciIDFmNmY1OGZiMzc0YzRhMDY4ODBmZTEyNDhlOTJmYTg3) 14 | 15 | > 下载之后解压,放到当前项目的 `data/` 目录下。 16 | 17 | 也可以用下述命令下载: 18 | ```bash 19 | mkdir data && cd data 20 | wget -O tplinker-bert-nyt.zip https://huaichen-oss.oss-cn-hangzhou.aliyuncs.com/public/tplinker-bert-nyt.zip?versionId=CAEQDxiBgMCKm83SxhciIDFmNmY1OGZiMzc0YzRhMDY4ODBmZTEyNDhlOTJmYTg3 21 | 22 | unzip tplinker-bert-nyt.zip 23 | 24 | ``` 25 | 26 | 27 | 训练数据格式如下: 28 | 29 | ```bash 30 | {"text": "In Queens , North Shore Towers , near the Nassau border , supplanted a golf course , and housing replaced a gravel quarry in Douglaston .", "id": "valid_0", "relation_list": [{"subject": "Douglaston", "object": "Queens", "subj_char_span": [125, 135], "obj_char_span": [3, 9], "predicate": "/location/neighborhood/neighborhood_of", "subj_tok_span": [26, 28], "obj_tok_span": [1, 2]}, {"subject": "Queens", "object": "Douglaston", "subj_char_span": [3, 9], "obj_char_span": [125, 135], "predicate": "/location/location/contains", "subj_tok_span": [1, 2], "obj_tok_span": [26, 28]}], "entity_list": [{"text": "Douglaston", "type": "DEFAULT", "char_span": [125, 135], "tok_span": [26, 28]}, {"text": "Queens", "type": "DEFAULT", "char_span": [3, 9], "tok_span": [1, 2]}, {"text": "Queens", "type": "DEFAULT", "char_span": [3, 9], "tok_span": [1, 2]}, {"text": "Douglaston", "type": "DEFAULT", "char_span": [125, 135], "tok_span": [26, 28]}]} 31 | ``` 32 | 33 | ## 训练模型 34 | 35 | > 相关的参数在 `tplinker/run_tplinker.py` 文件直接修改即可。 36 | 37 | ```bash 38 | nohup python -m tplinker.run_tplinker --gpus=0 >> train.log 2>&1 & 39 | ``` 40 | -------------------------------------------------------------------------------- /preprocessor/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/luozhouyang/TPLinker/3cacf48901f73a4d4e90ed51d8d5bbf8aecb5a02/preprocessor/__init__.py -------------------------------------------------------------------------------- /preprocessor/adapters.py: -------------------------------------------------------------------------------- 1 | import abc 2 | import json 3 | import logging 4 | import os 5 | import re 6 | from collections import namedtuple 7 | 8 | from transformers import BertTokenizerFast 9 | 10 | Entity = namedtuple('Entity', ['text', 'span', 'type']) 11 | Relation = namedtuple('Relation', ['subject', 'object', 'subject_span', 'object_span', 'predict']) 12 | Example = namedtuple('Example', ['id', 'text', 'relations', 'entities']) 13 | 14 | 15 | class AbstractDatasetAdapter(abc.ABC): 16 | 17 | @abc.abstractmethod 18 | def adapte(self, input_file, output_file, **kwargs): 19 | raise NotImplementedError() 20 | 21 | 22 | class NYTBertAdapter(AbstractDatasetAdapter): 23 | 24 | def __init__(self, pretrained_bert_path, add_special_tokens=False, do_lower_case=False, **kwargs): 25 | super().__init__() 26 | self.do_lower_case = do_lower_case 27 | self.tokenizer = BertTokenizerFast.from_pretrained( 28 | pretrained_bert_path, add_special_tokens=add_special_tokens, do_lower_case=do_lower_case) 29 | 30 | def adapte(self, input_file, output_file, **kwargs): 31 | with open(output_file, mode='wt', encoding='utf-8') as fout, \ 32 | open(input_file, mode='rt', encoding='utf-8') as fin: 33 | count = 0 34 | for line in fin: 35 | data = json.loads(line) 36 | example = self._adapte_example(data) 37 | # example.pop('offset', None) 38 | # json.dump(example, fout, ensure_ascii=False) 39 | # fout.write('\n') 40 | # print(example) 41 | self._validate_example(example) 42 | count += 1 43 | if count == kwargs.get('limit', -1): 44 | break 45 | 46 | def _adapte_example(self, data): 47 | text = data['sentText'] 48 | codes = self.tokenizer.encode_plus(text, return_offsets_mapping=True, add_special_tokens=False) 49 | example = { 50 | 'text': text, 51 | 'tokens': self.tokenizer.convert_ids_to_tokens(codes['input_ids']), 52 | 'ids': codes['input_ids'], 53 | 'offset': codes['offset_mapping'] 54 | } 55 | self._adapte_entities(data, example) 56 | # TODO: finishe relations adaption 57 | # self._adapte_relations(data, example) 58 | return example 59 | 60 | def _adapte_entities(self, data, example): 61 | text = data['sentText'] 62 | entity_list = [] 63 | for e in data['entityMentions']: 64 | for m in re.finditer(re.escape(e['text']), text): 65 | char_span_start, char_span_end = m.span()[0], m.span()[1] 66 | # prev character is number 67 | if char_span_start > 0 and re.match('\d', text[char_span_start - 1]): 68 | continue 69 | # next character is number 70 | if char_span_end < len(text) and re.match('\d', text[char_span_end]): 71 | continue 72 | # get token span by char span 73 | token_span_start, token_span_end = self._parse_token_span(example, char_span_start, char_span_end) 74 | if not token_span_start or not token_span_end: 75 | print('invalid token span for entity: {}, regex match span: {}'.format(e, m.span())) 76 | continue 77 | entity_list.append({ 78 | 'text': e['text'], 79 | 'type': e['label'], 80 | 'token_span': [token_span_start, token_span_end], 81 | 'char_span': [char_span_start, char_span_end] 82 | }) 83 | example.update({ 84 | 'entity_list': entity_list 85 | }) 86 | 87 | def _adapte_relations(self, data, example): 88 | entities = {e['text']: e for e in example['entity_list']} 89 | relations_list = [] 90 | for relation in data['relationMentions']: 91 | relations_list.append({ 92 | 'subject': None, 93 | 'object': None, 94 | 'predict': None 95 | }) 96 | 97 | def _parse_token_span(self, example, start, end): 98 | token_start, token_end = None, None 99 | for idx, (token, offset) in enumerate(zip(example['tokens'], example['offset'])): 100 | if offset[0] == start and end == offset[1]: 101 | return idx, idx + 1 102 | if offset[0] == start: 103 | token_start = idx 104 | if end == offset[1]: 105 | token_end = idx 106 | if token_start is not None and token_end is not None: 107 | return token_start, token_end + 1 108 | return token_start, token_end 109 | 110 | def _validate_example(self, example): 111 | tokens = example['tokens'] 112 | text = example['text'] 113 | for entity in example['entity_list']: 114 | start, end = entity['token_span'] 115 | print() 116 | print('tokens subsequence: {}'.format(tokens[start:end])) 117 | print('entity text: {}'.format(entity['text'])) 118 | char_start, char_end = entity['char_span'] 119 | print('origin text: {}'.format(text[char_start:char_end])) 120 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | glove-python-binary==0.2.0 2 | pytorch-lightning -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/luozhouyang/TPLinker/3cacf48901f73a4d4e90ed51d8d5bbf8aecb5a02/tests/__init__.py -------------------------------------------------------------------------------- /tests/adapters_test.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from preprocessor.adapters import NYTBertAdapter 4 | 5 | 6 | class DatasetAdapterTest(unittest.TestCase): 7 | 8 | def test_nyt_dataset_adapter(self): 9 | adapter = NYTBertAdapter('data/bert-base-cased') 10 | input_file = '/mnt/nas/zhouyang.lzy/public-datasets/NYT/raw_valid.json' 11 | output_file = 'data/preprocess/nyt_valid.jsonl' 12 | adapter.adapte(input_file, output_file, limit=2) 13 | 14 | 15 | if __name__ == "__main__": 16 | unittest.main() 17 | -------------------------------------------------------------------------------- /tests/dataset_test.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import torch 4 | from tplinker.dataset import TPLinkerBertDataset 5 | 6 | 7 | class DatasetTest(unittest.TestCase): 8 | 9 | def test_bert_dataset(self): 10 | ds = TPLinkerBertDataset( 11 | input_files=['data/tplinker/bert/valid_data.jsonl'], 12 | pretrained_bert_path='data/bert-base-cased', 13 | rel2id_path='data/tplinker/bert/rel2id.json', 14 | max_sequence_length=100) 15 | 16 | dl = torch.utils.data.DataLoader(ds, batch_size=2, drop_last=True) 17 | for idx, d in enumerate(dl): 18 | if idx == 10: 19 | break 20 | print() 21 | print(d) 22 | 23 | 24 | if __name__ == "__main__": 25 | unittest.main() 26 | -------------------------------------------------------------------------------- /tests/layers_torch_test.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import numpy as np 4 | import torch 5 | from tplinker.layers_torch import (ConcatHandshaking, DistanceEmbedding, 6 | GloveEmbedding, TaggingProjector, TPLinker) 7 | 8 | 9 | class LayersTest(unittest.TestCase): 10 | 11 | def test_glove_embedding(self): 12 | words = [] 13 | with open('data/bert-base-cased/vocab.txt', mode='rt', encoding='utf-8') as fin: 14 | for line in fin: 15 | words.append(line.rstrip('\n')) 16 | vocab = {} 17 | for idx, token in enumerate(words): 18 | vocab[idx] = token 19 | ge = GloveEmbedding('data/glove_300_nyt.emb', vocab=vocab, embedding_size=300) 20 | input_ids = torch.tensor([[1, 2, 3, 4, 5, 6, 7, 8]], dtype=torch.int64) 21 | output = ge(input_ids) 22 | self.assertEqual([1, 8, 300], list(output.size())) 23 | 24 | def test_distance_embedding(self): 25 | de = DistanceEmbedding() 26 | hidden = np.zeros([1, 6, 8], dtype=np.float) 27 | output = de(torch.tensor(hidden)) 28 | print(f"output shape: {output.size()}") 29 | self.assertEqual([1, 21, 768], list(output.size())) 30 | 31 | def test_concat_handshaking(self): 32 | cs = ConcatHandshaking(768) 33 | hidden = torch.zeros([1, 6, 768], dtype=torch.float32) 34 | output = cs(hidden) 35 | print(f"output shape: {output.size()}") 36 | self.assertEqual([1, 21, 768], list(output.size())) 37 | 38 | def test_tagging_projector(self): 39 | tp = TaggingProjector(768, 24) 40 | hidden = torch.zeros([1, 21, 768], dtype=torch.float32) 41 | output = tp(hidden) 42 | print(f"output shape: {output.size()}") 43 | self.assertEqual([1, 24, 21, 3], list(output.size())) 44 | 45 | def test_tplinker(self): 46 | tp = TPLinker(768, 24, add_dist_embedding=True) 47 | hidden = torch.zeros([1, 6, 768], dtype=torch.float32) 48 | h2t, h2h, t2t = tp(hidden) 49 | self.assertEqual([1, 21, 2], list(h2t.size())) 50 | self.assertEqual([1, 24, 21, 3], list(h2h.size())) 51 | self.assertEqual([1, 24, 21, 3], list(t2t.size())) 52 | 53 | 54 | if __name__ == "__main__": 55 | unittest.main() 56 | -------------------------------------------------------------------------------- /tests/metrics_test.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from tplinker.metrics import F1, Precision, Recall, SampleAccuracy 4 | 5 | 6 | class MetricsTest(unittest.TestCase): 7 | 8 | def test_precision(self): 9 | p = Precision() 10 | v = p.compute() 11 | print(v) 12 | 13 | 14 | if __name__ == "__main__": 15 | unittest.main() 16 | -------------------------------------------------------------------------------- /tests/models_torch_test.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import torch 4 | from tplinker.models_torch import TPLinkerBert, TPLinkerBiLSTM 5 | from transformers import BertTokenizerFast 6 | 7 | 8 | class ModelsTest(unittest.TestCase): 9 | 10 | def test_tplinker_bert(self): 11 | m = TPLinkerBert('data/bert-base-cased', 24, add_dist_embedding=True) 12 | t = BertTokenizerFast.from_pretrained('data/bert-base-cased', add_special_tokens=False, do_lower_case=False) 13 | codes = t.encode_plus('I love NLP!', return_offsets_mapping=True, add_special_tokens=False) 14 | print(codes) 15 | input_ids, attn_mask, segment_ids = codes['input_ids'], codes['attention_mask'], codes['token_type_ids'] 16 | seq_len = len(input_ids) 17 | 18 | input_ids = torch.tensor([input_ids], dtype=torch.long) 19 | attn_mask = torch.tensor([attn_mask], dtype=torch.long) 20 | segment_ids = torch.tensor([segment_ids], dtype=torch.long) 21 | flat_seq_len = seq_len * (seq_len + 1) // 2 22 | 23 | h2t, h2h, t2t = m(input_ids, attn_mask, segment_ids) 24 | self.assertEqual([1, flat_seq_len, 2], list(h2t.size())) 25 | self.assertEqual([1, 24, flat_seq_len, 3], list(h2h.size())) 26 | self.assertEqual([1, 24, flat_seq_len, 3], list(t2t.size())) 27 | 28 | def test_tplinker_bilstm(self): 29 | words = [] 30 | with open('data/bert-base-cased/vocab.txt', mode='rt', encoding='utf-8') as fin: 31 | for line in fin: 32 | words.append(line.rstrip('\n')) 33 | vocab = {} 34 | for idx, token in enumerate(words): 35 | vocab[idx] = token 36 | m = TPLinkerBiLSTM(24, 768, 768, 37 | pretrained_embedding_path='data/glove_300_nyt.emb', 38 | vocab=vocab, 39 | embedding_size=300, 40 | add_dist_embedding=True).float() 41 | input_ids = torch.tensor([[1, 2, 3, 4, 5, 6, 7, 8]], dtype=torch.long) 42 | h2t, h2h, t2t = m(input_ids) 43 | self.assertEqual([1, 8 * 9 // 2, 2], list(h2t.size())) 44 | self.assertEqual([1, 24, 8 * 9 // 2, 3], list(h2h.size())) 45 | self.assertEqual([1, 24, 8 * 9 // 2, 3], list(t2t.size())) 46 | 47 | 48 | if __name__ == "__main__": 49 | unittest.main() 50 | -------------------------------------------------------------------------------- /tests/tagging_scheme_test.py: -------------------------------------------------------------------------------- 1 | import json 2 | import unittest 3 | 4 | import numpy as np 5 | import torch 6 | import torch.nn.functional as F 7 | from tplinker.tagging_scheme import (HandshakingTaggingDecoder, 8 | HandshakingTaggingEncoder, TagMapping) 9 | from tplinker.truncator import BertExampleTruncator 10 | from transformers import BertTokenizerFast 11 | 12 | 13 | class SchemeTest(unittest.TestCase): 14 | 15 | def _build_encoder(self): 16 | with open('data/tplinker/bert/rel2id.json', mode='rt', encoding='utf-8') as fin: 17 | rel2id = json.load(fin) 18 | tm = TagMapping(relation2id=rel2id) 19 | encoder = HandshakingTaggingEncoder(tag_mapping=tm) 20 | return encoder 21 | 22 | def _build_decoder(self): 23 | with open('data/tplinker/bert/rel2id.json', mode='rt', encoding='utf-8') as fin: 24 | rel2id = json.load(fin) 25 | tm = TagMapping(relation2id=rel2id) 26 | decoder = HandshakingTaggingDecoder(tag_mapping=tm) 27 | return decoder 28 | 29 | def _build_truncator(self): 30 | tokenizer = BertTokenizerFast.from_pretrained( 31 | 'data/bert-base-cased', add_special_tokens=False, do_lower_case=False) 32 | truncator = BertExampleTruncator(tokenizer, max_sequence_length=100, window_size=50) 33 | return truncator, tokenizer 34 | 35 | def _read_example(self, limit=1): 36 | examples = [] 37 | with open('data/tplinker/bert/valid_data.jsonl', mode='rt', encoding='utf-8') as fin: 38 | for line in fin: 39 | example = json.loads(line) 40 | examples.append(example) 41 | if len(examples) == limit: 42 | break 43 | return examples 44 | 45 | def test_handshaking_tagging_encoder(self): 46 | encoder = self._build_encoder() 47 | truncator, _ = self._build_truncator() 48 | examples = self._read_example(limit=100) 49 | truncated_examples = [] 50 | for example in examples: 51 | truncated_examples.extend(truncator.truncate(example)) 52 | for example in truncated_examples: 53 | print() 54 | print('example: {}'.format(example)) 55 | h2t, h2h, t2t = encoder.encode(example, max_sequence_length=100) 56 | print(f'h2t shape: {h2t.shape}, h2h shape: {h2h.shape}, t2t shape: {t2t.shape}') 57 | 58 | def test_handshaking_tagging_ecoder_example(self): 59 | example = { 60 | 'text': "-LRB- Dunning -RRB- NEXT WAVE FESTIVAL : NATIONAL BALLET OF CHINA -LRB- Tuesday through Thursday -RRB- The company will perform '' Raise the Red Lantern , '' which tells the story of a young concubine in 1920 's China with a fusion of ballet , modern dance and traditional Chinese dance , set to music performed on Western and Chinese instruments and directed by Zhang Yimou", 61 | 'entity_list': [{'text': 'Zhang Yimou', 'type': 'DEFAULT', 'tok_span': [96, 100], 'char_span': [363, 374]}, {'text': 'China', 'type': 'DEFAULT', 'tok_span': [70, 71], 'char_span': [212, 217]}], 62 | 'relation_list': [{'subj_tok_span': [96, 100], 'obj_tok_span': [70, 71], 'subj_char_span': [363, 374], 'obj_char_span': [212, 217], 'subject': 'Zhang Yimou', 'object': 'China', 'predicate': '/people/person/nationality'}], 'token_offset': 0, 'char_offset': 0, 63 | 'offset_mapping': [[0, 1], [1, 2], [2, 4], [4, 5], [6, 10], [10, 13], [14, 15], [15, 16], [16, 18], [18, 19], [20, 22], [22, 24], [25, 27], [27, 29], [30, 31], [31, 33], [33, 35], [35, 37], [37, 38], [39, 40], [41, 42], [42, 44], [44, 47], [47, 49], [50, 52], [52, 54], [54, 56], [57, 59], [60, 62], [62, 64], [64, 65], [66, 67], [67, 68], [68, 70], [70, 71], [72, 79], [80, 87], [88, 96], [97, 98], [98, 99], [99, 101], [101, 102], [103, 106], [107, 114], [115, 119], [120, 127], [128, 129], [129, 130], [131, 134], [134, 136], [137, 140], [141, 144], [145, 152], [153, 154], [155, 156], [156, 157], [158, 163], [164, 169], [170, 173], [174, 179], [180, 182], [183, 184], [185, 190], [191, 194], [194, 196], [196, 200], [201, 203], [204, 208], [209, 210], [210, 211], [212, 217], [218, 222], [223, 224], [225, 231], [232, 234], [235, 241], [242, 243], [244, 250], [251, 256], [257, 260], [261, 272], [273, 280], [281, 286], [287, 288], [289, 292], [293, 295], [296, 301], [302, 311], [312, 314], [315, 322], [323, 326], [327, 334], [335, 346], [347, 350], [351, 359], [360, 362], [363, 368], [369, 371], [371, 373], [373, 374]] 64 | } 65 | 66 | encoder = self._build_encoder() 67 | 68 | outputs = encoder.encode(example) 69 | print(np.sum(outputs[0])) 70 | print(np.sum(outputs[1])) 71 | print(np.sum(outputs[2])) 72 | print(outputs) 73 | 74 | def test_handshaking_decoder_example(self): 75 | truncator, _ = self._build_truncator() 76 | truncated_examples = [] 77 | for e in self._read_example(): 78 | truncated_examples.extend(truncator.truncate(e)) 79 | encoder = self._build_encoder() 80 | example = truncated_examples[0] 81 | print(example['relation_list']) 82 | h2t, h2h, t2t = encoder.encode(example, max_sequence_length=100) 83 | decoder = self._build_decoder() 84 | h2t_pred = F.one_hot(torch.tensor(h2t), num_classes=2) 85 | h2h_pred = F.one_hot(torch.tensor(h2h), num_classes=3) 86 | t2t_pred = F.one_hot(torch.tensor(t2t), num_classes=3) 87 | relations = decoder.decode(example, h2t_pred, h2h_pred, t2t_pred, max_sequence_length=100) 88 | print(relations) 89 | 90 | 91 | if __name__ == "__main__": 92 | unittest.main() 93 | -------------------------------------------------------------------------------- /tests/truncator_test.py: -------------------------------------------------------------------------------- 1 | import json 2 | import unittest 3 | 4 | from tplinker.truncator import BertExampleTruncator 5 | from transformers import BertTokenizerFast 6 | 7 | 8 | class TruncatorTest(unittest.TestCase): 9 | 10 | def _read_examples(self, tokenizer, nums=1, min_sequence_length=100, **kwargs): 11 | examples = [] 12 | with open('data/tplinker/bert/valid_data.jsonl', mode='rt', encoding='utf8') as fin: 13 | for line in fin: 14 | e = json.loads(line) 15 | tokens = tokenizer.tokenize(e['text']) 16 | if len(tokens) < min_sequence_length: 17 | continue 18 | examples.append(e) 19 | # if len(examples) == 17: 20 | # tokens = tokenizer.tokenize(e['text']) 21 | # print(tokens) 22 | if len(examples) == nums: 23 | break 24 | return examples 25 | 26 | def _create_truncator(self): 27 | tokenizer = BertTokenizerFast.from_pretrained( 28 | 'data/bert-base-cased', add_special_tokens=False, do_lower_case=False) 29 | truncator = BertExampleTruncator(tokenizer, max_sequence_length=100) 30 | return truncator, tokenizer 31 | 32 | def test_bert_truncator(self): 33 | truncator, tokenizer = self._create_truncator() 34 | examples = self._read_examples(tokenizer, nums=17, min_sequence_length=100) 35 | truncated_examples = truncator.truncate(example=examples[-1]) 36 | print('original example: ', examples[-1]) 37 | for e in truncated_examples: 38 | print() 39 | print(e) 40 | 41 | def test_bert_truncator_example(self): 42 | example = { 43 | 'text': 'Besides Mr. Stanley and Mr. Fugate , they include Richard Andrews , the former homeland security adviser to Gov. Arnold Schwarzenegger of California ; Ellen M. Gordon , former homeland security adviser in Iowa ; Dale W. Shipley of Ohio and Eric Tolbert of North Carolina , two former top FEMA officials who also served as the top emergency managers in their home states ; and Bruce P. Baughman , the president of the National Emergency Management Association , as well as the top disaster planning official in Alabama .', 'id': 'valid_1816', 44 | 'relation_list': [{'subject': 'Arnold Schwarzenegger', 'object': 'California', 'subj_char_span': [113, 134], 'obj_char_span': [138, 148], 'predicate': '/people/person/place_lived', 'subj_tok_span': [24, 30], 'obj_tok_span': [31, 32]}, {'subject': 'Arnold Schwarzenegger', 'object': 'California', 'subj_char_span': [113, 134], 'obj_char_span': [138, 148], 'predicate': '/business/person/company', 'subj_tok_span': [24, 30], 'obj_tok_span': [31, 32]}], 45 | 'entity_list': [{'text': 'Arnold Schwarzenegger', 'type': 'DEFAULT', 'char_span': [113, 134], 'tok_span': [24, 30]}, {'text': 'California', 'type': 'DEFAULT', 'char_span': [138, 148], 'tok_span': [31, 32]}, {'text': 'Arnold Schwarzenegger', 'type': 'DEFAULT', 'char_span': [113, 134], 'tok_span': [24, 30]}, {'text': 'California', 'type': 'DEFAULT', 'char_span': [138, 148], 'tok_span': [31, 32]}] 46 | } 47 | truncator, _ = self._create_truncator() 48 | outputs = truncator.truncate(example) 49 | for o in outputs: 50 | print() 51 | print(o) 52 | 53 | 54 | if __name__ == "__main__": 55 | unittest.main() 56 | -------------------------------------------------------------------------------- /tplinker/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/luozhouyang/TPLinker/3cacf48901f73a4d4e90ed51d8d5bbf8aecb5a02/tplinker/__init__.py -------------------------------------------------------------------------------- /tplinker/dataset.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | 4 | import torch 5 | from transformers import BertTokenizerFast 6 | 7 | from tplinker.tagging_scheme import HandshakingTaggingEncoder, TagMapping 8 | 9 | from .truncator import BertExampleTruncator 10 | 11 | 12 | class TPLinkerBertDataset(torch.utils.data.Dataset): 13 | 14 | def __init__(self, 15 | input_files, 16 | pretrained_bert_path, 17 | rel2id_path, 18 | max_sequence_length=100, 19 | window_size=50, 20 | **kwargs): 21 | self.tokenizer = BertTokenizerFast.from_pretrained( 22 | pretrained_bert_path, add_special_tokens=False, do_lower_case=False) 23 | self.bert_truncator = BertExampleTruncator( 24 | self.tokenizer, max_sequence_length=max_sequence_length, window_size=window_size) 25 | self.examples = self._read_input_files(input_files) 26 | 27 | with open(rel2id_path, mode='rt', encoding='utf-8') as fin: 28 | rel2id = json.load(fin) 29 | self.tag_mapping = TagMapping(rel2id) 30 | self.encoder = HandshakingTaggingEncoder(self.tag_mapping) 31 | self.max_sequence_length = max_sequence_length 32 | 33 | def _read_input_files(self, input_files): 34 | if isinstance(input_files, str): 35 | input_files = [input_files] 36 | all_examples = [] 37 | for f in input_files: 38 | with open(f, mode='rt', encoding='utf-8') as fin: 39 | for line in fin: 40 | example = json.loads(line) 41 | examples = self.bert_truncator.truncate(example) 42 | if examples: 43 | all_examples.extend(examples) 44 | return all_examples 45 | 46 | def __len__(self): 47 | return len(self.examples) 48 | 49 | def __getitem__(self, index): 50 | example = self.examples[index] 51 | codes = self.tokenizer.encode_plus( 52 | example['text'], 53 | return_offsets_mapping=True, 54 | add_special_tokens=False, 55 | max_length=self.max_sequence_length, 56 | padding='max_length') 57 | 58 | h2t, h2h, t2t = self.encoder.encode(example, max_sequence_length=self.max_sequence_length) 59 | 60 | item = { 61 | 'example': json.dumps(example, ensure_ascii=False), # raw contents used to compute metrics 62 | 'input_ids': torch.tensor(codes['input_ids']), 63 | 'attention_mask': torch.tensor(codes['attention_mask']), 64 | 'token_type_ids': torch.tensor(codes['token_type_ids']), 65 | 'h2t': torch.tensor(h2t), 66 | 'h2h': torch.tensor(h2h), 67 | 't2t': torch.tensor(t2t), 68 | } 69 | return item 70 | -------------------------------------------------------------------------------- /tplinker/layers_torch.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import math 3 | 4 | import numpy as np 5 | import torch 6 | import torch.nn as nn 7 | from glove import Glove 8 | 9 | 10 | class GloveEmbedding(nn.Module): 11 | 12 | def __init__(self, pretrained_embedding_path, vocab, embedding_size, freeze=False, **kwargs): 13 | super().__init__() 14 | self.vocab = vocab 15 | self.vocab_size = len(self.vocab) 16 | self.embedding_size = embedding_size 17 | self.glove = Glove.load(pretrained_embedding_path) 18 | self.embedding = nn.Embedding.from_pretrained(self._build_embedding_matrix(), freeze=freeze) 19 | 20 | def _build_embedding_matrix(self): 21 | matrix = np.random.normal(-1, 1, size=(self.vocab_size, self.embedding_size)) 22 | count = 0 23 | for idx, token in self.vocab.items(): 24 | if token in self.glove.dictionary: 25 | matrix[idx] = self.glove.word_vectors[self.glove.dictionary[token]] 26 | count += 1 27 | logging.info(f'Load {count} tokens from pretrained embedding table.') 28 | matrix = torch.tensor(matrix) 29 | return matrix 30 | 31 | def forward(self, input_ids, **kwargs): 32 | return self.embedding(input_ids) 33 | 34 | 35 | class DistanceEmbedding(nn.Module): 36 | 37 | def __init__(self, max_positions=512, embedding_size=768, **kwargs): 38 | super().__init__() 39 | self.max_positions = max_positions 40 | self.embedding_size = embedding_size 41 | self.dist_embedding = self._init_embedding_table() 42 | self.register_parameter('distance_embedding', self.dist_embedding) 43 | 44 | def _init_embedding_table(self): 45 | matrix = np.zeros([self.max_positions, self.embedding_size]) 46 | for d in range(self.max_positions): 47 | for i in range(self.embedding_size): 48 | if i % 2 == 0: 49 | matrix[d][i] = math.sin(d / 10000**(i / self.embedding_size)) 50 | else: 51 | matrix[d][i] = math.cos(d / 10000**((i - 1) / self.embedding_size)) 52 | embedding_table = nn.Parameter(data=torch.tensor(matrix, requires_grad=False), requires_grad=False) 53 | return embedding_table 54 | 55 | def forward(self, inputs, **kwargs): 56 | """Distance embedding. 57 | 58 | Args: 59 | inputs: Tensor, shape (batch_size, seq_len, hidden_size) 60 | 61 | Returns: 62 | embedding: Tensor, shape (batch_size, 1+2+...+seq_len, embedding_size) 63 | """ 64 | batch_size, seq_len = inputs.size()[0], inputs.size()[1] 65 | segs = [] 66 | for index in range(seq_len, 0, -1): 67 | segs.append(self.dist_embedding[:index, :]) 68 | segs = torch.cat(segs, dim=0) 69 | embedding = segs[None, :, :].repeat(batch_size, 1, 1) 70 | return embedding 71 | 72 | 73 | class TaggingProjector(nn.Module): 74 | 75 | def __init__(self, hidden_size, num_relations, name='proj', **kwargs): 76 | super().__init__() 77 | self.name = name 78 | self.fc_layers = [nn.Linear(hidden_size, 3) for _ in range(num_relations)] 79 | for index, fc in enumerate(self.fc_layers): 80 | self.register_parameter('{}_weights_{}'.format(self.name, index), fc.weight) 81 | self.register_parameter('{}_bias_{}'.format(self.name, index), fc.bias) 82 | 83 | def forward(self, hidden, **kwargs): 84 | """Project hiddens to tags for each relation. 85 | 86 | Args: 87 | hidden: Tensor, shape (batch_size, 1+2+...+seq_len, hidden_size) 88 | 89 | Returns: 90 | outputs: Tensor, shape (batch_size, num_relations, 1+2+...+seq_len, num_tags=3) 91 | """ 92 | outputs = [] 93 | for fc in self.fc_layers: 94 | outputs.append(fc(hidden)) 95 | outputs = torch.stack(outputs, dim=1) 96 | outputs = torch.softmax(outputs, dim=-1) 97 | return outputs 98 | 99 | 100 | class ConcatHandshaking(nn.Module): 101 | 102 | def __init__(self, hidden_size, **kwargs): 103 | super().__init__() 104 | self.fc = nn.Linear(hidden_size * 2, hidden_size) 105 | 106 | def forward(self, hidden, **kwargs): 107 | """Handshaking. 108 | 109 | Args: 110 | hidden: Tensor, shape (batch_size, seq_len, hidden_size) 111 | 112 | Returns: 113 | handshaking_hiddens: Tensor, shape (batch_size, 1+2+...+seq_len, hidden_size) 114 | """ 115 | seq_len = hidden.size()[1] 116 | handshaking_hiddens = [] 117 | for i in range(seq_len): 118 | _h = hidden[:, i, :] 119 | repeat_hidden = _h[:, None, :].repeat(1, seq_len - i, 1) 120 | visibl_hidden = hidden[:, i:, :] 121 | shaking_hidden = torch.cat([repeat_hidden, visibl_hidden], dim=-1) 122 | shaking_hidden = self.fc(shaking_hidden) 123 | shaking_hidden = torch.tanh(shaking_hidden) 124 | handshaking_hiddens.append(shaking_hidden) 125 | handshaking_hiddens = torch.cat(handshaking_hiddens, dim=1) 126 | return handshaking_hiddens 127 | 128 | 129 | class TPLinker(nn.Module): 130 | 131 | def __init__(self, hidden_size, num_relations, max_positions=512, add_distance_embedding=False, **kwargs): 132 | super().__init__() 133 | self.handshaking = ConcatHandshaking(hidden_size) 134 | self.h2t_proj = nn.Linear(hidden_size, 2) 135 | self.h2h_proj = TaggingProjector(hidden_size, num_relations, name='h2hproj') 136 | self.t2t_proj = TaggingProjector(hidden_size, num_relations, name='t2tproj') 137 | self.add_distance_embedding = add_distance_embedding 138 | if self.add_distance_embedding: 139 | self.distance_embedding = DistanceEmbedding(max_positions, embedding_size=hidden_size) 140 | 141 | def forward(self, hidden, **kwargs): 142 | """TPLinker model forward pass. 143 | 144 | Args: 145 | hidden: Tensor, output of BERT or BiLSTM, shape (batch_size, seq_len, hidden_size) 146 | 147 | Returns: 148 | h2t_hidden: Tensor, shape (batch_size, 1+2+...+seq_len, 2), 149 | logits for entity recognization 150 | h2h_hidden: Tensor, shape (batch_size, num_relations, 1+2+...+seq_len, 3), 151 | logits for relation recognization 152 | t2t_hidden: Tensor, shape (batch_size, num_relations, 1+2+...+seq_len, 3), 153 | logits for relation recognization 154 | """ 155 | handshaking_hidden = self.handshaking(hidden) 156 | h2t_hidden, rel_hidden = handshaking_hidden, handshaking_hidden 157 | if self.add_distance_embedding: 158 | h2t_hidden += self.distance_embedding(hidden) 159 | rel_hidden += self.distance_embedding(hidden) 160 | h2t_hidden = self.h2t_proj(h2t_hidden) 161 | h2h_hidden = self.h2h_proj(rel_hidden) 162 | t2t_hidden = self.t2t_proj(rel_hidden) 163 | return h2t_hidden, h2h_hidden, t2t_hidden 164 | -------------------------------------------------------------------------------- /tplinker/metrics.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | import pytorch_lightning as pl 4 | import torch 5 | from torchmetrics import Metric 6 | 7 | from tplinker.tagging_scheme import HandshakingTaggingDecoder, TagMapping 8 | 9 | 10 | class SampleAccuracy(Metric): 11 | 12 | def __init__(self): 13 | super().__init__() 14 | self.add_state('correct', default=torch.tensor(0), dist_reduce_fx='sum') 15 | self.add_state('total', default=torch.tensor(0), dist_reduce_fx='sum') 16 | 17 | def update(self, preds, target): 18 | # shape: (batch_size, num_relations, 1+2+...+seq_len) 19 | preds_id = torch.argmax(preds, dim=-1) 20 | # shape: (batch_size, num_relations * (1+2+...+seq_len)) 21 | preds_id = preds_id.view(preds_id.size()[0], -1) 22 | # shape: (batch_size, num_relations * (1+2+...+seq_len)) 23 | target = target.view(target.size()[0], -1) 24 | # num of correct tags 25 | correct_tags = torch.sum(torch.eq(target, preds_id), dim=1) 26 | # num of correct samples 27 | correct_samples = torch.sum( 28 | torch.eq(correct_tags, torch.ones_like(correct_tags) * target.size()[-1])) 29 | 30 | self.correct += correct_samples 31 | self.total += target.size()[0] 32 | 33 | def compute(self): 34 | return self.correct / self.total 35 | 36 | 37 | class _PRF(Metric): 38 | """Precision, Recall and F1 metric""" 39 | 40 | def __init__(self, pattern='only_head_text', epsilon=1e-12): 41 | super().__init__() 42 | self.pattern = pattern 43 | self.epsilon = epsilon 44 | 45 | self.add_state('correct', default=torch.tensor(0), dist_reduce_fx='sum') 46 | self.add_state('goldnum', default=torch.tensor(0), dist_reduce_fx='sum') 47 | self.add_state('prednum', default=torch.tensor(0), dist_reduce_fx='sum') 48 | 49 | def update(self, pred_relations, gold_relations): 50 | for pred, gold in zip(pred_relations, gold_relations): 51 | pred_set, gold_set = self._parse_relations_set(pred, gold) 52 | for rel in pred_set: 53 | if rel in gold_set: 54 | self.correct += 1 55 | self.prednum += len(pred_set) 56 | self.goldnum += len(gold_set) 57 | # print('metric states: correct={}, prednum={}, goldnum={}'.format(self.correct, self.prednum, self.goldnum)) 58 | 59 | def _parse_relations_set(self, pred_relations, gold_relations): 60 | if self.pattern == 'whole_span': 61 | gold_set = set(['{}-{}-{}-{}-{}'.format( 62 | rel['subj_tok_span'][0], rel['subj_tok_span'][1], rel['predicate'], rel['obj_tok_span'][0], rel['obj_tok_span'][1] 63 | ) for rel in gold_relations]) 64 | pred_set = set(['{}-{}-{}-{}-{}'.format( 65 | rel['subj_tok_span'][0], rel['subj_tok_span'][1], rel['predicate'], rel['obj_tok_span'][0], rel['obj_tok_span'][1] 66 | ) for rel in pred_relations]) 67 | return pred_set, gold_set 68 | if self.pattern == 'whole_text': 69 | gold_set = set([ 70 | '{}-{}-{}'.format(rel['subject'], rel['predicate'], rel['object']) for rel in gold_relations 71 | ]) 72 | pred_set = set([ 73 | '{}-{}-{}'.format(rel['subject'], rel['predicate'], rel['object']) for rel in pred_relations 74 | ]) 75 | return pred_set, gold_set 76 | if self.pattern == 'only_head_index': 77 | gold_set = set([ 78 | '{}-{}-{}'.format(rel['subj_tok_span'][0], rel['predicate'], rel['obj_tok_span'][0]) for rel in gold_relations 79 | ]) 80 | pred_set = set([ 81 | '{}-{}-{}'.format(rel['subj_tok_span'][0], rel['predicate'], rel['obj_tok_span'][0]) for rel in pred_relations 82 | ]) 83 | return pred_set, gold_set 84 | gold_set = set([ 85 | '{}-{}-{}'.format(rel['subject'].split(' ')[0], rel['predicate'], rel['object'].split(' ')[0]) for rel in gold_relations 86 | ]) 87 | pred_set = set([ 88 | '{}-{}-{}'.format(rel['subject'].split(' ')[0], rel['predicate'], rel['object'].split(' ')[0]) for rel in pred_relations 89 | ]) 90 | return pred_set, gold_set 91 | 92 | 93 | class Precision(_PRF): 94 | 95 | def compute(self): 96 | return self.correct / (self.prednum + self.epsilon) 97 | 98 | 99 | class Recall(_PRF): 100 | 101 | def compute(self): 102 | return self.correct / (self.goldnum + self.epsilon) 103 | 104 | 105 | class F1(_PRF): 106 | 107 | def compute(self): 108 | precision = self.correct / (self.prednum + self.epsilon) 109 | recall = self.correct / (self.goldnum + self.epsilon) 110 | f1 = 2.0 * precision * recall / (precision + recall + self.epsilon) 111 | return f1 112 | -------------------------------------------------------------------------------- /tplinker/models_torch.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import torch 4 | import torch.nn as nn 5 | from tokenizers import BertWordPieceTokenizer 6 | from transformers import BertModel 7 | 8 | from .layers_torch import GloveEmbedding, TPLinker 9 | 10 | 11 | class TPLinkerBert(nn.Module): 12 | 13 | def __init__(self, bert_model_path, num_relations, add_distance_embedding=False, **kwargs): 14 | super().__init__() 15 | self.bert = BertModel.from_pretrained(bert_model_path) 16 | self.tplinker = TPLinker( 17 | hidden_size=self.bert.config.hidden_size, 18 | num_relations=num_relations, 19 | add_distance_embedding=add_distance_embedding, 20 | max_positions=512) 21 | 22 | def forward(self, input_ids, attn_mask, token_type_ids, **kwargs): 23 | sequence_output = self.bert(input_ids, attn_mask, token_type_ids)[0] 24 | h2t_outputs, h2h_outputs, t2t_outputs = self.tplinker(sequence_output) 25 | return h2t_outputs, h2h_outputs, t2t_outputs 26 | 27 | 28 | class TPLinkerBiLSTM(nn.Module): 29 | 30 | def __init__(self, 31 | num_relations, 32 | encoder_hidden_size, 33 | deocder_hidden_size, 34 | embedding_dropout_rate=0.1, 35 | lstm_dropout_rate=0.1, 36 | add_dist_embedding=False, 37 | max_positions=512, 38 | **kwargs): 39 | super().__init__() 40 | self.embedding = self._build_embedding(**kwargs) 41 | self.embedding_dropout = nn.Dropout(embedding_dropout_rate) 42 | self.encoder = nn.LSTM( 43 | kwargs['embedding_size'], 44 | encoder_hidden_size // 2, 45 | num_layers=1, 46 | bidirectional=True, 47 | batch_first=True) 48 | self.decoder = nn.LSTM( 49 | encoder_hidden_size, 50 | deocder_hidden_size // 2, 51 | num_layers=1, 52 | bidirectional=True, 53 | batch_first=True) 54 | self.lstm_dropout = nn.Dropout(lstm_dropout_rate) 55 | self.tplinker = TPLinker( 56 | hidden_size=deocder_hidden_size, 57 | num_relations=num_relations, 58 | max_positions=max_positions, 59 | add_dist_embedding=add_dist_embedding, 60 | **kwargs) 61 | 62 | def _build_embedding(self, **kwargs): 63 | pretrained_embedding_path = kwargs.get('pretrained_embedding_path', None) 64 | vocab = kwargs.get('vocab', None) 65 | embedding_size = kwargs.get('embedding_size', None) 66 | assert embedding_size, "embedding_size must be provided." 67 | if pretrained_embedding_path and vocab and embedding_size: 68 | logging.info('Load pretrained embedding...') 69 | embedding = GloveEmbedding(pretrained_embedding_path, vocab=vocab, embedding_size=embedding_size) 70 | return embedding 71 | vocab_size = kwargs.get('vocab_size', len(vocab) if vocab else None) 72 | if vocab_size and embedding_size: 73 | logging.info('Build embedding matrix...') 74 | embedding = nn.Embedding(vocab_size, embedding_size) 75 | return embedding 76 | raise ValueError('Not enough params to build emebdding layer.') 77 | 78 | def forward(self, input_ids, **kwargs): 79 | embedding = self.embedding(input_ids) 80 | embedding = self.embedding_dropout(embedding) 81 | encoder_outputs, _ = self.encoder(embedding) 82 | encoder_outputs = self.lstm_dropout(encoder_outputs) 83 | decoder_outputs, _ = self.decoder(encoder_outputs) 84 | decoder_outputs = self.lstm_dropout(decoder_outputs) 85 | h2t_outputs, h2h_outputs, t2t_outputs = self.tplinker(decoder_outputs) 86 | return h2t_outputs, h2h_outputs, t2t_outputs 87 | -------------------------------------------------------------------------------- /tplinker/run_tplinker.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | 5 | import pytorch_lightning as pl 6 | import torch 7 | import torch.nn.functional as F 8 | from pytorch_lightning.callbacks import (Callback, EarlyStopping, 9 | ModelCheckpoint) 10 | 11 | from .dataset import TPLinkerBertDataset 12 | from .metrics import F1, Precision, Recall, SampleAccuracy 13 | from .models_torch import TPLinkerBert 14 | from .tagging_scheme import HandshakingTaggingDecoder, TagMapping 15 | 16 | 17 | def _compute_loss(y_pred, y_true): 18 | y_pred = y_pred.view(-1, y_pred.size()[-1]) 19 | y_true = y_true.view(-1) 20 | return F.cross_entropy(y_pred, y_true) 21 | 22 | 23 | class ONNXModelExport(Callback): 24 | 25 | def __init__(self, export_dir, model_name='model'): 26 | super().__init__() 27 | self.export_dir = export_dir 28 | if not os.path.exists(self.export_dir) or not os.path.isdir(self.export_dir): 29 | os.makedirs(self.export_dir) 30 | self.model_name = model_name 31 | 32 | def on_train_epoch_end(self, trainer, pl_module: pl.LightningModule, outputs): 33 | filename = os.path.join( 34 | self.export_dir, 35 | '{}-epoch-{}.onnx'.format(self.model_name, trainer.current_epoch)) 36 | 37 | device = pl_module.device 38 | input_sample = ( 39 | torch.ones((1, 100)).long().to(device), 40 | torch.ones((1, 100)).long().to(device), 41 | torch.zeros((1, 100)).long().to(device)) 42 | # pl_module.to_onnx(filename, input_sample, export_params=True) 43 | torch.onnx.export( 44 | pl_module, input_sample, filename, 45 | opset_version=10, 46 | operator_export_type=torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK) 47 | 48 | 49 | class TPLinkerLightning(pl.LightningModule): 50 | 51 | def __init__(self, model, rel2id_path, max_sequence_length=100, **kwargs): 52 | super().__init__(**kwargs) 53 | self.model = model 54 | self.h2t_acc = SampleAccuracy() 55 | self.h2h_acc = SampleAccuracy() 56 | self.t2t_acc = SampleAccuracy() 57 | self.train_precision = Precision() 58 | self.train_recll = Recall() 59 | self.train_f1 = F1() 60 | 61 | with open(rel2id_path, mode='rt', encoding='utf-8') as fin: 62 | rel2id = json.load(fin) 63 | tag_mapping = TagMapping(rel2id) 64 | self.decoder = HandshakingTaggingDecoder(tag_mapping) 65 | self.max_sequence_length = max_sequence_length 66 | 67 | def forward(self, input_ids, attention_mask, token_type_ids, **kwargs): 68 | h2t, h2h, t2t = self.model(input_ids, attention_mask, token_type_ids) 69 | return h2t, h2h, t2t 70 | 71 | def train_dataloader(self): 72 | train_dataset = TPLinkerBertDataset( 73 | input_files=['data/tplinker/bert/train_data.jsonl'], 74 | pretrained_bert_path='data/bert-base-cased', 75 | rel2id_path='data/tplinker/bert/rel2id.json') 76 | train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=128, shuffle=True) 77 | return train_dataloader 78 | 79 | def val_dataloader(self): 80 | valid_dataset = TPLinkerBertDataset( 81 | input_files=['data/tplinker/bert/valid_data.jsonl'], 82 | pretrained_bert_path='data/bert-base-cased', 83 | rel2id_path='data/tplinker/bert/rel2id.json') 84 | valid_dataloader = torch.utils.data.DataLoader(valid_dataset, batch_size=64, shuffle=False) 85 | return valid_dataloader 86 | 87 | def training_step(self, batch, index, **kwargs): 88 | input_ids, attn_mask, type_ids = batch['input_ids'], batch['attention_mask'], batch['token_type_ids'] 89 | h2t_pred, h2h_pred, t2t_pred = self.model(input_ids, attn_mask, type_ids) 90 | h2t_loss = _compute_loss(h2t_pred, batch['h2t']) 91 | h2h_loss = _compute_loss(h2h_pred, batch['h2h']) 92 | t2t_loss = _compute_loss(t2t_pred, batch['t2t']) 93 | total_loss = 0.333 * h2t_loss + 0.333 * h2h_loss + 0.333 * t2t_loss 94 | logs = {'h2t_loss': h2t_loss, 'h2h_loss': h2h_loss, 't2t_loss': t2t_loss} 95 | logs.update({ 96 | 'h2t_acc': self.h2t_acc(h2t_pred, batch['h2t']), 97 | 'h2h_acc': self.h2h_acc(h2h_pred, batch['h2h']), 98 | 't2t_acc': self.t2t_acc(t2t_pred, batch['t2t']), 99 | }) 100 | 101 | # decoding predictions takes long time in early epochs, skip to decrease training time 102 | if self.trainer.current_epoch > 0: 103 | examples = [json.loads(e) for e in batch['example']] 104 | pred_relations = self.decoder.batch_decode( 105 | examples, 106 | h2t_pred, h2h_pred, t2t_pred, 107 | max_sequence_length=self.max_sequence_length) 108 | # print('num of pred relations: {}'.format([len(x) for x in pred_relations])) 109 | gold_relations = [e['relation_list'] for e in examples] 110 | # TODO: Fixed metrics 111 | logs.update({ 112 | 'precision': self.train_precision(pred_relations, gold_relations), 113 | 'recall': self.train_recll(pred_relations, gold_relations), 114 | 'f1': self.train_f1(pred_relations, gold_relations), 115 | }) 116 | self.log_dict(logs, prog_bar=True, on_step=True, on_epoch=False) 117 | return total_loss 118 | 119 | def training_epoch_end(self, outputs): 120 | print('correct: {}, prednum: {}, goldnum: {}'.format( 121 | self.train_precision.correct, 122 | self.train_precision.prednum, 123 | self.train_precision.goldnum)) 124 | 125 | def validation_step(self, batch, index, **kwargs): 126 | input_ids, attn_mask, type_ids = batch['input_ids'], batch['attention_mask'], batch['token_type_ids'] 127 | h2t_pred, h2h_pred, t2t_pred = self.model(input_ids, attn_mask, type_ids) 128 | h2t_loss = _compute_loss(h2t_pred, batch['h2t']) 129 | h2h_loss = _compute_loss(h2h_pred, batch['h2h']) 130 | t2t_loss = _compute_loss(t2t_pred, batch['t2t']) 131 | total_loss = 0.333 * h2t_loss + 0.333 * h2h_loss + 0.333 * t2t_loss 132 | logs = {'val_h2t_loss': h2t_loss, 'val_h2h_loss': h2h_loss, 'val_t2t_loss': t2t_loss, 'val_loss': total_loss} 133 | self.log_dict(logs, prog_bar=False, on_step=True, on_epoch=True) 134 | 135 | def configure_optimizers(self): 136 | opt = torch.optim.Adam(self.parameters(), lr=3e-5) 137 | # schedule = torch.optim.lr_scheduler.CosineAnnealingRestarts(opt,) 138 | return opt 139 | 140 | 141 | def create_trainer(model_path='model/', gpus=0, **kwargs): 142 | trainer = pl.Trainer( 143 | gpus=gpus, 144 | default_root_dir=model_path, 145 | callbacks=[ 146 | ModelCheckpoint( 147 | dirpath=os.path.join(model_path, 'ckpt'), 148 | filename='tplinker-bert-{epoch}-{step}-{val_loss:.2f}', 149 | monitor='val_loss', 150 | save_top_k=kwargs.get('save_top_k', 5), 151 | mode='min' 152 | ), 153 | EarlyStopping(monitor='val_loss'), 154 | ONNXModelExport(export_dir=os.path.join(model_path, 'onnx'), model_name='tplinker-bert') 155 | ], 156 | max_epochs=kwargs.get('max_epochs', 10), 157 | ) 158 | return trainer 159 | 160 | 161 | def create_bert_model(pretrained_bert_path, num_relations, add_distance_embedding=False): 162 | model = TPLinkerBert( 163 | bert_model_path=pretrained_bert_path, 164 | num_relations=num_relations, 165 | add_distance_embedding=add_distance_embedding) 166 | return model 167 | 168 | 169 | if __name__ == "__main__": 170 | parser = argparse.ArgumentParser() 171 | parser.add_argument('--gpus', default=None) 172 | parser.add_argument('--pretrained_bert_path', default='data/bert-base-cased') 173 | parser.add_argument('--num_relations', default=24) 174 | parser.add_argument('--add_distance_embedding', default=False) 175 | parser.add_argument('--model_path', default='data/model/tplinker-bert/v0') 176 | parser.add_argument('--save_top_k', default=5) 177 | parser.add_argument('--max_epochs', default=10) 178 | parser.add_argument('--max_sequence_length', default=100) 179 | 180 | args, _ = parser.parse_known_args() 181 | 182 | module = TPLinkerLightning( 183 | model=create_bert_model( 184 | pretrained_bert_path=args.pretrained_bert_path, 185 | num_relations=args.num_relations, 186 | add_distance_embedding=args.add_distance_embedding), 187 | rel2id_path='data/tplinker/bert/rel2id.json', 188 | max_sequence_length=args.max_sequence_length) 189 | trainer = create_trainer( 190 | model_path=args.model_path, 191 | gpus=args.gpus, 192 | save_top_k=args.save_top_k, 193 | max_epochs=args.max_epochs, 194 | ) 195 | trainer.fit(module) 196 | -------------------------------------------------------------------------------- /tplinker/tagging_scheme.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from collections import namedtuple 3 | 4 | import numpy as np 5 | import torch 6 | 7 | 8 | class TagMapping: 9 | 10 | def __init__(self, relation2id, **kwargs): 11 | super().__init__() 12 | # relation mapping 13 | self.relation2id = relation2id 14 | self.id2relation = {v: k for k, v in self.relation2id.items()} 15 | 16 | # tag2id mapping: entity head to entity tail 17 | self.head2tail_tag2id = { 18 | "O": 0, 19 | "ENT-H2T": 1, # entity head to entity tail 20 | } 21 | self.head2tail_id2tag = {v: k for k, v in self.head2tail_tag2id.items()} 22 | 23 | # tag2id mapping: entity1 head to entity2 head 24 | self.head2head_tag2id = { 25 | "O": 0, 26 | "REL-SH2OH": 1, # subject head to object head 27 | "REL-OH2SH": 2, # object head to subject head 28 | } 29 | self.head2head_id2tag = {v: k for k, v in self.head2head_tag2id.items()} 30 | 31 | # tag2id mapping: entity1 tail to entity2 tail 32 | self.tail2tail_tag2id = { 33 | "O": 0, 34 | "REL-ST2OT": 1, # subject tail to object tail 35 | "REL-OT2ST": 2, # object tail to subject tail 36 | } 37 | self.tail2tail_id2tag = {v: k for k, v in self.tail2tail_tag2id.items()} 38 | 39 | def relation_id(self, key): 40 | return self.relation2id.get(key, None) 41 | 42 | def relation_tag(self, _id): 43 | return self.id2relation.get(_id, None) 44 | 45 | def h2t_id(self, key): 46 | return self.head2tail_tag2id.get(key, None) 47 | 48 | def h2t_tag(self, _id): 49 | return self.head2tail_id2tag.get(_id, 'O') 50 | 51 | def h2h_id(self, key): 52 | return self.head2head_tag2id.get(key, None) 53 | 54 | def h2h_tag(self, _id): 55 | return self.head2head_id2tag.get(_id, 'O') 56 | 57 | def t2t_id(self, key): 58 | return self.tail2tail_tag2id.get(key, None) 59 | 60 | def t2t_tag(self, _id): 61 | return self.tail2tail_id2tag.get(_id, 'O') 62 | 63 | 64 | # p -> point to head of entity, q -> point to tail of entity, tagid -> id of h2t tag 65 | Head2TailItem = namedtuple('Head2TailItem', ['p', 'q', 'tagid']) 66 | # p -> point to head of entity1, q -> point to head of entity2 67 | # relid -> relation_id between entity1 and entity2, tagid -> id of h2h tag 68 | Head2HeadItem = namedtuple('Head2HeadItem', ['relid', 'p', 'q', 'tagid']) 69 | # p -> point to tail of entity1, q -> point to tail of entity2 70 | # relid -> relation_id between entity1 and entity2, tagid -> id of h2h tag 71 | Tail2TailItem = namedtuple('Tail2TailItem', ['relid', 'p', 'q', 'tagid']) 72 | 73 | 74 | class HandshakingTaggingEncoder: 75 | 76 | def __init__(self, tag_mapping: TagMapping, **kwargs): 77 | super().__init__() 78 | self.tag_mapping = tag_mapping 79 | 80 | def encode(self, example, max_sequence_length=100, **kwargs): 81 | h2t, h2h, t2t = self.batch_encode([example], max_sequence_length=max_sequence_length, **kwargs) 82 | return h2t[0], h2h[0], t2t[0] 83 | 84 | def batch_encode(self, examples, max_sequence_length=100, **kwargs): 85 | index_matrix = self._build_index_matrix(max_sequence_length=max_sequence_length, **kwargs) 86 | flatten_length = max_sequence_length * (max_sequence_length + 1) // 2 87 | batch_h2t_spots, batch_h2h_spots, batch_t2t_spots = [], [], [] 88 | for example in examples: 89 | h2t_spots, h2h_spots, t2t_spots = self._collect_spots(example, **kwargs) 90 | batch_h2t_spots.append(h2t_spots) 91 | batch_h2h_spots.append(h2h_spots) 92 | batch_t2t_spots.append(t2t_spots) 93 | h2t_tagging = self._encode_head2tail(batch_h2t_spots, index_matrix=index_matrix, sequence_length=flatten_length) 94 | h2h_tagging = self._encode_head2head(batch_h2h_spots, index_matrix=index_matrix, sequence_length=flatten_length) 95 | t2t_tagging = self._encode_tail2tail(batch_t2t_spots, index_matrix=index_matrix, sequence_length=flatten_length) 96 | return h2t_tagging, h2h_tagging, t2t_tagging 97 | 98 | def _collect_spots(self, example, **kwargs): 99 | h2t_spots, h2h_spots, t2t_spots = [], [], [] 100 | # TODO: 考虑不在relation_list中的entity 101 | for relation in example['relation_list']: 102 | subject_span = relation['subj_tok_span'] 103 | object_span = relation['obj_tok_span'] 104 | 105 | # add head-to-tail spot 106 | h2t_i0 = Head2TailItem(p=subject_span[0], q=subject_span[1] - 1, tagid=self.tag_mapping.h2t_id('ENT-H2T')) 107 | h2t_spots.append(h2t_i0) 108 | h2t_i1 = Head2TailItem(p=object_span[0], q=object_span[1] - 1, tagid=self.tag_mapping.h2t_id('ENT-H2T')) 109 | h2t_spots.append(h2t_i1) 110 | # convert relation to id 111 | relid = self.tag_mapping.relation_id(relation['predicate']) 112 | # add head-to-head spot 113 | p = subject_span[0] if subject_span[0] <= object_span[0] else object_span[0] 114 | q = object_span[0] if subject_span[0] <= object_span[0] else subject_span[0] 115 | k = 'REL-SH2OH' if subject_span[0] <= object_span[0] else 'REL-OH2SH' 116 | h2h_item = Head2HeadItem(relid=relid, p=p, q=q, tagid=self.tag_mapping.h2h_id(k)) 117 | h2h_spots.append(h2h_item) 118 | 119 | # add tail-to-tail spot 120 | p = subject_span[1] - 1 if subject_span[1] <= object_span[1] else object_span[1] - 1 121 | q = object_span[1] - 1 if subject_span[1] <= object_span[1] else subject_span[1] - 1 122 | k = 'REL-ST2OT' if subject_span[1] <= object_span[1] else 'REL-OT2ST' 123 | t2t_item = Tail2TailItem(relid=relid, p=p, q=q, tagid=self.tag_mapping.t2t_id(k)) 124 | t2t_spots.append(t2t_item) 125 | 126 | return h2t_spots, h2h_spots, t2t_spots 127 | 128 | def _encode_head2tail(self, batch_h2t_spots, index_matrix, sequence_length, **kwargs): 129 | batch_tagging_sequence = np.zeros([len(batch_h2t_spots), sequence_length], dtype=np.int) 130 | for batch_id, h2t_spots in enumerate(batch_h2t_spots): 131 | for item in h2t_spots: 132 | index = index_matrix[item.p][item.q] 133 | batch_tagging_sequence[batch_id][index] = item.tagid 134 | return batch_tagging_sequence 135 | 136 | def _encode_head2head(self, batch_h2h_spots, index_matrix, sequence_length, **kwargs): 137 | num_relations = len(self.tag_mapping.relation2id) 138 | # shape (num_relations, sequence_length) 139 | batch_tagging_sequence = np.zeros([len(batch_h2h_spots), num_relations, sequence_length], dtype=np.int) 140 | for batch_id, h2h_spots in enumerate(batch_h2h_spots): 141 | for item in h2h_spots: 142 | index = index_matrix[item.p][item.q] 143 | batch_tagging_sequence[batch_id][item.relid][index] = item.tagid 144 | return batch_tagging_sequence 145 | 146 | def _encode_tail2tail(self, batch_t2t_spots, index_matrix, sequence_length, **kwargs): 147 | num_relations = len(self.tag_mapping.relation2id) 148 | # shape (num_relations, sequence_length) 149 | batch_tagging_sequence = np.zeros([len(batch_t2t_spots), num_relations, sequence_length], dtype=np.int) 150 | for batch_id, t2t_spots in enumerate(batch_t2t_spots): 151 | for item in t2t_spots: 152 | index = index_matrix[item.p][item.q] 153 | batch_tagging_sequence[batch_id][item.relid][index] = item.tagid 154 | return batch_tagging_sequence 155 | 156 | def _build_index_matrix(self, max_sequence_length=100, **kwargs): 157 | # e.g [(0, 0), (0, 1), (0, 2), (1, 1), (1, 2), (2, 2)] 158 | pairs = [(i, j) for i in range(max_sequence_length) for j in list(range(max_sequence_length))[i:]] 159 | # shape: (max_sequence_length, max_sequence_length) 160 | matrix = [[0 for i in range(max_sequence_length)] for j in range(max_sequence_length)] 161 | for index, values in enumerate(pairs): 162 | matrix[values[0]][values[1]] = index 163 | return matrix 164 | 165 | 166 | class HandshakingTaggingDecoder: 167 | 168 | def __init__(self, tag_mapping: TagMapping, **kwargs): 169 | super().__init__() 170 | self.tag_mapping = tag_mapping 171 | 172 | def decode(self, example, h2t_pred, h2h_pred, t2t_pred, max_sequence_length=100, **kwargs): 173 | # e.g [(0, 0), (0, 1), (0, 2), (1, 1), (1, 2), (2, 2)] 174 | index_matrix = [(i, j) for i in range(max_sequence_length) for j in list(range(max_sequence_length))[i:]] 175 | # decode predictions 176 | h2t_spots = self._decode_head2tail(h2t_pred, index_matrix) 177 | h2h_spots = self._decode_head2head(h2h_pred, index_matrix) 178 | t2t_spots = self._decode_tail2tail(t2t_pred, index_matrix) 179 | 180 | entities_head_map = self._parse_entities(h2t_spots, example) 181 | relation_tails = self._parse_tails(t2t_spots) 182 | relations = self._parse_relations( 183 | h2h_spots, entities_head_map, relation_tails, 184 | token_offset=example['token_offset'], char_offset=example['char_offset']) 185 | return relations 186 | 187 | def batch_decode(self, examples, batch_h2t_pred, batch_h2h_pred, batch_t2t_pred, max_sequence_length=100, **kwargs): 188 | batch_relations = [] 189 | for example, h2t_pred, h2h_pred, t2t_pred in zip(examples, batch_h2t_pred, batch_h2h_pred, batch_t2t_pred): 190 | relations = self.decode( 191 | example, h2t_pred, h2h_pred, t2t_pred, 192 | max_sequence_length=max_sequence_length, **kwargs) 193 | batch_relations.append(relations) 194 | return batch_relations 195 | 196 | def _decode_head2tail(self, h2t_pred, index_matrix): 197 | """Decode head2tail tagging. 198 | 199 | Args: 200 | h2h_pred: Tensor, shape (1+2+...+seq_len, 2) 201 | index_matrix: List of indexes 202 | 203 | Returns: 204 | items: List of Head2TailItem 205 | """ 206 | items = [] 207 | # shape: (1+2+...+seq_len) 208 | h2t_pred = torch.argmax(h2t_pred, dim=-1) 209 | for index in torch.nonzero(h2t_pred): 210 | flat_index = index[0].item() 211 | matrix_ind = index_matrix[flat_index] 212 | item = Head2TailItem(p=matrix_ind[0], q=matrix_ind[1], tagid=h2t_pred[flat_index].item()) 213 | items.append(item) 214 | return items 215 | 216 | def _parse_entities(self, h2t_items, example): 217 | entities_head_map = {} 218 | for item in h2t_items: 219 | if item.tagid != self.tag_mapping.h2t_id('ENT-H2T'): 220 | continue 221 | char_offset_list = example['offset_mapping'][item.p: item.q + 1] 222 | char_span = [char_offset_list[0][0], char_offset_list[-1][1]] 223 | entity_txt = example['text'][char_span[0]: char_span[1]] 224 | head = item.p 225 | if head not in entities_head_map: 226 | entities_head_map[head] = [] 227 | entities_head_map[head].append({ 228 | 'text': entity_txt, 229 | 'tok_span': [item.p, item.q], 230 | 'char_span': char_span, 231 | }) 232 | return entities_head_map 233 | 234 | def _decode_head2head(self, h2h_pred, index_maxtrix): 235 | """Decode head2head predictions. 236 | 237 | Args: 238 | h2h_pred: Tensor, shape (num_relations, 1+2+...+seq_len, 3) 239 | index_matrix: List of indexes 240 | 241 | Returns: 242 | items: List of Head2HeadItem 243 | """ 244 | items = [] 245 | # shape: (num_relations, 1+2+...+seq_len) 246 | h2h_pred = torch.argmax(h2h_pred, dim=-1) 247 | for index in torch.nonzero(h2h_pred): 248 | relation_id, flat_index = index[0].item(), index[1].item() 249 | matrix_index = index_maxtrix[flat_index] 250 | item = Head2HeadItem( 251 | relid=relation_id, p=matrix_index[0], q=matrix_index[1], 252 | tagid=h2h_pred[relation_id][flat_index].item()) 253 | items.append(item) 254 | return items 255 | 256 | def _parse_relations(self, h2h_items, entities_head_map, relation_tails, token_offset=0, char_offset=0, **kwargs): 257 | relations = [] 258 | for item in h2h_items: 259 | subj_head, obj_head = None, None 260 | if item.tagid == self.tag_mapping.h2h_id('REL-SH2OH'): 261 | subj_head, obj_head = item.p, item.q 262 | elif item.tagid == self.tag_mapping.h2h_id('REL-OH2SH'): 263 | subj_head, obj_head = item.q, item.p 264 | if not subj_head or not obj_head: 265 | continue 266 | if subj_head not in entities_head_map or obj_head not in entities_head_map: 267 | continue 268 | 269 | subj_list = entities_head_map[subj_head] 270 | obj_list = entities_head_map[obj_head] 271 | for subj in subj_list: 272 | for obj in obj_list: 273 | tail = '{}-{}-{}'.format(item.relid, subj['tok_span'][1], obj['tok_span'][1]) 274 | if tail not in relation_tails: 275 | continue 276 | relations.append({ 277 | 'subject': subj['text'], 278 | 'object': obj['text'], 279 | 'subj_tok_span': [subj['tok_span'][0] + token_offset, subj['tok_span'][1] + token_offset + 1], 280 | 'subj_char_span': [subj['char_span'][0] + char_offset, subj['char_span'][1] + char_offset + 1], 281 | 'obj_tok_span': [obj['tok_span'][0] + token_offset, obj['tok_span'][1] + token_offset + 1], 282 | 'obj_char_span': [obj['char_span'][0] + char_offset, obj['char_span'][1] + char_offset + 1], 283 | 'predicate': self.tag_mapping.relation_tag(item.relid), 284 | }) 285 | return relations 286 | 287 | def _decode_tail2tail(self, t2t_pred, index_matrix): 288 | """Decode tail2tail predictions. 289 | 290 | Args: 291 | t2t_pred: Tensor, shape (num_relations, 1+2+...+seq_len, 3) 292 | index_matrix: List of indexes 293 | 294 | Returns: 295 | items: List of Tail2TailItem 296 | """ 297 | items = [] 298 | t2t_pred = torch.argmax(t2t_pred, dim=-1) 299 | for index in torch.nonzero(t2t_pred): 300 | relation_id, flat_index = index[0].item(), index[1].item() 301 | matrix_index = index_matrix[flat_index] 302 | item = Tail2TailItem( 303 | relid=relation_id, p=matrix_index[0], q=matrix_index[1], 304 | tagid=t2t_pred[relation_id][flat_index].item()) 305 | items.append(item) 306 | return items 307 | 308 | def _parse_tails(self, t2t_items): 309 | tails = set() 310 | for item in t2t_items: 311 | if item.tagid == self.tag_mapping.t2t_id('REL-ST2OT'): 312 | tails.add('{}-{}-{}'.format(item.relid, item.p, item.q)) 313 | elif item.tagid == self.tag_mapping.t2t_id('REL-OT2ST'): 314 | tails.add('{}-{}-{}'.format(item.relid, item.q, item.p)) 315 | return tails 316 | -------------------------------------------------------------------------------- /tplinker/truncator.py: -------------------------------------------------------------------------------- 1 | import abc 2 | 3 | from transformers import BertTokenizerFast 4 | 5 | 6 | class AbstractExampleTruncator(abc.ABC): 7 | """Truncate examples whose text is too long. *THIS MAY BREAK RELATIONS!*""" 8 | 9 | @abc.abstractmethod 10 | def truncate(self, example, **kwargs): 11 | raise NotImplementedError() 12 | 13 | def _adjust_entity_list(self, example, start, end, token_offset, char_offset, **kwargs): 14 | entity_list = [] 15 | for entity in example['entity_list']: 16 | # print(entity) 17 | token_span = entity.get('tok_span', None) or entity.get('token_span', None) 18 | char_span = entity['char_span'] 19 | if not token_span: 20 | continue 21 | if start > token_span[0] or end < token_span[1]: 22 | continue 23 | entity_list.append({ 24 | 'text': entity['text'], 25 | 'type': entity['type'], 26 | 'tok_span': [token_span[0] - token_offset, token_span[1] - token_offset], 27 | 'char_span': [char_span[0] - char_offset, char_span[1] - char_offset] 28 | }) 29 | return entity_list 30 | 31 | def _adjust_relation_list(self, example, start, end, token_offset, char_offset, **kwargs): 32 | relation_list = [] 33 | for relation in example['relation_list']: 34 | subj_token_span, obj_token_span = relation['subj_tok_span'], relation['obj_tok_span'] 35 | subj_char_span, obj_char_span = relation['subj_char_span'], relation['obj_char_span'] 36 | if start <= subj_token_span[0] and subj_token_span[1] <= end and start <= obj_token_span[0] and obj_token_span[1] <= end: 37 | relation_list.append({ 38 | 'subj_tok_span': [subj_token_span[0] - token_offset, subj_token_span[1] - token_offset], 39 | 'obj_tok_span': [obj_token_span[0] - token_offset, obj_token_span[1] - token_offset], 40 | 'subj_char_span': [subj_char_span[0] - char_offset, subj_char_span[1] - char_offset], 41 | 'obj_char_span': [obj_char_span[0] - char_offset, obj_char_span[1] - char_offset], 42 | 'subject': relation['subject'], 43 | 'object': relation['object'], 44 | 'predicate': relation['predicate'], 45 | }) 46 | return relation_list 47 | 48 | def _adjust_offset_mapping(self, offset_mapping, char_offset, max_sequence_length=100, **kwargs): 49 | offsets = [] 50 | for start, end in offset_mapping: 51 | offsets.append([start - char_offset, end - char_offset]) 52 | # padding to max_sequence_length to avoid DataLoader runtime error 53 | while len(offsets) < max_sequence_length: 54 | offsets.append([0, 0]) 55 | return offsets 56 | 57 | 58 | class BertExampleTruncator(AbstractExampleTruncator): 59 | 60 | def __init__(self, tokenizer: BertTokenizerFast, max_sequence_length=100, window_size=50, **kwargs): 61 | super().__init__() 62 | self.max_sequence_length = max_sequence_length 63 | self.window_size = window_size 64 | self.tokenizer = tokenizer 65 | 66 | def truncate(self, example, **kwargs): 67 | all_examples = [] 68 | 69 | text = example['text'] 70 | codes = self.tokenizer.encode_plus(text, return_offsets_mapping=True, add_special_tokens=False) 71 | if len(codes['input_ids']) < self.max_sequence_length: 72 | all_examples.append({ 73 | 'text': text, 74 | 'entity_list': example['entity_list'], 75 | 'relation_list': example['relation_list'], 76 | 'offset_mapping': self._adjust_offset_mapping(codes['offset_mapping'], 0), 77 | 'token_offset': 0, 78 | 'char_offset': 0, 79 | }) 80 | return all_examples 81 | 82 | tokens = self.tokenizer.convert_ids_to_tokens(codes['input_ids']) 83 | offset = codes['offset_mapping'] 84 | 85 | for start in range(0, len(tokens), self.window_size): 86 | # do not truncte word pieces 87 | while str(tokens[start]).startswith('##'): 88 | start -= 1 89 | end = min(start + self.max_sequence_length, len(tokens)) 90 | range_offset_mapping = offset[start: end] 91 | char_span = [range_offset_mapping[0][0], range_offset_mapping[-1][1]] 92 | text_subs = text[char_span[0]:char_span[1]] 93 | 94 | token_offset = start 95 | char_offset = char_span[0] 96 | 97 | truncated_example = { 98 | 'text': text_subs, 99 | 'entity_list': self._adjust_entity_list(example, start, end, token_offset, char_offset), 100 | 'relation_list': self._adjust_relation_list(example, start, end, token_offset, char_offset), 101 | 'token_offset': token_offset, 102 | 'char_offset': char_offset, 103 | 'offset_mapping': self._adjust_offset_mapping(range_offset_mapping, char_offset) 104 | } 105 | all_examples.append(truncated_example) 106 | 107 | if end > len(tokens): 108 | break 109 | 110 | return all_examples 111 | --------------------------------------------------------------------------------