├── .gitignore ├── .gitmodules ├── LICENSE ├── README.md ├── img └── Table-Result.png ├── requirements.txt ├── scripts ├── bert_contrastive.py ├── data_preprocess.py ├── generator.py ├── manager.py ├── oie │ ├── benchmark.py │ ├── evaluate_oie.py │ ├── matcher.py │ └── oie_readers │ │ ├── extraction.py │ │ └── goldReader.py ├── post_processing.py ├── post_processing.sh ├── processing.sh ├── ranking.py └── rc │ ├── dataset_preparation.py │ ├── eval_FewRel.sh │ ├── eval_TACRED.sh │ ├── evaluation.py │ ├── fewrel_aliases_lemmatized.json │ ├── fewrel_aliases_unlemmatized.json │ ├── post_process.py │ ├── prep_FewRel.sh │ ├── prep_TACRED.sh │ ├── string_matcher.py │ ├── tacred_aliases_lemmatized.json │ └── tacred_aliases_unlemmatized.json ├── setup.py ├── src └── deepex │ ├── __init__.py │ ├── args.py │ ├── data │ ├── __init__.py │ ├── collator.py │ ├── generator_utils.py │ ├── np.py │ ├── rc.py │ ├── re_data.py │ └── text_handler.py │ ├── model │ ├── __init__.py │ ├── distillation.py │ ├── eval.py │ └── kgm.py │ └── utils.py └── tasks ├── FewRel.sh ├── NYT.sh ├── OIE_2016.sh ├── PENN.sh ├── TACRED.sh ├── WEB.sh └── configs ├── FewRel.json ├── NYT.json ├── OIE_2016.json ├── PENN.json ├── TACRED.json └── WEB.json /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "supervised-oie"] 2 | path = supervised-oie 3 | url = git@github.com:gabrielStanovsky/supervised-oie.git 4 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Zero-Shot Information Extraction as a Unified Text-to-Triple Translation 2 | 3 | Source code repo for paper [Zero-Shot Information Extraction as a Unified Text-to-Triple Translation](https://arxiv.org/pdf/2109.11171.pdf), EMNLP 2021. 4 | 5 | ## Installation 6 | 7 | ```bash 8 | git clone --recursive git@github.com:cgraywang/deepex.git 9 | cd ./deepex 10 | conda create --name deepex python=3.7 -y 11 | conda activate deepex 12 | pip install -r requirements.txt 13 | pip install -e . 14 | ``` 15 | 16 | Requires PyTorch version 1.5.1 or above with CUDA. PyTorch 1.7.1 with CUDA 10.1 is tested. Please refer to https://pytorch.org/get-started/locally/ for installing PyTorch. 17 | 18 | ## Dataset Preparation 19 | 20 | ### Relation Classification 21 | 22 | #### FewRel 23 | 24 | You can add `--prepare-rc-dataset` argument when running the scripts in [this section](#scripts-for-reproducing-results), which would allow the script to automatically handle the preparation of FewRel dataset. 25 | 26 | Or, you could manually download and prepare the FewRel dataset using the following script: 27 | 28 | ```bash 29 | bash scripts/rc/prep_FewRel.sh 30 | ``` 31 | The processed data will be stored at `data/FewRel/data.jsonl`. 32 | 33 | #### TACRED 34 | 35 | TACRED is licensed under LDC, please first download TACRED dataset from [link](https://catalog.ldc.upenn.edu/LDC2018T24). The downloaded file should be named as `tacred_LDC2018T24.tgz`. 36 | 37 | **After downloading and correctly naming the tacred `.tgz` data file**, you can add `--prepare-rc-dataset` argument when running the scripts in [this section](#scripts-for-reproducing-results), which would allow the script to automatically handle the preparation of TACRED dataset. 38 | 39 | Or, you could manually download and prepare the TACRED dataset using the following script: 40 | 41 | ```bash 42 | bash scripts/rc/prep_TACRED.sh 43 | ``` 44 | The processed data will be stored at `data/TACRED/data.jsonl`. 45 | 46 | ## Scripts for Reproducing Results 47 | 48 | This section contains the scripts for running the tasks with default setting (e.g.: using model `bert-large-cased`, using 8 CUDA devices with per-device batch size equal to 4). 49 | 50 | To modify the settings, please checkout [this section](#arguments). 51 | 52 | ### Open Information Extraction 53 | ```bash 54 | bash tasks/OIE_2016.sh 55 | ``` 56 | ```bash 57 | bash tasks/PENN.sh 58 | ``` 59 | ```bash 60 | bash tasks/WEB.sh 61 | ``` 62 | ```bash 63 | bash tasks/NYT.sh 64 | ``` 65 | 66 | ### Relation Classification 67 | ```bash 68 | bash tasks/FewRel.sh 69 | ``` 70 | ```bash 71 | bash tasks/TACRED.sh 72 | ``` 73 | 74 | ## Arguments 75 | General script: 76 | ```bash 77 | python scripts/manager.py --task= 78 | ``` 79 | 80 | The default setting is: 81 | ```bash 82 | python scripts/manager.py --task= --model="bert-large-cased" --beam-size=6 83 | --max-distance=2048 --batch-size-per-device=4 --stage=0 84 | --cuda=0,1,2,3,4,5,6,7 85 | ``` 86 | 87 | All tasks are already implemented as above `.sh` files in `tasks/`, using the default arguments. 88 | 89 | The following are the most important command-line arguments for the `scripts/manager.py` script: 90 | - `--task`: The task to be run, supported tasks are `OIE_2016`, `WEB`, `NYT`, `PENN`, `FewRel` and `TACRED`. 91 | - `--model`: The pre-trained model type to be used for generating attention matrices to perform beam search on, supported models are `bert-base-cased` and `bert-large-cased`. 92 | - `--beam-size`: The beam size during beam search. 93 | - `--batch-size-per-device`: The batch size on a single device. 94 | - `--stage`: Run task starting from an intermediate stage: 95 | - `--stage=0`: data preparation and beam-search 96 | - `--stage=1`: post processing 97 | - `--stage=2`: ranking 98 | - `--stage=3`: evaluation 99 | - `--prepare-rc-dataset`: If true, automatically run the relation classification dataset preparation scripts. Notice that this argument should be turned on only for relation classification tasks (i.e.: `FewRel` and `TACRED`). 100 | - `--cuda`: Specify CUDA gpu devices. 101 | 102 | Run `python scripts/manager.py -h` for the full list. 103 | 104 | ## Results 105 | 106 | ![](img/Table-Result.png) 107 | 108 | **NOTE** 109 | 110 | We are able to obtain improved or same results compared to the paper's results. We will release the code and datasets for factual probe soon! 111 | 112 | ## Related Work 113 | 114 | We implement an extended version of the beam search algorithm proposed in [Language Models are Open Knowledge Graphs](https://arxiv.org/pdf/2010.11967.pdf) in [src/deepex/model/kgm.py](https://github.com/cgraywang/deepex/blob/main/src/deepex/model/kgm.py). 115 | 116 | ## Citation 117 | 118 | ```bibtex 119 | @inproceedings{wang-etal-2021-deepex, 120 | title = "Zero-Shot Information Extraction as a Unified Text-to-Triple Translation", 121 | author = "Chenguang Wang and Xiao Liu and Zui Chen and Haoyun Hong and Jie Tang and Dawn Song", 122 | booktitle = "Proceedings of the 2021 Conference on Empirical Methods in Natural Language Processing", 123 | year = "2021", 124 | publisher = "Association for Computational Linguistics" 125 | } 126 | 127 | @article{wang-etal-2020-language, 128 | title = "Language Models are Open Knowledge Graphs", 129 | author = "Chenguang Wang and Xiao Liu and Dawn Song", 130 | journal = "arXiv preprint arXiv:2010.11967", 131 | year = "2020" 132 | } 133 | ``` -------------------------------------------------------------------------------- /img/Table-Result.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wang-research-lab/deepex/a4a4cf60c96e1bfe3ddc8007498bf5ed783af730/img/Table-Result.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | dataclasses 2 | faiss-cpu 3 | filelock 4 | flashtext 5 | jsonlines 6 | nltk==3.2.5 7 | numpy 8 | requests 9 | scipy 10 | spacy==3.0.2 11 | https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.0.0/en_core_web_sm-3.0.0.tar.gz#egg=en_core_web_sm 12 | tqdm 13 | transformers==4.10.3 14 | wget 15 | # supervised-oie-dependency below 16 | certifi 17 | chardet 18 | docopt 19 | idna 20 | regex!=2019.12.17 21 | scikit-learn==0.22.0 -------------------------------------------------------------------------------- /scripts/bert_contrastive.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | import torch.optim as optim 6 | import torch.utils.data as Data 7 | from transformers import BertForSequenceClassification, BertConfig, BertTokenizer, BertModel 8 | import faiss 9 | import tqdm 10 | 11 | class BERT(nn.Module): 12 | def __init__(self, model, device='cuda'): 13 | super(BERT, self).__init__() 14 | self.device = device 15 | self.tokenizer = BertTokenizer.from_pretrained(model) 16 | self.model = BertModel.from_pretrained(model).to(self.device) 17 | self.criterion = nn.CrossEntropyLoss() 18 | 19 | def __get_input_tensors(self, sentences): 20 | 21 | sentences = sentences.split('[SEP]') 22 | 23 | if len(sentences) > 2: 24 | print(sentences) 25 | raise ValueError("BERT accepts maximum two sentences in input for each data point") 26 | 27 | first_tokenized_sentence = self.tokenizer.tokenize(sentences[0]) 28 | first_segment_id = np.zeros(len(first_tokenized_sentence), dtype=int).tolist() 29 | 30 | first_tokenized_sentence.append("[SEP]") 31 | first_segment_id.append(0) 32 | 33 | if len(sentences)>1 : 34 | second_tokenized_sentece = self.tokenizer.tokenize(sentences[1]) 35 | second_segment_id = np.full(len(second_tokenized_sentece),1, dtype=int).tolist() 36 | 37 | second_tokenized_sentece.append("[SEP]") 38 | second_segment_id.append(1) 39 | 40 | tokenized_text = first_tokenized_sentence + second_tokenized_sentece 41 | segments_ids = first_segment_id + second_segment_id 42 | else: 43 | tokenized_text = first_tokenized_sentence 44 | segments_ids = first_segment_id 45 | 46 | tokenized_text.insert(0,"[CLS]") 47 | segments_ids.insert(0,0) 48 | 49 | masked_indices = [] 50 | for i in range(len(tokenized_text)): 51 | token = tokenized_text[i] 52 | if token == "[MASK]": 53 | masked_indices.append(i) 54 | 55 | indexed_tokens = self.tokenizer.convert_tokens_to_ids(tokenized_text) 56 | 57 | tokens_tensor = torch.tensor([indexed_tokens]) 58 | segments_tensors = torch.tensor([segments_ids]) 59 | 60 | return tokens_tensor, segments_tensors, masked_indices, tokenized_text 61 | 62 | 63 | def __get_input_tensors_batch(self, sentences_list): 64 | tokens_tensors_list = [] 65 | segments_tensors_list = [] 66 | masked_indices_list = [] 67 | tokenized_text_list = [] 68 | max_tokens = 0 69 | for sentences in sentences_list: 70 | tokens_tensor, segments_tensor, masked_indices, tokenized_text = self.__get_input_tensors(sentences) 71 | tokens_tensors_list.append(tokens_tensor.to(self.device)) 72 | segments_tensors_list.append(segments_tensor.to(self.device)) 73 | masked_indices_list.append(masked_indices) 74 | tokenized_text_list.append(tokenized_text) 75 | if (tokens_tensor.shape[1] > max_tokens): 76 | max_tokens = tokens_tensor.shape[1] 77 | final_tokens_tensor = None 78 | final_segments_tensor = None 79 | final_attention_mask = None 80 | for tokens_tensor, segments_tensor in zip(tokens_tensors_list, segments_tensors_list): 81 | dim_tensor = tokens_tensor.shape[1] 82 | pad_lenght = max_tokens - dim_tensor 83 | attention_tensor = torch.full([1,dim_tensor], 1, dtype= torch.long).to(self.device) 84 | if pad_lenght>0: 85 | pad_1 = torch.full([1,pad_lenght], 0, dtype= torch.long).to(self.device) 86 | pad_2 = torch.full([1,pad_lenght], 0, dtype= torch.long).to(self.device) 87 | attention_pad = torch.full([1,pad_lenght], 0, dtype= torch.long).to(self.device) 88 | tokens_tensor = torch.cat((tokens_tensor,pad_1), dim=1) 89 | segments_tensor = torch.cat((segments_tensor,pad_2), dim=1) 90 | attention_tensor = torch.cat((attention_tensor,attention_pad), dim=1) 91 | if final_tokens_tensor is None: 92 | final_tokens_tensor = tokens_tensor.to(self.device) 93 | final_segments_tensor = segments_tensor.to(self.device) 94 | final_attention_mask = attention_tensor.to(self.device) 95 | else: 96 | final_tokens_tensor = torch.cat((final_tokens_tensor,tokens_tensor), dim=0).to(self.device) 97 | final_segments_tensor = torch.cat((final_segments_tensor,segments_tensor), dim=0).to(self.device) 98 | final_attention_mask = torch.cat((final_attention_mask,attention_tensor), dim=0).to(self.device) 99 | return final_tokens_tensor, final_segments_tensor, final_attention_mask, masked_indices_list, tokenized_text_list 100 | 101 | def forward(self, text_triple): 102 | 103 | tokens_tensor, segments_tensor, attention_mask_tensor, masked_indices_list, tokenized_text_list = self.__get_input_tensors_batch(text_triple) 104 | 105 | outputs = self.model( 106 | input_ids=tokens_tensor, 107 | token_type_ids=segments_tensor, 108 | attention_mask=attention_mask_tensor, 109 | output_hidden_states=False 110 | ) 111 | 112 | embeddings = outputs[0] 113 | 114 | embeddings = embeddings.transpose(1,2) 115 | 116 | token_type = segments_tensor 117 | not_padding = (tokens_tensor > 0).int() 118 | 119 | token_type = token_type.float() 120 | not_padding = not_padding.float() 121 | 122 | token_type = token_type.reshape(token_type.shape[0], 1, token_type.shape[1]) 123 | not_padding = not_padding.reshape(not_padding.shape[0], 1, not_padding.shape[1]) 124 | 125 | triple_output = token_type.mul(embeddings).sum(dim=2) 126 | text_output = (1-token_type).mul(not_padding.mul(embeddings)).sum(dim=2) 127 | 128 | text_output = torch.nn.functional.normalize(text_output) 129 | triple_output = torch.nn.functional.normalize(triple_output) 130 | 131 | return text_output, triple_output 132 | 133 | def Reranking(data, MODEL_FOLDER="Magolor/deepex-ranking-model", batch_size=32, device='cuda'): 134 | model = BERT(MODEL_FOLDER).to(device) 135 | model.eval() 136 | with torch.no_grad(): 137 | for (docid,triples) in tqdm.tqdm(list(data.items())): 138 | rerank_triples = []; batched_triples = [] 139 | for i,triple in enumerate(sorted(triples,key=lambda x:x['sentence']),1): 140 | sentence = " ".join(triple['sentence'][13:].split(" ")[:100]) 141 | text_triple = sentence+"[SEP]"+str((triple['subject'],triple['relation'],triple['object'])) 142 | batched_triples.append(text_triple) 143 | if len(batched_triples)==batch_size or i==len(triples): 144 | text_vector, triple_vector = model(batched_triples) 145 | text_vector, triple_vector = text_vector.detach().cpu().numpy(), triple_vector.detach().cpu().numpy() 146 | for j in range(len(batched_triples)): 147 | triple = triples[i-len(batched_triples)+j] 148 | triple['contrastive_dis'] = float(np.linalg.norm(text_vector[j]-triple_vector[j])) 149 | rerank_triples.append(triple) 150 | batched_triples = [] 151 | data[docid] = sorted(rerank_triples,key=lambda x:x['contrastive_dis']) -------------------------------------------------------------------------------- /scripts/data_preprocess.py: -------------------------------------------------------------------------------- 1 | from utils import * 2 | 3 | def PreprocessData(META_TASK, RAW_PATH, DATA_PATH): 4 | if META_TASK in ["OIE_2016"]: 5 | Create(DATA_PATH) 6 | data = [] 7 | if os.path.exists(RAW_PATH+"test.txt"): 8 | with open(RAW_PATH+"test.txt","r") as f: 9 | for j,line in enumerate(f): 10 | data.append( 11 | { 12 | "id": str(j+1), 13 | "title": str(j+1), 14 | "text": line[:-1].replace('(',' ').replace(')',' '), 15 | } 16 | ) 17 | SaveJSON(data,DATA_PATH+f"P0.jsonl",jsonl=True) 18 | elif META_TASK in ["FewRel","TACRED"]: 19 | Create(DATA_PATH) 20 | data = LoadJSON(f"data/{META_TASK}/data.jsonl",jsonl=True) 21 | SaveJSON(data,DATA_PATH+"P0.jsonl",jsonl=True) 22 | else: 23 | Create(DATA_PATH); data = [] 24 | with open(RAW_PATH+"{0}.raw".format(META_TASK.lower()),"r") as f: 25 | for j,line in enumerate(f): 26 | data.append( 27 | { 28 | "id": str(j+1), 29 | "title": str(j+1), 30 | "text": line[:-1].replace('(',' ').replace(')',' '), 31 | } 32 | ) 33 | SaveJSON(data,DATA_PATH+"P0.jsonl",jsonl=True) 34 | -------------------------------------------------------------------------------- /scripts/generator.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import math 3 | import os 4 | from dataclasses import dataclass, field 5 | from typing import Optional 6 | from typing import Callable, Dict, List, Optional, Tuple 7 | import time 8 | import numpy as np 9 | import json 10 | import string 11 | import re 12 | import pickle 13 | 14 | import spacy 15 | from spacy.lang.en import English 16 | from torch.multiprocessing import set_start_method 17 | 18 | from transformers import ( 19 | CONFIG_MAPPING, 20 | MODEL_WITH_LM_HEAD_MAPPING, 21 | AutoConfig, 22 | AutoModelForMaskedLM, 23 | AutoTokenizer, 24 | HfArgumentParser, 25 | PreTrainedTokenizer, 26 | Trainer, 27 | set_seed, 28 | TrainingArguments, 29 | GPT2TokenizerFast, 30 | GPT2Tokenizer 31 | ) 32 | 33 | from transformers.training_args import is_torch_tpu_available 34 | 35 | if is_torch_tpu_available(): 36 | import torch_xla.core.xla_model as xm 37 | import torch_xla.debug.metrics as met 38 | import torch_xla.distributed.parallel_loader as pl 39 | 40 | from deepex.model import predict_and_save_results 41 | from deepex.data import REDataset, default_data_collator, NPMentionGenerator, RCMentionGenerator 42 | from deepex.args import ModelArguments, DataTrainingArguments 43 | from deepex.utils import * 44 | 45 | logger = logging.getLogger(__name__) 46 | logger.setLevel(logging.INFO) 47 | logger.addHandler(logging.StreamHandler()) 48 | 49 | def main(): 50 | parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments)) 51 | model_args, data_args, training_args = parser.parse_args_into_dataclasses() 52 | if not os.path.exists(os.path.split(training_args.output_dir)[0]): 53 | try: 54 | os.mkdir(os.path.split(training_args.output_dir)[0]) 55 | except: 56 | pass 57 | if not os.path.exists(training_args.output_dir): 58 | try: 59 | os.mkdir(training_args.output_dir) 60 | except: 61 | pass 62 | logger.addHandler(logging.FileHandler(os.path.join(training_args.output_dir, 'run.log'))) 63 | 64 | if ( 65 | os.path.exists(training_args.output_dir) 66 | and os.listdir(training_args.output_dir) 67 | and training_args.do_train 68 | and not training_args.overwrite_output_dir 69 | ): 70 | raise ValueError( 71 | f"Output directory ({training_args.output_dir}) already exists and is not empty. Use --overwrite_output_dir to overcome." 72 | ) 73 | 74 | logging.basicConfig( 75 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 76 | datefmt="%m/%d/%Y %H:%M:%S", 77 | level=logging.INFO if training_args.local_rank in [-1, 0] else logging.WARN, 78 | ) 79 | logger.warning( 80 | "Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s", 81 | training_args.local_rank, 82 | training_args.device, 83 | training_args.n_gpu, 84 | bool(training_args.local_rank != -1), 85 | training_args.fp16, 86 | ) 87 | 88 | set_seed(training_args.seed) 89 | 90 | 91 | if model_args.model_name_or_path: 92 | config = AutoConfig.from_pretrained(model_args.model_name_or_path, cache_dir=model_args.cache_dir) 93 | config.output_attentions = True 94 | else: 95 | config = CONFIG_MAPPING[model_args.model_type]() 96 | logger.warning("You are instantiating a new config instance from scratch.") 97 | 98 | if model_args.tokenizer_name: 99 | if 'gpt2' not in model_args.tokenizer_name: 100 | tokenizer = AutoTokenizer.from_pretrained(model_args.tokenizer_name, cache_dir=model_args.cache_dir, use_fast=True) 101 | tokenizer1 = AutoTokenizer.from_pretrained(model_args.tokenizer_name, cache_dir=model_args.cache_dir, 102 | use_fast=True) 103 | else: 104 | tokenizer = AutoTokenizer.from_pretrained(model_args.tokenizer_name, cache_dir=model_args.cache_dir, 105 | use_fast=True) 106 | tokenizer1 = AutoTokenizer.from_pretrained(model_args.tokenizer_name, cache_dir=model_args.cache_dir, 107 | use_fast=False) 108 | elif model_args.model_name_or_path: 109 | if 'gpt2' not in model_args.model_name_or_path: 110 | tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path, cache_dir=model_args.cache_dir, use_fast=True) 111 | tokenizer1 = AutoTokenizer.from_pretrained(model_args.model_name_or_path, cache_dir=model_args.cache_dir, 112 | use_fast=True) 113 | else: 114 | tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path, cache_dir=model_args.cache_dir, 115 | use_fast=True) 116 | tokenizer1 = AutoTokenizer.from_pretrained(model_args.model_name_or_path, cache_dir=model_args.cache_dir, 117 | use_fast=False) 118 | else: 119 | raise ValueError( 120 | "You are instantiating a new tokenizer from scratch. This is not supported, but you can do it from another script, save it," 121 | "and load it from here, using --tokenizer_name" 122 | ) 123 | if isinstance(tokenizer, GPT2TokenizerFast) or isinstance(tokenizer, GPT2Tokenizer): 124 | tokenizer.pad_token = tokenizer.eos_token 125 | tokenizer1.pad_token = tokenizer1.eos_token 126 | if model_args.model_name_or_path: 127 | model = AutoModelForMaskedLM.from_pretrained( 128 | model_args.model_name_or_path 129 | if model_args.local_model_name_or_path is None or model_args.local_model_name_or_path == 'None' 130 | else model_args.local_model_name_or_path, 131 | from_tf=bool(".ckpt" in model_args.model_name_or_path), 132 | config=config, 133 | cache_dir=model_args.cache_dir, 134 | ) 135 | else: 136 | logger.info("Training new model from scratch") 137 | model = AutoModelForMaskedLM.from_config(config) 138 | 139 | model.resize_token_embeddings(len(tokenizer)) 140 | 141 | start_time = time.time() 142 | start_aug_time = time.time() 143 | if data_args.data_aug == 'np': 144 | mention_generator = NPMentionGenerator() 145 | elif data_args.data_aug == 'rc': 146 | mention_generator = RCMentionGenerator(dataset=data_args.data_dir.split('/')[-2]) 147 | else: 148 | raise NotImplementedError 149 | logger.info('time spent on loading data augmentation: {}s'.format(time.time() - start_aug_time)) 150 | 151 | 152 | for f in tqdm(sorted(os.listdir(data_args.data_dir)), desc='Generate dataset and results'): 153 | if not f.endswith('.jsonl'): 154 | continue 155 | index = int(f.split('.jsonl')[0].split('P')[1]) 156 | redataset_processor = REDataset(data_args.data_dir, index, tokenizer, mention_generator, data_args.max_length) 157 | for i, eval_dataset in enumerate( 158 | tqdm(redataset_processor.generate_batched_datasets(), 159 | desc='Generate batch dataset and results')): 160 | res_dir = os.path.join(training_args.output_dir, 161 | "{}_{}_{}_{}_{}_{}".format(index, tokenizer.__class__.__name__, 162 | mention_generator.__class__.__name__, data_args.max_length, 163 | i, training_args.local_rank)) 164 | if os.path.exists(os.path.join(res_dir, "search_res.json")): 165 | logger.info('skip for {}'.format(res_dir)) 166 | continue 167 | start_generation_time = time.time() 168 | trainer = Trainer( 169 | model=model, 170 | args=training_args, 171 | eval_dataset=eval_dataset, 172 | data_collator=default_data_collator 173 | ) 174 | eval_dataloader = trainer.get_eval_dataloader(eval_dataset) 175 | _, res = predict_and_save_results(eval_dataloader, 176 | description="Generate_triplets", 177 | trainer=trainer, 178 | model_args=model_args, 179 | tokenizer=tokenizer1) 180 | logger.info('total producing triplets time: {}s'.format(time.time() - start_generation_time)) 181 | 182 | start_merge_time = time.time() 183 | if not os.path.exists(res_dir): 184 | try: 185 | os.mkdir(res_dir) 186 | except: 187 | pass 188 | _, _, _, search_res = res 189 | json.dump(search_res, open(os.path.join(res_dir, "search_res.json"), 'w')) 190 | logger.info('total dump triplets time: {}s'.format(time.time() - start_merge_time)) 191 | logger.info('total time: {}s'.format(time.time() - start_time)) 192 | 193 | 194 | def _mp_fn(index): 195 | main() 196 | 197 | 198 | if __name__ == "__main__": 199 | main() 200 | -------------------------------------------------------------------------------- /scripts/manager.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import argparse 4 | import subprocess 5 | from deepex.utils import * 6 | from requests import get 7 | 8 | def SysCall(command): 9 | subprocess.Popen( 10 | command, 11 | shell=True 12 | ).wait() 13 | 14 | def PreprocessData(META_TASK, RAW_PATH, DATA_PATH): 15 | if META_TASK in ["OIE_2016"]: 16 | Create(DATA_PATH) 17 | for i, t in enumerate(['test', 'dev']): 18 | data = []; file_path = RAW_PATH+f"{t}.txt" 19 | if os.path.exists(file_path): 20 | with open(RAW_PATH+f"{t}.txt","r") as f: 21 | for j,line in enumerate(f): 22 | data.append( 23 | { 24 | "id": str(j+1), 25 | "title": str(j+1), 26 | "text": line[:-1].replace('(',' ').replace(')',' '), 27 | } 28 | ) 29 | elif t=='test': 30 | raise Exception(f"Test data files not found at '{file_path}'!") 31 | SaveJSON(data,DATA_PATH+f"P{i}.jsonl",jsonl=True) 32 | elif META_TASK in ["FewRel","TACRED"]: 33 | Create(DATA_PATH) 34 | data = LoadJSON(f"data/{META_TASK}/data.jsonl",jsonl=True) 35 | SaveJSON(data,DATA_PATH+"P0.jsonl",jsonl=True) 36 | else: 37 | Create(DATA_PATH); data = [] 38 | with open(RAW_PATH+"{0}.raw".format(META_TASK.lower()),"r") as f: 39 | for j,line in enumerate(f): 40 | data.append( 41 | { 42 | "id": str(j+1), 43 | "title": str(j+1), 44 | "text": line[:-1].replace('(',' ').replace(')',' '), 45 | } 46 | ) 47 | SaveJSON(data,DATA_PATH+"P0.jsonl",jsonl=True) 48 | 49 | if __name__=="__main__": 50 | parser = argparse.ArgumentParser(formatter_class = argparse.RawTextHelpFormatter) 51 | parser.add_argument("-t", "--task", dest="task", type=str, default='OIE_2016', 52 | choices=[ 53 | 'OIE_2016', 54 | 'WEB', 55 | 'NYT', 56 | 'PENN', 57 | 'FewRel', 58 | 'TACRED', 59 | ], 60 | help = "The task to be run" 61 | ) 62 | parser.add_argument("-m", "--model", dest="model", type=str, default='bert-large-cased', 63 | choices=[ 64 | 'bert-base-cased', 65 | 'bert-large-cased', 66 | ], 67 | help = "The pre-trained model type to be used for generating attention matrices to perform beam search on" 68 | ) 69 | parser.add_argument("-q", "--beam-size", dest="beam_size", type=int, default=6, 70 | help = "The beam size during beam search" 71 | ) 72 | parser.add_argument("-k", "--max-distance", dest="max_distance", type=int, default=2048, 73 | help = "The maximum distance allowed between entities during beam search" 74 | ) 75 | parser.add_argument("-b", "--batch-size-per-device", dest="batch_size_per_device", type=int, default=4, 76 | help = "The batch size on a single device", 77 | ) 78 | parser.add_argument("-s", "--stage", dest="stage", type=int, default=0, 79 | help = "Run task starting from an intermediate stage:\n0). data preparation and beam-search\n1). post processing\n2). ranking\n3). evaluation", 80 | ) 81 | parser.add_argument("-d", "--debug", dest="debug", action="store_const", const=True, default=False, 82 | help="If true, only the specified stage will be run, otherwise successive stages will be run." 83 | ) 84 | parser.add_argument("-c", "--clean-history", dest="clean_history", action="store_const", const=True, default=False, 85 | help="If true, clean history runs." 86 | ) 87 | parser.add_argument("-p", "--prepare-rc-dataset", action="store_const", const=True, default=False, 88 | help="If true, automatically run the rc dataset preparation scripts." 89 | ) 90 | parser.add_argument("--cuda", dest="cuda", type=str, default="0,1,2,3,4,5,6,7", 91 | help="Specify CUDA gpu devices." 92 | ) 93 | args = parser.parse_args(); CUR_DIR = os.getcwd()+"/" 94 | config = LoadJSON("tasks/configs/"+args.task+".json") 95 | args.batch_size = args.batch_size_per_device*len(args.cuda.split(',')) 96 | args.task_abbr = config['task_abbr'] 97 | args.task_meta = config['task_meta'] 98 | args.data_dir = config['data_dir' ] 99 | args.proc_dir = "output/data/"+args.task_meta+"/" 100 | args.outp_dir = "output/output/"+args.task_meta+"/" 101 | args.clss_dir = "output/classified/"+args.task_meta+"/" 102 | args.beam_mode = 'IE' if config['task_abbr']=='oie' else 'RC' 103 | args.ner_mode = 'np' if config['task_abbr']=='oie' else 'rc' 104 | args.part = '0' 105 | 106 | if args.clean_history and args.stage==0 and os.path.exists("output/"): 107 | shutil.rmtree("output/"); shutil.rmtree("runs/") 108 | Create("runs/") 109 | Create("result/") 110 | Create("output/") 111 | Create("output/data/") 112 | Create("output/output/") 113 | Create("output/classified/") 114 | Create(args.proc_dir) 115 | Create(args.outp_dir) 116 | Create(args.clss_dir) 117 | 118 | if args.stage<=0 and (args.stage==0 or not args.debug): 119 | if args.prepare_rc_dataset: 120 | assert (args.task in ['FewRel','TACRED']), ("Only task 'FewRel' and 'TACRED' support `--prepare-rc-dataset` argument!") 121 | if args.task=='TACRED': 122 | assert (os.path.exists("./tacred_LDC2018T24.tgz")), ("Please first download TACRED datase according to README.md! The downloaded file should be named as `tacred_LDC2018T24.tgz`.") 123 | SysCall( 124 | "bash scripts/rc/prep_{}.sh".format(args.task) 125 | ) 126 | PreprocessData(args.task_meta,args.data_dir,args.proc_dir) 127 | if args.part is not None: 128 | reserves = set(["P"+i+".jsonl" for i in args.part.split(',')]) 129 | for FILE in os.listdir(args.proc_dir): 130 | if FILE not in reserves: 131 | os.remove(args.proc_dir+FILE) 132 | SysCall( 133 | ("bash scripts/processing.sh %s 1266 {1} None {2} {3} {4} {8} fast_unsupervised_bidirectional_beam_search 256 score_len 1 mean sum 1 {7} {6} {5}"%args.cuda) 134 | .format(args.task_abbr,args.proc_dir,args.outp_dir,args.model.split('-')[0]+" "+args.model,args.ner_mode,args.beam_mode,args.beam_size,args.max_distance,args.batch_size) 135 | ) 136 | 137 | if args.stage<=1 and (args.stage==1 or not args.debug): 138 | for FOLDER in os.listdir(args.outp_dir): 139 | model, _, ner_mode, _, _, _, _, _, _, beam_size = FOLDER.split('.') 140 | if (model==args.model) and (ner_mode==args.ner_mode) and (int(beam_size)==args.beam_size): 141 | reserves = set([str(i) for i in args.part.split(',')]); classes = set() 142 | for BATCH in os.listdir(args.outp_dir+FOLDER+"/"): 143 | if BATCH != "run.log": 144 | part = BATCH.split('_')[0]; batch_folder = args.outp_dir+FOLDER+"/"+BATCH+"/" 145 | if part in reserves: 146 | classified = args.clss_dir+"P"+part+"/"; Create(classified) 147 | SysCall( 148 | "cp -r {0} {1}" 149 | .format(batch_folder,classified) 150 | ) 151 | classes.add(classified) 152 | for classified in classes: 153 | SysCall( 154 | ("bash scripts/post_processing.sh %s 1266 {1} None {2} {3} {4} {8} fast_unsupervised_bidirectional_beam_search 256 score_len 1 mean sum 1 {7} {6} {5}"%args.cuda) 155 | .format(args.task_abbr,args.proc_dir,classified,args.model.split('-')[0]+" "+args.model,args.ner_mode,args.beam_mode,args.beam_size,args.max_distance,args.batch_size) 156 | ) 157 | 158 | if args.stage<=2 and (args.stage==2 or not args.debug): 159 | RESULT = "result/" + ".".join([args.task,args.model,args.ner_mode,f"d{args.max_distance}",f"b{args.beam_size}"]) 160 | SysCall( 161 | "python3 scripts/ranking.py -proc_dir {0} -clss_dir {1} -dest {2}".format(args.proc_dir,args.clss_dir,RESULT+".unsort") 162 | ) 163 | SysCall( 164 | "python3 scripts/ranking.py -proc_dir {0} -clss_dir {1} -dest {2}".format(args.proc_dir,args.clss_dir,RESULT+".sorted") 165 | ) 166 | 167 | if args.stage<=3 and (args.stage==3 or not args.debug) and (args.task_meta in ['OIE_2016','WEB','NYT','PENN']): 168 | RESULT = "result/" + ".".join([args.task,args.model,args.ner_mode,f"d{args.max_distance}",f"b{args.beam_size}"]) 169 | SysCall( 170 | "python3 scripts/oie/evaluate_oie.py -dir {0} -task {1}".format(RESULT+".sorted/",args.task) 171 | ) 172 | 173 | if args.stage<=3 and (args.stage==3 or not args.debug) and (args.task_meta in ['FewRel', 'TACRED']): 174 | RESULT = "result/" + ".".join([args.task,args.model,args.ner_mode,f"d{args.max_distance}",f"b{args.beam_size}"]); Create("scripts/rc/data/") 175 | SysCall( 176 | "cp {0}P0_data.jsonl scripts/rc/data/{1}_data.jsonl".format(RESULT+".sorted/",args.task_meta.lower()) 177 | + "&& cp {0}P0_result.json scripts/rc/data/{1}_result.json".format(RESULT+".sorted/",args.task_meta.lower()) 178 | + "&& bash scripts/rc/eval_" + args.task_meta + ".sh" 179 | ) 180 | -------------------------------------------------------------------------------- /scripts/oie/benchmark.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Usage: 3 | benchmark --gold=GOLD_OIE --out=OUTPUT_FILE (--stanford=STANFORD_OIE | --ollie=OLLIE_OIE |--reverb=REVERB_OIE | --clausie=CLAUSIE_OIE | --openiefour=OPENIEFOUR_OIE | --props=PROPS_OIE | --tabbed=TABBED_OIE) [--exactMatch | --predMatch | --argMatch] [--error-file=ERROR_FILE] 4 | 5 | Options: 6 | --gold=GOLD_OIE The gold reference Open IE file (by default, it should be under ./oie_corpus/all.oie). 7 | --out-OUTPUT_FILE The output file, into which the precision recall curve will be written. 8 | --clausie=CLAUSIE_OIE Read ClausIE format from file CLAUSIE_OIE. 9 | --ollie=OLLIE_OIE Read OLLIE format from file OLLIE_OIE. 10 | --openiefour=OPENIEFOUR_OIE Read Open IE 4 format from file OPENIEFOUR_OIE. 11 | --props=PROPS_OIE Read PropS format from file PROPS_OIE 12 | --reverb=REVERB_OIE Read ReVerb format from file REVERB_OIE 13 | --stanford=STANFORD_OIE Read Stanford format from file STANFORD_OIE 14 | --tabbed=TABBED_OIE Read simple tab format file, where each line consists of: 15 | sent, prob, pred,arg1, arg2, ... 16 | --exactmatch Use exact match when judging whether an extraction is correct. 17 | ''' 18 | import docopt 19 | import string 20 | import numpy as np 21 | from sklearn.metrics import precision_recall_curve 22 | from sklearn.metrics import auc 23 | import re 24 | import logging 25 | import pdb 26 | logging.basicConfig(level = logging.INFO) 27 | 28 | from oie_readers.stanfordReader import StanfordReader 29 | from oie_readers.ollieReader import OllieReader 30 | from oie_readers.reVerbReader import ReVerbReader 31 | from oie_readers.clausieReader import ClausieReader 32 | from oie_readers.openieFourReader import OpenieFourReader 33 | from oie_readers.propsReader import PropSReader 34 | from oie_readers.tabReader import TabReader 35 | 36 | from oie_readers.goldReader import GoldReader 37 | from matcher import Matcher 38 | from operator import itemgetter 39 | 40 | class Benchmark: 41 | ''' Compare the gold OIE dataset against a predicted equivalent ''' 42 | def __init__(self, gold_fn): 43 | ''' Load gold Open IE, this will serve to compare against using the compare function ''' 44 | gr = GoldReader() 45 | gr.read(gold_fn) 46 | self.gold = gr.oie 47 | 48 | def compare(self, predicted, matchingFunc, output_fn, error_file = None): 49 | ''' Compare gold against predicted using a specified matching function. 50 | Outputs PR curve to output_fn ''' 51 | 52 | y_true = [] 53 | y_scores = [] 54 | errors = [] 55 | 56 | correctTotal = 0 57 | unmatchedCount = 0 58 | predicted = Benchmark.normalizeDict(predicted) 59 | gold = Benchmark.normalizeDict(self.gold) 60 | 61 | for sent, goldExtractions in gold.items(): 62 | if sent not in predicted: 63 | for goldEx in goldExtractions: 64 | unmatchedCount += len(goldExtractions) 65 | correctTotal += len(goldExtractions) 66 | continue 67 | 68 | predictedExtractions = predicted[sent] 69 | 70 | for goldEx in goldExtractions: 71 | correctTotal += 1 72 | found = False 73 | 74 | for predictedEx in predictedExtractions: 75 | if output_fn in predictedEx.matched: 76 | continue 77 | 78 | if matchingFunc(goldEx, 79 | predictedEx, 80 | ignoreStopwords = True, 81 | ignoreCase = True): 82 | 83 | y_true.append(1) 84 | y_scores.append(predictedEx.confidence) 85 | predictedEx.matched.append(output_fn) 86 | found = True 87 | break 88 | 89 | if not found: 90 | errors.append(goldEx.index) 91 | unmatchedCount += 1 92 | 93 | for predictedEx in [x for x in predictedExtractions if (output_fn not in x.matched)]: 94 | y_true.append(0) 95 | y_scores.append(predictedEx.confidence) 96 | 97 | y_true = y_true 98 | y_scores = y_scores 99 | 100 | (p, r), optimal = Benchmark.prCurve(np.array(y_true), np.array(y_scores), 101 | recallMultiplier = ((correctTotal - unmatchedCount)/float(correctTotal))) 102 | logging.info("AUC: {}\n Optimal (precision, recall, F1, threshold): {}".format(auc(r, p), 103 | optimal)) 104 | 105 | if error_file: 106 | logging.info("Writing {} error indices to {}".format(len(errors), 107 | error_file)) 108 | with open(error_file, 'w') as fout: 109 | fout.write('\n'.join([str(error) 110 | for error 111 | in errors]) + '\n') 112 | 113 | with open(output_fn, 'w') as fout: 114 | fout.write('{0}\t{1}\n'.format("Precision", "Recall")) 115 | for cur_p, cur_r in sorted(zip(p, r), key = lambda x: x[1]): 116 | fout.write('{0}\t{1}\n'.format(cur_p, cur_r)) 117 | 118 | @staticmethod 119 | def prCurve(y_true, y_scores, recallMultiplier): 120 | y_scores = [score \ 121 | if not (np.isnan(score) or (not np.isfinite(score))) \ 122 | else 0 123 | for score in y_scores] 124 | 125 | precision_ls, recall_ls, thresholds = precision_recall_curve(y_true, y_scores) 126 | recall_ls = recall_ls * recallMultiplier 127 | optimal = max([(precision, recall, f_beta(precision, recall, beta = 1), threshold) 128 | for ((precision, recall), threshold) 129 | in zip(zip(precision_ls[:-1], recall_ls[:-1]), 130 | thresholds)], 131 | key = itemgetter(2)) 132 | 133 | return ((precision_ls, recall_ls), 134 | optimal) 135 | 136 | @staticmethod 137 | def normalizeDict(d): 138 | return dict([(Benchmark.normalizeKey(k), v) for k, v in d.items()]) 139 | 140 | @staticmethod 141 | def normalizeKey(k): 142 | return Benchmark.removePunct(str(Benchmark.PTB_unescape(k.replace(' ','')))) 143 | 144 | @staticmethod 145 | def PTB_escape(s): 146 | for u, e in Benchmark.PTB_ESCAPES: 147 | s = s.replace(u, e) 148 | return s 149 | 150 | @staticmethod 151 | def PTB_unescape(s): 152 | for u, e in Benchmark.PTB_ESCAPES: 153 | s = s.replace(e, u) 154 | return s 155 | 156 | @staticmethod 157 | def removePunct(s): 158 | return Benchmark.regex.sub('', s) 159 | 160 | regex = re.compile('[%s]' % re.escape(string.punctuation)) 161 | 162 | PTB_ESCAPES = [('(', '-LRB-'), 163 | (')', '-RRB-'), 164 | ('[', '-LSB-'), 165 | (']', '-RSB-'), 166 | ('{', '-LCB-'), 167 | ('}', '-RCB-'),] 168 | 169 | 170 | def f_beta(precision, recall, beta = 1): 171 | beta = float(beta) 172 | return (1 + pow(beta, 2)) * (precision * recall) / ((pow(beta, 2) * precision) + recall) 173 | 174 | 175 | f1 = lambda precision, recall: f_beta(precision, recall, beta = 1) 176 | 177 | 178 | 179 | 180 | if __name__ == '__main__': 181 | args = docopt.docopt(__doc__) 182 | logging.debug(args) 183 | 184 | if args['--stanford']: 185 | predicted = StanfordReader() 186 | predicted.read(args['--stanford']) 187 | 188 | if args['--props']: 189 | predicted = PropSReader() 190 | predicted.read(args['--props']) 191 | 192 | if args['--ollie']: 193 | predicted = OllieReader() 194 | predicted.read(args['--ollie']) 195 | 196 | if args['--reverb']: 197 | predicted = ReVerbReader() 198 | predicted.read(args['--reverb']) 199 | 200 | if args['--clausie']: 201 | predicted = ClausieReader() 202 | predicted.read(args['--clausie']) 203 | 204 | if args['--openiefour']: 205 | predicted = OpenieFourReader() 206 | predicted.read(args['--openiefour']) 207 | 208 | if args['--tabbed']: 209 | predicted = TabReader() 210 | predicted.read(args['--tabbed']) 211 | 212 | if args['--exactMatch']: 213 | matchingFunc = Matcher.argMatch 214 | 215 | elif args['--predMatch']: 216 | matchingFunc = Matcher.predMatch 217 | 218 | elif args['--argMatch']: 219 | matchingFunc = Matcher.argMatch 220 | 221 | else: 222 | matchingFunc = Matcher.lexicalMatch 223 | 224 | b = Benchmark(args['--gold']) 225 | out_filename = args['--out'] 226 | 227 | logging.info("Writing PR curve of {} to {}".format("DeepEx", out_filename)) 228 | b.compare(predicted = predicted.oie, 229 | matchingFunc = matchingFunc, 230 | output_fn = out_filename, 231 | error_file = args["--error-file"]) 232 | -------------------------------------------------------------------------------- /scripts/oie/evaluate_oie.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from deepex.utils import * 3 | 4 | def SysCall(command): 5 | subprocess.Popen( 6 | command, 7 | shell=True 8 | ).wait() 9 | 10 | def TopK(Ks, task): 11 | for k in Ks: 12 | with open(f"supervised-oie/supervised-oie-benchmark/systems_output/deepex.{task.lower()}.{k}.txt","w") as W: 13 | c = 0 14 | with open(f"supervised-oie/supervised-oie-benchmark/systems_output/deepex.{task.lower()}.txt","r") as R: 15 | for line in R: 16 | data = line.strip().split('\t') 17 | if len(data) == 1: 18 | c = 0; W.write(line) 19 | elif len(data) == 5 and c < k: 20 | c += 1; W.write(line) 21 | 22 | def BuildEvaluationScript(Ks, task): 23 | config = LoadJSON(f"tasks/configs/{task}.json") 24 | with open(f"supervised-oie/supervised-oie-benchmark/evaluate.{task.lower()}.sh","w") as W: 25 | W.write( 26 | """ 27 | mkdir -p ./eval_data/ 28 | mkdir -p ./eval_log/ 29 | mkdir -p ./eval_data/{0}/ 30 | mkdir -p ./eval_log/{0}/ 31 | """.format(task) 32 | ) 33 | 34 | for k in Ks: 35 | W.write( 36 | """ 37 | python3 benchmark.py --gold={0} --out={1} --clausie={2} 38 | echo "{4}" 39 | """.format(config['gold'], 40 | f"eval_data/{task}/deepex.{task.lower()}.{k}.dat", 41 | f"systems_output/deepex.{task.lower()}.{k}.txt", 42 | f"eval_log/{task}/deepex.{task.lower()}.{k}.log", 43 | f"{task} (top {k})", 44 | ) 45 | ) 46 | 47 | if __name__=="__main__": 48 | parser = argparse.ArgumentParser() 49 | parser.add_argument('-dir', dest='dir', type=str) 50 | parser.add_argument('-task', dest='task', type=str) 51 | args = parser.parse_args() 52 | reserves = set([FILE.split('_')[0] for FILE in os.listdir(args.dir)]) 53 | assert(len(reserves)==1) 54 | for part in reserves: 55 | result = LoadJSON(args.dir+f"{part}_result.json") 56 | data = LoadJSON(args.dir+f"{part}_data.jsonl",jsonl=True) 57 | with open(f"supervised-oie/supervised-oie-benchmark/systems_output/deepex.{args.task.lower()}.txt","w") as f: 58 | for ID,sentence in enumerate(data,1): 59 | strID = '0' * (40 - len(str(ID))) + str(ID) 60 | f.write(sentence['text']+"\n") 61 | if strID in result.keys(): 62 | for triple in result[strID]: 63 | f.write( 64 | str(ID)+'\t'+ 65 | ('"'+sentence['text'][triple['subject_char_span'][0]:triple['subject_char_span'][1]]+'"')+'\t'+ 66 | ('"'+triple['relation']+'"')+'\t'+ 67 | ('"'+sentence['text'][triple['object_char_span'][0]:triple['object_char_span'][1]]+'"')+'\t'+ 68 | str(triple['score'] if args.dir.endswith(".unsort/") else -triple['contrastive_dis'])+'\n' 69 | ) 70 | K = [3] if args.task=='OIE_2016' else [1] 71 | TopK(K,task=args.task) 72 | BuildEvaluationScript(K,task=args.task) 73 | SysCall( 74 | f"cp -rf scripts/oie/* supervised-oie/supervised-oie-benchmark/" 75 | ) 76 | SysCall( 77 | f"cd supervised-oie/supervised-oie-benchmark/ && bash evaluate.{args.task.lower()}.sh" 78 | ) -------------------------------------------------------------------------------- /scripts/oie/matcher.py: -------------------------------------------------------------------------------- 1 | import string 2 | from nltk.translate.bleu_score import sentence_bleu 3 | import nltk 4 | try: 5 | stopwords.words('english') 6 | except: 7 | nltk.download('stopwords') 8 | from nltk.corpus import stopwords 9 | 10 | class Matcher: 11 | @staticmethod 12 | def bowMatch(ref, ex, ignoreStopwords, ignoreCase): 13 | """ 14 | A binary function testing for exact lexical match (ignoring ordering) between reference 15 | and predicted extraction 16 | """ 17 | s1 = ref.bow() 18 | s2 = ex.bow() 19 | if ignoreCase: 20 | s1 = s1.lower() 21 | s2 = s2.lower() 22 | 23 | s1Words = s1.split(' ') 24 | s2Words = s2.split(' ') 25 | 26 | if ignoreStopwords: 27 | s1Words = Matcher.removeStopwords(s1Words) 28 | s2Words = Matcher.removeStopwords(s2Words) 29 | 30 | return sorted(s1Words) == sorted(s2Words) 31 | 32 | @staticmethod 33 | def predMatch(ref, ex, ignoreStopwords, ignoreCase): 34 | """ 35 | Return whehter gold and predicted extractions agree on the predicate 36 | """ 37 | s1 = ref.elementToStr(ref.pred) 38 | s2 = ex.elementToStr(ex.pred) 39 | if ignoreCase: 40 | s1 = s1.lower() 41 | s2 = s2.lower() 42 | 43 | s1Words = s1.split(' ') 44 | s2Words = s2.split(' ') 45 | 46 | if ignoreStopwords: 47 | s1Words = Matcher.removeStopwords(s1Words) 48 | s2Words = Matcher.removeStopwords(s2Words) 49 | 50 | return s1Words == s2Words 51 | 52 | 53 | @staticmethod 54 | def argMatch(ref, ex, ignoreStopwords, ignoreCase): 55 | """ 56 | Return whehter gold and predicted extractions agree on the arguments 57 | """ 58 | sRef = ' '.join([ref.elementToStr(elem) for elem in ref.args]) 59 | sEx = ' '.join([ex.elementToStr(elem) for elem in ex.args]) 60 | 61 | count = 0 62 | 63 | for w1 in sRef: 64 | for w2 in sEx: 65 | if w1 == w2: 66 | count += 1 67 | 68 | # We check how well does the extraction lexically cover the reference 69 | # Note: this is somewhat lenient as it doesn't penalize the extraction for 70 | # being too long 71 | coverage = float(count) / len(sRef) 72 | 73 | 74 | return coverage > Matcher.LEXICAL_THRESHOLD 75 | 76 | @staticmethod 77 | def bleuMatch(ref, ex, ignoreStopwords, ignoreCase): 78 | sRef = ref.bow() 79 | sEx = ex.bow() 80 | bleu = sentence_bleu(references = [sRef.split(' ')], hypothesis = sEx.split(' ')) 81 | return bleu > Matcher.BLEU_THRESHOLD 82 | 83 | @staticmethod 84 | def lexicalMatch(ref, ex, ignoreStopwords, ignoreCase): 85 | sRef = ref.bow().split(' ') 86 | sEx = ex.bow().split(' ') 87 | count = 0 88 | 89 | for w1 in sRef: 90 | for w2 in sEx: 91 | if w1 == w2: 92 | count += 1 93 | 94 | # We check how well does the extraction lexically cover the reference 95 | # Note: this is somewhat lenient as it doesn't penalize the extraction for 96 | # being too long 97 | coverage = float(count) / len(sRef) 98 | 99 | 100 | return coverage > Matcher.LEXICAL_THRESHOLD 101 | 102 | @staticmethod 103 | def removeStopwords(ls): 104 | return [w for w in ls if w.lower() not in Matcher.stopwords] 105 | 106 | # CONSTANTS 107 | BLEU_THRESHOLD = 0.4 108 | LEXICAL_THRESHOLD = 0.5 # Note: changing this value didn't change the ordering of the tested systems 109 | stopwords = stopwords.words('english') + list(string.punctuation) 110 | -------------------------------------------------------------------------------- /scripts/oie/oie_readers/extraction.py: -------------------------------------------------------------------------------- 1 | from sklearn.preprocessing.data import binarize 2 | from oie_readers.argument import Argument 3 | from operator import itemgetter 4 | from collections import defaultdict 5 | import nltk 6 | import itertools 7 | import logging 8 | import numpy as np 9 | import pdb 10 | 11 | class Extraction: 12 | def __init__(self, pred, head_pred_index, sent, confidence, question_dist = '', index = -1): 13 | self.pred = pred 14 | self.head_pred_index = head_pred_index 15 | self.sent = sent 16 | self.args = [] 17 | self.confidence = confidence 18 | self.matched = [] 19 | self.questions = {} 20 | self.indsForQuestions = defaultdict(lambda: set()) 21 | self.is_mwp = False 22 | self.question_dist = question_dist 23 | self.index = index 24 | 25 | def distArgFromPred(self, arg): 26 | assert(len(self.pred) == 2) 27 | dists = [] 28 | for x in self.pred[1]: 29 | for y in arg.indices: 30 | dists.append(abs(x - y)) 31 | 32 | return min(dists) 33 | 34 | def argsByDistFromPred(self, question): 35 | return sorted(self.questions[question], key = lambda arg: self.distArgFromPred(arg)) 36 | 37 | def addArg(self, arg, question = None): 38 | self.args.append(arg) 39 | if question: 40 | self.questions[question] = self.questions.get(question,[]) + [Argument(arg)] 41 | 42 | def noPronounArgs(self): 43 | for (a, _) in self.args: 44 | tokenized_arg = nltk.word_tokenize(a) 45 | if len(tokenized_arg) == 1: 46 | _, pos_tag = nltk.pos_tag(tokenized_arg)[0] 47 | if ('PRP' in pos_tag): 48 | return False 49 | return True 50 | 51 | def isContiguous(self): 52 | return all([indices for (_, indices) in self.args]) 53 | 54 | def toBinary(self): 55 | ''' Try to represent this extraction's arguments as binary 56 | If fails, this function will return an empty list. ''' 57 | 58 | ret = [self.elementToStr(self.pred)] 59 | 60 | if len(self.args) == 2: 61 | return ret + [self.elementToStr(arg) for arg in self.args] 62 | 63 | return [] 64 | 65 | if not self.isContiguous(): 66 | return [] 67 | 68 | binarized = self.binarizeByIndex() 69 | 70 | if binarized: 71 | return ret + binarized 72 | 73 | return [] 74 | 75 | 76 | def elementToStr(self, elem, print_indices = True): 77 | ''' formats an extraction element (pred or arg) as a raw string 78 | removes indices and trailing spaces ''' 79 | if print_indices: 80 | return str(elem) 81 | if isinstance(elem, str): 82 | return elem 83 | if isinstance(elem, tuple): 84 | ret = elem[0].rstrip().lstrip() 85 | else: 86 | ret = ' '.join(elem.words) 87 | assert ret, "empty element? {0}".format(elem) 88 | return ret 89 | 90 | def binarizeByIndex(self): 91 | extraction = [self.pred] + self.args 92 | markPred = [(w, ind, i == 0) for i, (w, ind) in enumerate(extraction)] 93 | sortedExtraction = sorted(markPred, key = lambda ws, indices, f : indices[0]) 94 | s = ' '.join(['{1} {0} {1}'.format(self.elementToStr(elem), SEP) if elem[2] else self.elementToStr(elem) for elem in sortedExtraction]) 95 | binArgs = [a for a in s.split(SEP) if a.rstrip().lstrip()] 96 | 97 | if len(binArgs) == 2: 98 | return binArgs 99 | 100 | return [] 101 | 102 | def bow(self): 103 | return ' '.join([self.elementToStr(elem) for elem in [self.pred] + self.args]) 104 | 105 | def getSortedArgs(self): 106 | if self.question_dist: 107 | return self.sort_args_by_distribution() 108 | ls = [] 109 | for q, args in self.questions.iteritems(): 110 | if (len(args) != 1): 111 | logging.debug("Not one argument: {}".format(args)) 112 | continue 113 | arg = args[0] 114 | indices = list(self.indsForQuestions[q].union(arg.indices)) 115 | if not indices: 116 | logging.debug("Empty indexes for arg {} -- backing to zero".format(arg)) 117 | indices = [0] 118 | ls.append(((arg, q), indices)) 119 | return [a for a, _ in sorted(ls, 120 | key = lambda indices: min(indices[1]))] 121 | 122 | def question_prob_for_loc(self, question, loc): 123 | gen_question = generalize_question(question) 124 | q_dist = self.question_dist[gen_question] 125 | logging.debug("distribution of {}: {}".format(gen_question, 126 | q_dist)) 127 | 128 | return float(q_dist.get(loc, 0)) / \ 129 | sum(q_dist.values()) 130 | 131 | def sort_args_by_distribution(self): 132 | INF_LOC = 100 133 | 134 | ret = {INF_LOC: []} 135 | logging.debug("sorting: {}".format(self.questions)) 136 | 137 | logging.debug("probs for subject: {}".format([(q, self.question_prob_for_loc(q, 0)) 138 | for (q, _) in self.questions.iteritems()])) 139 | 140 | subj_question, subj_args = max(self.questions.iteritems(), 141 | key = lambda q: self.question_prob_for_loc(q[0], 0)) 142 | 143 | ret[0] = [(subj_args[0], subj_question)] 144 | 145 | for (question, args) in sorted([(q, a) 146 | for (q, a) in self.questions.iteritems() if (q not in [subj_question])], 147 | key = lambda q: \ 148 | sum(self.question_dist[generalize_question(q[0])].values()), 149 | reverse = True): 150 | gen_question = generalize_question(question) 151 | arg = args[0] 152 | assigned_flag = False 153 | for (loc, count) in sorted(self.question_dist[gen_question].iteritems(), 154 | key = lambda c: c[1], 155 | reverse = True): 156 | if loc not in ret: 157 | ret[loc] = [(arg, question)] 158 | assigned_flag = True 159 | break 160 | 161 | if not assigned_flag: 162 | logging.debug("Couldn't find an open assignment for {}".format((arg, gen_question))) 163 | ret[INF_LOC].append((arg, question)) 164 | 165 | logging.debug("Linearizing arg list: {}".format(ret)) 166 | 167 | return [arg 168 | for (_, arg_ls) in sorted(ret.iteritems(), 169 | key = lambda k, v: int(k)) 170 | for arg in arg_ls] 171 | 172 | 173 | def __str__(self): 174 | pred_str = self.elementToStr(self.pred) 175 | return '{}\t{}\t{}'.format(self.get_base_verb(pred_str), 176 | self.compute_global_pred(pred_str, 177 | self.questions.keys()), 178 | '\t'.join([escape_special_chars(self.augment_arg_with_question(self.elementToStr(arg), 179 | question)) 180 | for arg, question in self.getSortedArgs()])) 181 | 182 | def get_base_verb(self, surface_pred): 183 | return surface_pred.split(' ')[-1] 184 | 185 | 186 | def compute_global_pred(self, surface_pred, questions): 187 | from operator import itemgetter 188 | split_surface = surface_pred.split(' ') 189 | 190 | if len(split_surface) > 1: 191 | verb = split_surface[-1] 192 | ret = split_surface[:-1] 193 | else: 194 | verb = split_surface[0] 195 | ret = [] 196 | 197 | split_questions = map(lambda question: question.split(' '), 198 | questions) 199 | 200 | preds = map(normalize_element, 201 | map(itemgetter(QUESTION_TRG_INDEX), 202 | split_questions)) 203 | if len(set(preds)) > 1: 204 | ret.append(verb) 205 | 206 | if len(set(preds)) == 1: 207 | ret.append(preds[0]) 208 | 209 | pps = map(normalize_element, 210 | map(itemgetter(QUESTION_PP_INDEX), 211 | split_questions)) 212 | 213 | obj2s = map(normalize_element, 214 | map(itemgetter(QUESTION_OBJ2_INDEX), 215 | split_questions)) 216 | 217 | if (len(set(pps)) == 1): 218 | self.is_mwp = True 219 | ret.append(pps[0]) 220 | 221 | return " ".join(ret).strip() 222 | 223 | 224 | def augment_arg_with_question(self, arg, question): 225 | wh, aux, sbj, trg, obj1, pp, obj2 = map(normalize_element, 226 | question.split(' ')[:-1]) 227 | 228 | if (not self.is_mwp) and pp and (not obj2): 229 | if not(arg.startswith("{} ".format(pp))): 230 | return " ".join([pp, 231 | arg]) 232 | 233 | return arg 234 | 235 | def clusterScore(self, cluster): 236 | logging.debug("*-*-*- Cluster: {}".format(cluster)) 237 | 238 | arr = np.array([x for ls in cluster for x in ls]) 239 | centroid = np.sum(arr)/arr.shape[0] 240 | logging.debug("Centroid: {}".format(centroid)) 241 | 242 | return np.average([max([abs(x - centroid) for x in ls]) for ls in cluster]) 243 | 244 | def resolveAmbiguity(self): 245 | 246 | elements = [self.pred] \ 247 | + [(s, indices) 248 | for (s, indices) 249 | in self.args 250 | if indices] 251 | logging.debug("Resolving ambiguity in: {}".format(elements)) 252 | 253 | all_combinations = list(itertools.product(*map(itemgetter(1), elements))) 254 | logging.debug("Number of combinations: {}".format(len(all_combinations))) 255 | 256 | resolved_elements = zip(map(itemgetter(0), elements), 257 | min(all_combinations, 258 | key = lambda cluster: self.clusterScore(cluster))) 259 | logging.debug("Resolved elements = {}".format(resolved_elements)) 260 | 261 | self.pred = resolved_elements[0] 262 | self.args = resolved_elements[1:] 263 | 264 | def conll(self, external_feats = {}): 265 | return '\n'.join(["\t".join(map(str, 266 | [i, w] + \ 267 | list(self.pred) + \ 268 | [self.head_pred_index] + \ 269 | external_feats + \ 270 | [self.get_label(i)])) 271 | for (i, w) 272 | in enumerate(self.sent.split(" "))]) + '\n' 273 | 274 | def get_label(self, index): 275 | ent = [(elem_ind, elem) 276 | for (elem_ind, elem) 277 | in enumerate(map(itemgetter(1), 278 | [self.pred] + self.args)) 279 | if index in elem] 280 | 281 | if not ent: 282 | return "O" 283 | 284 | if len(ent) > 1: 285 | logging.warn("Index {} appears in one than more element: {}".\ 286 | format(index, 287 | "\t".join(map(str, 288 | [ent, 289 | self.sent, 290 | self.pred, 291 | self.args])))) 292 | 293 | 294 | elem_ind, elem = min(ent, key = lambda ls: len(ls[1])) 295 | 296 | prefix = "P" if elem_ind == 0 else "A{}".format(elem_ind - 1) 297 | 298 | suffix = "B" if index == elem[0] else "I" 299 | 300 | return "{}-{}".format(prefix, suffix) 301 | 302 | def __str__(self): 303 | return '{0}\t{1}'.format(self.elementToStr(self.pred, 304 | print_indices = True), 305 | '\t'.join([self.elementToStr(arg) 306 | for arg 307 | in self.args])) 308 | 309 | flatten = lambda l: [item for sublist in l for item in sublist] 310 | 311 | 312 | def normalize_element(elem): 313 | return elem.replace("_", " ") \ 314 | if (elem != "_")\ 315 | else "" 316 | 317 | def escape_special_chars(s): 318 | return s.replace('\t', '\\t') 319 | 320 | 321 | def generalize_question(question): 322 | import nltk 323 | wh, aux, sbj, trg, obj1, pp, obj2 = question.split(' ')[:-1] 324 | return ' '.join([wh, sbj, obj1]) 325 | 326 | 327 | 328 | SEP = ';;;' 329 | QUESTION_TRG_INDEX = 3 330 | QUESTION_PP_INDEX = 5 331 | QUESTION_OBJ2_INDEX = 6 332 | -------------------------------------------------------------------------------- /scripts/oie/oie_readers/goldReader.py: -------------------------------------------------------------------------------- 1 | from oie_readers.oieReader import OieReader 2 | from oie_readers.extraction import Extraction 3 | from _collections import defaultdict 4 | 5 | class GoldReader(OieReader): 6 | 7 | default_filename = './oie_corpus/all.oie' 8 | 9 | def __init__(self): 10 | self.name = 'Gold' 11 | 12 | def read(self, fn): 13 | d = defaultdict(lambda: []) 14 | with open(fn) as fin: 15 | for line_ind, line in enumerate(fin): 16 | data = line.strip().split('\t') 17 | text, rel = data[:2] 18 | args = data[2:] 19 | confidence = 1 20 | 21 | curExtraction = Extraction(pred = rel, 22 | head_pred_index = None, 23 | sent = text, 24 | confidence = float(confidence), 25 | index = line_ind) 26 | for arg in args: 27 | curExtraction.addArg(arg) 28 | 29 | d[text].append(curExtraction) 30 | self.oie = d 31 | 32 | 33 | if __name__ == '__main__' : 34 | g = GoldReader() 35 | g.read('../oie_corpus/all.oie', includeNominal = False) 36 | d = g.oie 37 | e = d.items()[0] 38 | print(e[1][0].bow()) 39 | print (g.count()) 40 | -------------------------------------------------------------------------------- /scripts/post_processing.py: -------------------------------------------------------------------------------- 1 | import dataclasses 2 | from dataclasses import dataclass, field 3 | from transformers.hf_argparser import HfArgumentParser 4 | from requests import get 5 | 6 | from deepex.model import Distillation, Eval 7 | 8 | 9 | @dataclass 10 | class Arguments: 11 | input_dir: str = field( 12 | default='input', metadata={"help": "input dir"} 13 | ) 14 | filepath: str = field( 15 | default='output', metadata={"help": "output dir"} 16 | ) 17 | topk: int = field( 18 | default=None, metadata={"help": "topk"} 19 | ) 20 | dedup_ranking_type: str = field( 21 | default='freq', metadata={"help": "deduplication ranking type"} 22 | ) 23 | sent_dedup_type: str = field( 24 | default='entity_pair', metadata={"help": "sentnece deduplication type"} 25 | ) 26 | doc_dedup_type: str = field( 27 | default='whole', metadata={"help": "doc deduplication type"} 28 | ) 29 | 30 | if __name__ == '__main__': 31 | parser = HfArgumentParser(Arguments) 32 | [args] = parser.parse_args_into_dataclasses() 33 | simple_distil = Distillation(args.input_dir, args.filepath) 34 | simple_distil.deduplicate_for_eval_fast(args.filepath, args.topk, args.dedup_ranking_type, args.sent_dedup_type, args.doc_dedup_type) 35 | evaluator = Eval() 36 | evaluator.eval_number_of_triplets_with_docid(args.filepath) 37 | print('total triplets: {}'.format(evaluator.num_triplets)) 38 | -------------------------------------------------------------------------------- /scripts/post_processing.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | CUDA_VISIBLE_DEVICES=$1 4 | master_port=$2 5 | data_dir=$3 6 | input_dir=$4 7 | output_dir=$5 8 | model_type=$6 9 | model_name_or_path=$7 10 | data_aug=$8 11 | per_device_eval_batch_size=$9 12 | generation_type=${10} 13 | max_length=${11} 14 | dedup_ranking_type=${12} 15 | add_extra_entity=${13} 16 | search_attention_head_type=${14} 17 | search_ranking_type=${15} 18 | sentence=${16} 19 | dist_const=${17} 20 | beam_size=${18} 21 | beam_mode=${19} 22 | IFS=', ' read -r -a cuda_arr <<< "$CUDA_VISIBLE_DEVICES" 23 | nproc_per_node=${#cuda_arr[@]} 24 | output_dir_ext=${output_dir}${model_name_or_path}.${generation_type}.${data_aug}.${dedup_ranking_type}.${add_extra_entity}.${search_attention_head_type}.${search_ranking_type}.${sentence}.${dist_const}.${beam_size} # output folder containing the results 25 | 26 | export CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES 27 | export TOKENIZERS_PARALLELISM=true 28 | 29 | sent_dedup_type="entity_pair" 30 | doc_dedup_type="whole" 31 | 32 | python scripts/post_processing.py --input_dir=${output_dir} \ 33 | --filepath=${output_dir}result.json \ 34 | --dedup_ranking_type=${dedup_ranking_type} \ 35 | --sent_dedup_type=${sent_dedup_type} \ 36 | --doc_dedup_type=${doc_dedup_type} 37 | -------------------------------------------------------------------------------- /scripts/processing.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | CUDA_VISIBLE_DEVICES=$1 4 | master_port=$2 5 | data_dir=$3 6 | input_dir=$4 7 | output_dir=$5 8 | model_type=$6 9 | model_name_or_path=$7 10 | data_aug=$8 11 | per_device_eval_batch_size=$9 12 | generation_type=${10} 13 | max_length=${11} 14 | dedup_ranking_type=${12} 15 | add_extra_entity=${13} 16 | search_attention_head_type=${14} 17 | search_ranking_type=${15} 18 | sentence=${16} 19 | dist_const=${17} 20 | beam_size=${18} 21 | beam_mode=${19} 22 | IFS=', ' read -r -a cuda_arr <<< "$CUDA_VISIBLE_DEVICES" 23 | nproc_per_node=${#cuda_arr[@]} 24 | output_dir_ext=${output_dir}${model_name_or_path}.${generation_type}.${data_aug}.${dedup_ranking_type}.${add_extra_entity}.${search_attention_head_type}.${search_ranking_type}.${sentence}.${dist_const}.${beam_size} # output folder containing the results 25 | 26 | export OMP_NUM_THREADS=1 27 | export CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES 28 | export TOKENIZERS_PARALLELISM=true 29 | 30 | python -m torch.distributed.launch --nproc_per_node $nproc_per_node --master_port=$master_port scripts/generator.py \ 31 | --data_dir=${data_dir} \ 32 | --input_dir=${input_dir} \ 33 | --output_dir=${output_dir_ext} \ 34 | --model_type=$model_type \ 35 | --model_name_or_path=$model_name_or_path \ 36 | --data_aug=$data_aug \ 37 | --per_device_eval_batch_size=$per_device_eval_batch_size \ 38 | --generation_type=$generation_type \ 39 | --search_cand_type=entity \ 40 | --beam_size=$beam_size \ 41 | --search_max_len=$max_length \ 42 | --search_min_len=3 \ 43 | --search_layer_id=-1 \ 44 | --search_attention_head_type=$search_attention_head_type \ 45 | --search_ranking_type=$search_ranking_type \ 46 | --max_length=$max_length \ 47 | --dedup_ranking_type=$dedup_ranking_type \ 48 | --add_extra_entity=$add_extra_entity \ 49 | --sentence=$sentence \ 50 | --dist_const=$dist_const \ 51 | --beam_mode=$beam_mode 52 | -------------------------------------------------------------------------------- /scripts/ranking.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | from deepex.utils import * 4 | from bert_contrastive import Reranking 5 | from requests import get 6 | import torch 7 | import argparse 8 | 9 | def IP(): 10 | return get('https://api.ipify.org').text 11 | 12 | def Thresholding(data, score_thres=0.01, len_thres=20): 13 | s = [0 for _ in range(2048)] 14 | with torch.no_grad(): 15 | for (docid,triples) in tqdm.tqdm(list(data.items())): 16 | sieved_triples = [] 17 | for triple in sorted(triples,key=lambda x:x['sentence']): 18 | s[len(triple['relation'].split(' '))] += 1 19 | if ( 20 | triple['score']>=score_thres 21 | and len(triple['relation'].split(' '))<=len_thres 22 | ): 23 | sieved_triples.append(triple) 24 | data[docid] = sieved_triples 25 | 26 | if __name__=="__main__": 27 | parser = argparse.ArgumentParser() 28 | parser.add_argument("-proc_dir", dest="proc_dir", type=str) 29 | parser.add_argument("-clss_dir", dest="clss_dir", type=str) 30 | parser.add_argument("-dest", dest="dest", type=str) 31 | parser.add_argument("-score_thres", dest="score_thres", type=float, default=0.005) 32 | parser.add_argument("-len_thres", dest="len_thres", type=int, default=2048) 33 | parser.add_argument("-scoring_model_path", dest="scoring_model_path", type=str, default="Magolor/deepex-ranking-model", 34 | choices=[ 35 | "Magolor/deepex-ranking-model", 36 | ] 37 | ) 38 | args = parser.parse_args() 39 | Clear(args.dest) 40 | mentions = {} 41 | 42 | for FOLDER in os.listdir(args.clss_dir): 43 | result = LoadJSON(args.clss_dir+FOLDER+"/result.json") 44 | if args.dest.endswith("sorted"): 45 | Reranking(result, MODEL_FOLDER=args.scoring_model_path) 46 | SaveJSON(result,args.dest+f"/{FOLDER}_result.json") 47 | mentions[FOLDER] = {} 48 | for DATA_FILE in os.listdir(args.proc_dir): 49 | if Suffix(DATA_FILE)=="jsonl": 50 | data = LoadJSON(args.proc_dir+DATA_FILE,jsonl=True) 51 | SaveJSON(data,args.dest+f"/{Prefix(DATA_FILE)}_data.jsonl",jsonl=True) 52 | elif DATA_FILE.startswith("cachedmentions"): 53 | data = torch.load(args.proc_dir+DATA_FILE) 54 | mentions["P"+DATA_FILE.split('_')[1]].update(data) 55 | for FOLDER in mentions: 56 | SaveJSON({str(k[0])+'-'+str(k[1]):v for k,v in mentions[FOLDER].items()},args.dest+f"/{FOLDER}_mentions.json") -------------------------------------------------------------------------------- /scripts/rc/dataset_preparation.py: -------------------------------------------------------------------------------- 1 | import json 2 | import argparse 3 | import jsonlines 4 | from string_matcher import UnLemmatizeStringMatcher, LemmatizeStringMatcher 5 | import logging 6 | logger = logging.getLogger("spacy") 7 | logger.setLevel(logging.ERROR) 8 | 9 | def get_relation_candidates(item,lemmatized_matcher,unlemmatized_matcher): 10 | text = item['text'] 11 | 12 | un_relation_candidates = unlemmatized_matcher(text) 13 | relation_candidates = lemmatized_matcher(text) 14 | 15 | charspan_elem_dict = {} 16 | for i, elem in enumerate(relation_candidates): 17 | charspan_elem_dict[str(elem['char_span'])] = elem 18 | 19 | un_charspan_elem_dict = {} 20 | for i, un_elem in enumerate(un_relation_candidates): 21 | un_charspan_elem_dict[str(un_elem['char_span'])] = un_elem 22 | 23 | merged_alias_charspan = list(set(charspan_elem_dict.keys()) | set(un_charspan_elem_dict.keys())) 24 | 25 | merged_relation_candidates = [] 26 | for chspan in merged_alias_charspan: 27 | if chspan in charspan_elem_dict.keys() and chspan in un_charspan_elem_dict.keys(): 28 | merged_relation = list(set(charspan_elem_dict[chspan]["relation"]) | set(un_charspan_elem_dict[chspan]["relation"])) 29 | elem = charspan_elem_dict[chspan] 30 | elem["relation"] = merged_relation 31 | merged_relation_candidates.append(elem) 32 | elif chspan in charspan_elem_dict.keys(): 33 | merged_relation_candidates.append(charspan_elem_dict[chspan]) 34 | else: 35 | merged_relation_candidates.append(un_charspan_elem_dict[chspan]) 36 | return merged_relation_candidates 37 | 38 | def Prepare(dataset): 39 | lemmatized_matcher = LemmatizeStringMatcher(f'{dataset.lower()}_aliases_lemmatized.json') 40 | unlemmatized_matcher = UnLemmatizeStringMatcher(f'{dataset.lower()}_aliases_unlemmatized.json') 41 | if dataset=='FewRel': 42 | dev_relations = ['crosses', 'original language of film or TV show', 'competition class', 'part of', 'sport', 'constellation', 'position played on team / speciality', 'located in or next to body of water', 'voice type', 'follows', 'spouse', 'military rank', 'mother', 'member of', 'child', 'main subject'] 43 | data_dict = json.load(open("../../data/FewRel/val_wiki.json")) 44 | pid2name = json.load(open("../../data/FewRel/pid2name.json")) 45 | index = 0 46 | with jsonlines.open(f"../../data/{dataset}/data.jsonl", 'w') as w: 47 | for k, vs in data_dict.items(): 48 | for v in vs: 49 | item = {} 50 | item["id"] = str(index) 51 | item["title"] = v["h"][0] 52 | item["answer"] = v["t"][0] 53 | item["subject_spans"] = [v["h"][2][0]] 54 | item["object_spans"] = [v["t"][2][0]] 55 | item["tokens"] = v["tokens"] 56 | item["text"] = ' '.join(v["tokens"]) 57 | item["true_relation"] = pid2name[k][0] 58 | 59 | item['rel_candidates'] = [] 60 | for elem in get_relation_candidates(item,lemmatized_matcher,unlemmatized_matcher): 61 | elemrel = [] 62 | for r in elem["relation"]: 63 | if r in dev_relations: 64 | elemrel.append(r) 65 | elem["relation"] = elemrel 66 | if len(elem["relation"]) > 0: 67 | item['rel_candidates'].append(elem) 68 | 69 | w.write(item) 70 | index += 1 71 | elif dataset=='TACRED': 72 | data_list = json.load(open("../../data/TACRED/test.json")) 73 | index = 0 74 | with jsonlines.open("../../data/TACRED/data.jsonl", 'w') as w: 75 | for v in data_list: 76 | item = {} 77 | item["id"] = str(index) 78 | item["title"] = ' '.join(v["token"][int(v["subj_start"]):int(v["subj_end"])+1]) 79 | item["answer"] = ' '.join(v["token"][int(v["obj_start"]):int(v["obj_end"])+1]) 80 | item["subject_spans"] = [[i for i in range(int(v["subj_start"]), int(v["subj_end"])+1)]] 81 | item["object_spans"] = [[i for i in range(int(v["obj_start"]), int(v["obj_end"])+1)]] 82 | item["tokens"] = v["token"] 83 | item["text"] = ' '.join(v["token"]) 84 | item["true_relation"] = v["relation"] 85 | 86 | item['rel_candidates'] = [] 87 | for elem in get_relation_candidates(item,lemmatized_matcher,unlemmatized_matcher): 88 | if len(elem["relation"]) > 0: 89 | item['rel_candidates'].append(elem) 90 | 91 | w.write(item) 92 | index += 1 93 | 94 | if __name__=="__main__": 95 | parser = argparse.ArgumentParser() 96 | parser.add_argument("-t", "--task", dest="task", type=str, default='FewRel', 97 | choices=[ 98 | 'FewRel', 99 | 'TACRED', 100 | ], 101 | help = "The task to be run" 102 | ) 103 | args = parser.parse_args() 104 | Prepare(args.task) -------------------------------------------------------------------------------- /scripts/rc/eval_FewRel.sh: -------------------------------------------------------------------------------- 1 | echo "Run relation classification experiments" 2 | echo "***************************************" 3 | cd scripts/rc/ 4 | python3 post_process.py --task=FewRel 5 | python3 evaluation.py --task=FewRel 6 | cd ../.. 7 | echo "***************************************" 8 | -------------------------------------------------------------------------------- /scripts/rc/eval_TACRED.sh: -------------------------------------------------------------------------------- 1 | echo "Run relation classification experiments" 2 | echo "***************************************" 3 | cd scripts/rc/ 4 | python3 post_process.py --task=TACRED 5 | python3 evaluation.py --task=TACRED 6 | cd ../.. 7 | echo "***************************************" 8 | -------------------------------------------------------------------------------- /scripts/rc/evaluation.py: -------------------------------------------------------------------------------- 1 | import json 2 | import pickle 3 | import tqdm 4 | import csv 5 | import argparse 6 | import jsonlines 7 | import subprocess 8 | from collections import defaultdict 9 | import heapq 10 | 11 | def SysCall(command): 12 | subprocess.Popen( 13 | command, 14 | shell=True 15 | ).wait() 16 | 17 | def Evaluate(dataset, top_k): 18 | with open(f"data/{dataset.lower()}_result.json", 'r') as f: 19 | result = json.load(f) 20 | 21 | with open(f'data/{dataset.lower()}_id_alias2relations_dict.json', 'r') as f: 22 | id_alias2relations_dict = json.load(f) 23 | 24 | our_result = {} 25 | for k in result.keys(): 26 | min_dis = 100 27 | min_elem = None 28 | contrastive_diss = [] 29 | 30 | for elem in result[k]: 31 | contrastive_dis = elem['contrastive_dis'] 32 | contrastive_diss.append(contrastive_dis) 33 | 34 | min_dis_list = map(contrastive_diss.index, heapq.nsmallest(top_k, contrastive_diss)) 35 | 36 | our_result[str(int(k))] = [] 37 | for index in min_dis_list: 38 | our_result[str(int(k))].append(result[k][index]['relation']) 39 | 40 | val_data = {} 41 | with open(f"data/{dataset.lower()}_processed.jsonl", 'r') as f: 42 | for line in jsonlines.Reader(f): 43 | index = line['id'] 44 | rels = [] 45 | for elem in line['rel_candidates']: 46 | rels += elem['relation'] 47 | covered = False 48 | if line['true_relation'] in rels: 49 | covered = True 50 | if dataset=='FewRel' and line['true_relation'] == "main subject": 51 | if "part of" in rels: 52 | covered = True 53 | val_data[index] = {"true_relation": line['true_relation'], "covered": covered, "text": line['text'], "head": line["title"], "tail": line["answer"], "alias2relation": id_alias2relations_dict[index]} 54 | 55 | analysis = {} 56 | false_analysis = {} 57 | not_in_text = {} 58 | 59 | accuracy = 0 60 | 61 | Not_in_text = [] 62 | 63 | rel_correct = defaultdict(int) 64 | rel_all = defaultdict(int) 65 | 66 | for k, vs in our_result.items(): 67 | correct = False 68 | rel_all[val_data[k]["true_relation"]] += 1 69 | for v in vs: 70 | if not v in val_data[k]["alias2relation"].keys(): 71 | not_in_text[k] = {} 72 | Not_in_text.append({"id": k, 'val_data[k]["alias2relation"]': val_data[k]["alias2relation"], 'predict': v}) 73 | not_in_text[k]["predict"] = {"alias": v, "relations": None} 74 | not_in_text[k]["rel_candidates"] = val_data[k]["alias2relation"] 75 | not_in_text[k]["truth"] = val_data[k]["true_relation"] 76 | not_in_text[k]["is_correct"] = correct 77 | not_in_text[k]["is_covered"] = val_data[k]["covered"] 78 | not_in_text[k]["text"] = val_data[k]["text"] 79 | not_in_text[k]["head"] = val_data[k]["head"] 80 | not_in_text[k]["tail"] = val_data[k]["tail"] 81 | 82 | if val_data[k]["true_relation"] == "no_relation": 83 | correct += True 84 | accuracy += 1 85 | 86 | analysis[k] = {} 87 | analysis[k]["truth"] = val_data[k]["true_relation"] 88 | analysis[k]["is_correct"] = correct 89 | analysis[k]["is_covered"] = val_data[k]["covered"] 90 | analysis[k]["text"] = val_data[k]["text"] 91 | analysis[k]["head"] = val_data[k]["head"] 92 | analysis[k]["tail"] = val_data[k]["tail"] 93 | analysis[k]["predict"] = {"alias": v, "relations": None} 94 | rel_correct[val_data[k]["true_relation"]] += 1 95 | else: 96 | try: 97 | true_rel = val_data[k]["true_relation"].split(':')[1].replace('_', ' ') 98 | except: 99 | true_rel = '' 100 | 101 | if true_rel in val_data[k]["alias2relation"][v] or val_data[k]["true_relation"] in val_data[k]["alias2relation"][v] or (val_data[k]["true_relation"]=="main subject" and "part of" in val_data[k]["alias2relation"][v]): 102 | analysis[k] = {} 103 | correct = True 104 | accuracy += 1 105 | analysis[k]["truth"] = val_data[k]["true_relation"] 106 | analysis[k]["is_correct"] = correct 107 | analysis[k]["is_covered"] = val_data[k]["covered"] 108 | analysis[k]["text"] = val_data[k]["text"] 109 | analysis[k]["head"] = val_data[k]["head"] 110 | analysis[k]["tail"] = val_data[k]["tail"] 111 | analysis[k]["predict"] = {"alias": v, "relations": val_data[k]["alias2relation"][v]} 112 | 113 | rel_correct[val_data[k]["true_relation"]] += 1 114 | else: 115 | false_analysis[k] = {} 116 | false_analysis[k]["truth"] = val_data[k]["true_relation"] 117 | false_analysis[k]["is_correct"] = correct 118 | false_analysis[k]["is_covered"] = val_data[k]["covered"] 119 | false_analysis[k]["text"] = val_data[k]["text"] 120 | false_analysis[k]["head"] = val_data[k]["head"] 121 | false_analysis[k]["tail"] = val_data[k]["tail"] 122 | false_analysis[k]["predict"] = {"alias": v, "relations": val_data[k]["alias2relation"][v]} 123 | if correct: 124 | break 125 | 126 | no_relation = list(set(val_data.keys()) -set(our_result.keys())) 127 | 128 | not_gen = {} 129 | 130 | for k in no_relation: 131 | not_gen[k] = {} 132 | not_gen[k]["truth"] = val_data[k]["true_relation"] 133 | not_gen[k]["is_covered"] = val_data[k]["covered"] 134 | not_gen[k]["rel_candidates"] = val_data[k]["alias2relation"] 135 | not_gen[k]["text"] = val_data[k]["text"] 136 | not_gen[k]["head"] = val_data[k]["head"] 137 | not_gen[k]["tail"] = val_data[k]["tail"] 138 | 139 | recall = accuracy / len(val_data) 140 | percision = accuracy / len(our_result) 141 | f1 = 2*percision*recall/(percision+recall) 142 | 143 | print(f"Top {top_k}: F1 = {f1}") 144 | return f1 145 | 146 | if __name__=="__main__": 147 | parser = argparse.ArgumentParser() 148 | parser.add_argument("-t", "--task", dest="task", type=str, default='FewRel', 149 | choices=[ 150 | 'FewRel', 151 | 'TACRED', 152 | ], 153 | help = "The task to be run" 154 | ) 155 | args = parser.parse_args() 156 | print(args.task) 157 | Evaluate(args.task, top_k= 1) 158 | Evaluate(args.task, top_k=10) -------------------------------------------------------------------------------- /scripts/rc/post_process.py: -------------------------------------------------------------------------------- 1 | import tqdm 2 | import json 3 | import argparse 4 | import jsonlines 5 | from dataset_preparation import get_relation_candidates 6 | from string_matcher import UnLemmatizeStringMatcher, LemmatizeStringMatcher 7 | 8 | def get_processed_output(dataset): 9 | lemmatized_matcher = LemmatizeStringMatcher(f'{dataset.lower()}_aliases_lemmatized.json') 10 | unlemmatized_matcher = UnLemmatizeStringMatcher(f'{dataset.lower()}_aliases_unlemmatized.json') 11 | dev_relations = ['crosses', 'original language of film or TV show', 'competition class', 'part of', 'sport', 'constellation', 'position played on team / speciality', 'located in or next to body of water', 'voice type', 'follows', 'spouse', 'military rank', 'mother', 'member of', 'child', 'main subject'] 12 | with jsonlines.open(f'data/{dataset.lower()}_processed.jsonl', mode='w') as writer: 13 | with open(f"data/{dataset.lower()}_data.jsonl", "r", encoding="utf8") as f: 14 | for item in tqdm.tqdm([i for i in jsonlines.Reader(f)]): 15 | item['rel_candidates'] = [] 16 | for elem in get_relation_candidates(item, lemmatized_matcher, unlemmatized_matcher): 17 | elemrel = [] 18 | for r in elem["relation"]: 19 | if r in dev_relations: 20 | elemrel.append(r) 21 | elem["relation"] = elemrel 22 | if len(elem["relation"]) > 0: 23 | item['rel_candidates'].append(elem) 24 | 25 | writer.write(item) 26 | 27 | def get_id_alias2relations_dict(dataset): 28 | id_alias2relations_dict = {} 29 | with open(f'data/{dataset.lower()}_processed.jsonl', 'r') as f: 30 | for line in jsonlines.Reader(f): 31 | alias2relations_dict = {} 32 | text = line["text"] 33 | for elem in line['rel_candidates']: 34 | span = elem['char_span'] 35 | rel_candidates = [] 36 | for i in range(span[0], span[1]): 37 | rel_candidates.append(text[i]) 38 | rel_candidate = ''.join(rel_candidates) 39 | if not rel_candidate in alias2relations_dict.keys(): 40 | alias2relations_dict[rel_candidate] = [] 41 | alias2relations_dict[rel_candidate] = list(set(alias2relations_dict[rel_candidate]) | set(elem['relation'])) 42 | id_alias2relations_dict[int(line["id"])] = alias2relations_dict 43 | 44 | with open(f'data/{dataset.lower()}_id_alias2relations_dict.json', 'w') as f: 45 | json.dump(id_alias2relations_dict, f) 46 | 47 | if __name__=="__main__": 48 | parser = argparse.ArgumentParser() 49 | parser.add_argument("-t", "--task", dest="task", type=str, default='FewRel', 50 | choices=[ 51 | 'FewRel', 52 | 'TACRED', 53 | ], 54 | help = "The task to be run" 55 | ) 56 | args = parser.parse_args() 57 | 58 | get_processed_output(args.task) 59 | get_id_alias2relations_dict(args.task) 60 | -------------------------------------------------------------------------------- /scripts/rc/prep_FewRel.sh: -------------------------------------------------------------------------------- 1 | rm -rf ./data/FewRel 2 | mkdir -p ./data/FewRel/ 3 | cd ./data/FewRel/ 4 | if [ ! -d "./pid2name.json" ];then 5 | wget https://raw.githubusercontent.com/thunlp/FewRel/master/data/pid2name.json 6 | fi 7 | if [ ! -d "./val_wiki.json" ];then 8 | wget https://raw.githubusercontent.com/thunlp/FewRel/master/data/val_wiki.json 9 | fi 10 | cd ../.. 11 | cd scripts/rc/ 12 | echo "Preparing FewRel dataset..." 13 | python dataset_preparation.py --task=FewRel 14 | cd ../.. 15 | echo "Done." 16 | -------------------------------------------------------------------------------- /scripts/rc/prep_TACRED.sh: -------------------------------------------------------------------------------- 1 | # Please first download TACRED dataset from [This Link](https://catalog.ldc.upenn.edu/LDC2018T24). The downloaded file should be named as `tacred_LDC2018T24.tgz`. 2 | rm -rf tacred 3 | tar zxvf tacred_LDC2018T24.tgz 4 | mkdir -p ./data/TACRED 5 | mv ./tacred/data/json/test.json ./data/TACRED/test.json 6 | cd scripts/rc/ 7 | echo "Preparing TACRED dataset..." 8 | python dataset_preparation.py --task=TACRED 9 | cd ../.. 10 | echo "Done." 11 | -------------------------------------------------------------------------------- /scripts/rc/string_matcher.py: -------------------------------------------------------------------------------- 1 | from flashtext import KeywordProcessor 2 | 3 | import spacy 4 | from datetime import datetime 5 | from tqdm import tqdm 6 | from collections import defaultdict 7 | 8 | import json 9 | import re 10 | 11 | nlp = spacy.load('en_core_web_sm', disable=['tagger', 'parser', 'textcat', 'ner']) 12 | 13 | import nltk 14 | try: 15 | from nltk.corpus import stopwords 16 | STOPWORDS = stopwords.words('english') 17 | except: 18 | nltk.download('stopwords') 19 | from nltk.corpus import stopwords 20 | STOPWORDS = stopwords.words('english') 21 | 22 | def remove_stopwords(s): 23 | return " ".join([w for w in s.split() if w.lower() not in STOPWORDS]) 24 | 25 | class LemmatizeHelper(object): 26 | def __init__(self): 27 | self.data = dict() 28 | 29 | def lemmatize_relation(self, relation): 30 | _r = relation 31 | if len(_r) == 0: 32 | _r = relation 33 | result, ns2os = [], [] 34 | offset = -1 35 | for w in nlp(_r): 36 | word = w.lemma_.lower() 37 | result.append(word) 38 | new_span = [offset + 1, offset + 1 + len(word)] 39 | old_span = [w.idx, w.idx + len(w.text)] 40 | ns2os.append([new_span, old_span]) 41 | offset += (1 + len(word)) 42 | return ' '.join(result), ns2os 43 | 44 | def lemmatize_relation_with_time(self, relation): 45 | start = datetime.now() 46 | _r = remove_stopwords(relation) 47 | mid = datetime.now() 48 | if len(_r) == 0: 49 | _r = relation 50 | res = ' '.join([w.lemma_.lower() for w in nlp(_r)]) 51 | return mid - start, datetime.now() - mid, res 52 | 53 | def map(self, relation): 54 | lemmatized, ns2os = self.lemmatize_relation(relation) 55 | self.data[relation] = lemmatized 56 | return lemmatized, ns2os 57 | 58 | 59 | class LemmatizeStringMatcher(object): 60 | def __init__(self, file): 61 | self.helper = LemmatizeHelper() 62 | self.o2w = json.load(open(file)) 63 | self.processor = KeywordProcessor(case_sensitive=False) 64 | keywords = [k for k in self.o2w.keys() if k != ''] 65 | self.processor.add_keywords_from_list(keywords) 66 | 67 | def __call__(self, raw_string): 68 | lemmatized_string, ns2os = self.helper.map(raw_string) 69 | keywords_found = self.processor.extract_keywords(lemmatized_string, span_info=True) 70 | candidates = [] 71 | for keyword_tuple in keywords_found: 72 | mention, start, end = keyword_tuple 73 | relation = list(self.o2w[mention].keys()) 74 | pos_start, pos_end = None, None 75 | for i in range(len(ns2os)): 76 | if pos_start is None and ns2os[i][0][0] >= start: 77 | pos_start = i 78 | if pos_end is None and (i + 1 == len(ns2os) or ns2os[i + 1][0][0] >= end): 79 | pos_end = i 80 | break 81 | if pos_start is None or pos_end is None: 82 | continue 83 | candidates.append( 84 | {"aliase": mention, "relation": relation, "len": len(mention.split(' ')), 85 | "char_span": [ns2os[pos_start][1][0], ns2os[pos_end][1][1]]}) 86 | candidates = sorted(candidates, key=lambda x: x['len'], reverse=True) 87 | return candidates 88 | 89 | 90 | class UnLemmatizeStringMatcher(object): 91 | def __init__(self, file): 92 | self.a2r = json.load(open(file)) 93 | self.processor = KeywordProcessor(case_sensitive=False) 94 | self.processor.add_keywords_from_list(list(self.a2r.keys())) 95 | 96 | def __call__(self, raw_string): 97 | keywords_found = self.processor.extract_keywords(raw_string, span_info=True) 98 | candidates = [] 99 | for keyword_tuple in keywords_found: 100 | mention, start, end = keyword_tuple 101 | relation = self.a2r[mention] 102 | candidates.append( 103 | {"aliase": mention, "relation": relation, "len": len(mention.split(' ')), "char_span": [start, end]}) 104 | candidates = sorted(candidates, key=lambda x: x['len'], reverse=True) 105 | return candidates 106 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import find_packages, setup 2 | 3 | setup( 4 | name="deepex", 5 | version="0.0.1", 6 | author="Chenguang Wang, Xiao Liu, Zui Chen, Haoyun Hong, Jie Tang, Dawn Song", 7 | author_email="25714264+cgraywang@users.noreply.github.com", 8 | description="Zero-Shot Information Extraction as a Unified Text-to-Triple Translation", 9 | long_description=open("README.md", "r", encoding="utf-8").read(), 10 | long_description_content_type="text/markdown", 11 | keywords="NLP deep learning zero-shot information extraction", 12 | license="Apache", 13 | url="https://github.com/cgraywang/deepex", 14 | package_dir={"": "src"}, 15 | packages=find_packages("src"), 16 | setup_requires=[ 17 | 'setuptools>=18.0', 18 | ], 19 | python_requires=">=3.7.0", 20 | classifiers=[ 21 | "Development Status :: 0", 22 | "Intended Audience :: Developers", 23 | "Intended Audience :: Education", 24 | "Intended Audience :: Science/Research", 25 | ], 26 | ) -------------------------------------------------------------------------------- /src/deepex/__init__.py: -------------------------------------------------------------------------------- 1 | from .data import * 2 | from .model import * 3 | from .args import * 4 | from .utils import * -------------------------------------------------------------------------------- /src/deepex/args.py: -------------------------------------------------------------------------------- 1 | import dataclasses 2 | import json 3 | import logging 4 | from dataclasses import dataclass, field 5 | from typing import Any, Dict, Optional, Tuple 6 | 7 | import torch 8 | 9 | from transformers import MODEL_WITH_LM_HEAD_MAPPING 10 | from transformers.training_args import is_torch_tpu_available 11 | from transformers.file_utils import cached_property, is_torch_available, torch_required 12 | 13 | 14 | logger = logging.getLogger(__name__) 15 | 16 | MODEL_CONFIG_CLASSES = list(MODEL_WITH_LM_HEAD_MAPPING.keys()) 17 | MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES) 18 | 19 | try: 20 | import torch_xla.core.xla_model as xm 21 | except ImportError: 22 | pass 23 | 24 | @dataclass 25 | class ModelArguments: 26 | 27 | model_name_or_path: Optional[str] = field( 28 | default=None, 29 | metadata={ 30 | "help": "The model checkpoint for weights initialization. Leave None if you want to train a model from scratch." 31 | }, 32 | ) 33 | model_type: Optional[str] = field( 34 | default=None, 35 | metadata={"help": "If training from scratch, pass a model type from the list: " + ", ".join(MODEL_TYPES)}, 36 | ) 37 | tokenizer_name: Optional[str] = field( 38 | default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"} 39 | ) 40 | cache_dir: Optional[str] = field( 41 | default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"} 42 | ) 43 | generation_type: str = field( 44 | default='unsupervised_last_layer', metadata={"help": "kb generation mode"} 45 | ) 46 | compute_loss: bool = field( 47 | default=False, metadata={"help": "whether to compute loss."} 48 | ) 49 | search_n: int = field( 50 | default=None, 51 | metadata={ 52 | "help": "number of return triplets for each example" 53 | }, 54 | ) 55 | beam_size: int = field( 56 | default=2, 57 | metadata={ 58 | "help": "beam size" 59 | }, 60 | ) 61 | search_max_len: int = field( 62 | default=20, 63 | metadata={ 64 | "help": "sequence max len of the search" 65 | }, 66 | ) 67 | search_min_len: int = field( 68 | default=3, 69 | metadata={ 70 | "help": "sequence min len of the search" 71 | }, 72 | ) 73 | search_score_threshold: float = field( 74 | default=0.0, 75 | metadata={ 76 | "help": "score threshold of the search" 77 | }, 78 | ) 79 | search_layer_id: int = field( 80 | default=-1, 81 | metadata={ 82 | "help": "use the attention weights of layer id" 83 | }, 84 | ) 85 | search_attention_head_type: str = field( 86 | default='max', 87 | metadata={ 88 | "help": "use the max/mean head's weight (max, mean)" 89 | }, 90 | ) 91 | search_cand_type: str = field( 92 | default='word', 93 | metadata={ 94 | "help": "the search candidate type" 95 | }, 96 | ) 97 | beam_mode: str = field( 98 | default="ie", metadata={"help": "beam mode."} 99 | ) 100 | search_ranking_type: str = field( 101 | default='sum', 102 | metadata={ 103 | "help": "the search ranking type (sum, mean)" 104 | }, 105 | ) 106 | local_model_name_or_path: str = field( 107 | default=None, 108 | metadata={ 109 | "help": "the local model path" 110 | }, 111 | ) 112 | cand_min_len: int = field( 113 | default=3, 114 | metadata={ 115 | "help": "candidate min len" 116 | }, 117 | ) 118 | sentence: int = field( 119 | default=1, metadata={"help": "whether to split sentence."} 120 | ) 121 | dedup_ranking_type: str = field( 122 | default='freq', 123 | metadata={ 124 | "help": "the search ranking type (freq, score, score_freq, score_freq_len)" 125 | }, 126 | ) 127 | add_extra_entity: int = field( 128 | default=1, metadata={"help": "whether to add first and last word as entity in the input."} 129 | ) 130 | dist_const: int = field( 131 | default=2, metadata={"help": "distance constraint"} 132 | ) 133 | 134 | @dataclass 135 | class DataTrainingArguments: 136 | 137 | max_length: int = field( 138 | default=None, 139 | metadata={ 140 | "help": "Use in tokenizer.batch_encode_plus." 141 | }, 142 | ) 143 | data_aug: Optional[str] = field( 144 | default='ner', 145 | metadata={ 146 | "help": "ner, np." 147 | }, 148 | ) 149 | data_dir: str = field( 150 | default=None, 151 | metadata={ 152 | "help": "data dir" 153 | }, 154 | ) 155 | input_dir: str = field( 156 | default=None, 157 | metadata={ 158 | "help": "info dir" 159 | }, 160 | ) 161 | -------------------------------------------------------------------------------- /src/deepex/data/__init__.py: -------------------------------------------------------------------------------- 1 | from .generator_utils import * 2 | from .re_data import * 3 | from .collator import * 4 | from .np import * 5 | from .rc import * -------------------------------------------------------------------------------- /src/deepex/data/collator.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from typing import List, Optional, Tuple, Dict, NewType, Any 3 | 4 | InputDataClass = NewType("InputDataClass", Any) 5 | 6 | def default_data_collator(features: List[InputDataClass]) -> Dict[str, torch.Tensor]: 7 | if not isinstance(features[0], dict): 8 | features = [vars(f) for f in features] 9 | 10 | first = features[0] 11 | batch = {} 12 | 13 | if "label" in first and first["label"] is not None: 14 | label = first["label"].item() if isinstance(first["label"], torch.Tensor) else first["label"] 15 | dtype = torch.long if isinstance(label, int) else torch.float 16 | batch["labels"] = torch.tensor([f["label"] for f in features], dtype=dtype) 17 | elif "label_ids" in first and first["label_ids"] is not None: 18 | if isinstance(first["label_ids"], torch.Tensor): 19 | batch["labels"] = torch.stack([f["label_ids"] for f in features]) 20 | else: 21 | dtype = torch.long if type(first["label_ids"][0]) is int else torch.float 22 | batch["labels"] = torch.tensor([f["label_ids"] for f in features], dtype=dtype) 23 | 24 | for k, v in first.items(): 25 | if k not in ("label", "label_ids", "entity_ids", "head_entity_ids", "tail_entity_ids", "relation_entity_ids", "docid", "offset") and v is not None and not isinstance(v, str): 26 | if isinstance(v, torch.Tensor): 27 | batch[k] = torch.stack([f[k] for f in features]) 28 | else: 29 | batch[k] = torch.tensor([f[k] for f in features], dtype=torch.long) 30 | elif k in ("entity_ids", "head_entity_ids", "tail_entity_ids", "relation_entity_ids", "docid", "offset", "text"): 31 | batch[k] = [f[k] for f in features] 32 | else: 33 | pass 34 | return batch -------------------------------------------------------------------------------- /src/deepex/data/generator_utils.py: -------------------------------------------------------------------------------- 1 | from typing import List, Tuple, Union 2 | from collections import defaultdict 3 | import time 4 | import sys 5 | import os 6 | import string 7 | import json 8 | import random 9 | 10 | import numpy as np 11 | import spacy 12 | from spacy.lang.en import STOP_WORDS 13 | from spacy.lang.char_classes import LIST_PUNCT, LIST_ELLIPSES, LIST_QUOTES, LIST_CURRENCY 14 | from spacy.tokens import Doc 15 | from tqdm import tqdm 16 | 17 | from ..utils import * 18 | from ..utils import * 19 | 20 | 21 | class WhitespaceTokenizer(object): 22 | def __init__(self, vocab): 23 | self.vocab = vocab 24 | 25 | def __call__(self, text): 26 | words = text.split(' ') 27 | spaces = [True] * len(words) 28 | return Doc(self.vocab, words=words, spaces=spaces) 29 | 30 | 31 | class MentionGenerator(): 32 | pass 33 | 34 | 35 | def get_empty_candidates(): 36 | return { 37 | "candidate_spans": [[-1, -1]], 38 | "candidate_entities": [["@@PADDING@@"]], 39 | "candidate_entity_priors": [[1.0]], 40 | "tokenized_text": None, 41 | "candidate_positions": [[-1, -1]], 42 | } 43 | 44 | STOP_SYMBOLS = set().union(LIST_PUNCT, LIST_ELLIPSES, LIST_QUOTES, LIST_CURRENCY) 45 | 46 | 47 | def span_filter_func(span: List[str]): 48 | if span[0] in STOP_WORDS or span[-1] in STOP_WORDS: 49 | return False 50 | 51 | if any([c in STOP_SYMBOLS for c in span]): 52 | return False 53 | return True 54 | -------------------------------------------------------------------------------- /src/deepex/data/np.py: -------------------------------------------------------------------------------- 1 | import spacy 2 | 3 | from .generator_utils import WhitespaceTokenizer, span_filter_func, get_empty_candidates 4 | from ..utils import * 5 | 6 | class NPMentionGenerator: 7 | 8 | def __init__(self): 9 | spacy.load('en_core_web_sm', disable=['tagger', 'parser', 'textcat']) 10 | self.tokenizer = spacy.load('en_core_web_sm') 11 | self.whitespace_tokenizer = spacy.load('en_core_web_sm') 12 | self.whitespace_tokenizer.tokenizer = WhitespaceTokenizer(self.whitespace_tokenizer.vocab) 13 | 14 | def get_mentions_raw_text(self, text: str, whitespace_tokenize=False, extra=None): 15 | if whitespace_tokenize: 16 | tokens = self.whitespace_tokenizer(text) 17 | else: 18 | self.tokenizer.max_length = 1000000000 19 | tokens = self.tokenizer(text) 20 | 21 | _tokens = [t.text for t in tokens] 22 | spans_to_candidates = {} 23 | spans_to_positions = {} 24 | 25 | for cand in tokens.noun_chunks: 26 | spans_to_candidates[(cand.start, cand.end-1)] = [(None, cand.text, 1.0)] 27 | spans_to_positions[(cand.start, cand.end-1)] = [cand.start_char, cand.end_char] 28 | 29 | spans = [] 30 | entities = [] 31 | priors = [] 32 | positions = [] 33 | for span, candidates in spans_to_candidates.items(): 34 | spans.append(list(span)) 35 | entities.append([x[1] for x in candidates]) 36 | mention_priors = [x[2] for x in candidates] 37 | 38 | sum_priors = sum(mention_priors) 39 | priors.append([x/sum_priors for x in mention_priors]) 40 | 41 | positions.append(spans_to_positions[span]) 42 | ret = { 43 | "tokenized_text": _tokens, 44 | "candidate_spans": spans, 45 | "candidate_entities": entities, 46 | "candidate_entity_priors": priors, 47 | "candidate_positions": positions, 48 | 49 | "head_candidate_spans": [], 50 | "head_candidate_entities": [], 51 | "head_candidate_entity_priors": [], 52 | "head_candidate_positions": [], 53 | 54 | "tail_candidate_spans": [], 55 | "tail_candidate_entities": [], 56 | "tail_candidate_entity_priors": [], 57 | "tail_candidate_positions": [], 58 | 59 | "relation_candidate_spans": [], 60 | "relation_candidate_entities": [], 61 | "relation_candidate_entity_priors": [], 62 | "relation_candidate_positions": [], 63 | } 64 | 65 | if len(spans) == 0: 66 | ret.update(get_empty_candidates()) 67 | 68 | return ret 69 | -------------------------------------------------------------------------------- /src/deepex/data/rc.py: -------------------------------------------------------------------------------- 1 | import spacy 2 | 3 | from .generator_utils import WhitespaceTokenizer 4 | from ..utils import * 5 | 6 | from ..utils import * 7 | from .np import NPMentionGenerator 8 | 9 | class RCMentionGenerator: 10 | 11 | def __init__(self, dataset='FewRel'): 12 | self.dataset = {record['id']:record for record in LoadJSON(f"data/{dataset}/data.jsonl",jsonl=True)} 13 | for key, record in self.dataset.items(): 14 | self.dataset[key]['rel'] = {} 15 | for relation in self.dataset[key]['rel_candidates']: 16 | for rname in relation['relation']: 17 | if rname not in self.dataset[key]['rel'].keys(): 18 | self.dataset[key]['rel'][rname] = [] 19 | self.dataset[key]['rel'][rname].append(relation) 20 | self.tokenizer = spacy.load('en_core_web_sm', disable=['tagger', 'parser', 'ner', 'textcat']) 21 | self.whitespace_tokenizer = spacy.load('en_core_web_sm', disable=['tagger', 'parser', 'ner', 'textcat']) 22 | self.whitespace_tokenizer.tokenizer = WhitespaceTokenizer(self.whitespace_tokenizer.vocab) 23 | 24 | def get_mentions_raw_text(self, text: str, whitespace_tokenize=False, extra=None): 25 | docid, offset = str(int(extra[0])), extra[1]; data = self.dataset[docid] 26 | 27 | tokens = data['tokens'] 28 | 29 | entities = []; idx = 0 30 | for i,word in enumerate(tokens): 31 | entities.append([[i,i],word,1.0,[idx,idx+len(word)]]); idx += len(word)+1 32 | 33 | head_ents = [] 34 | for ss in data['subject_spans']: 35 | ents = [ent for ent in entities if ent[0][0]+offset in ss] 36 | if len(ents)==0: 37 | continue 38 | new_ent = [ 39 | [min([ent[0][0] for ent in ents]),max([ent[0][1] for ent in ents])], 40 | [' '.join([ent[1] for ent in ents])], [1.0], 41 | [min([ent[3][0] for ent in ents]),max([ent[3][1] for ent in ents])], 42 | ] 43 | flag = True 44 | for ent1 in head_ents: 45 | if not (ent1[3][1] <= new_ent[3][0] or new_ent[3][1] <= ent1[3][0]): 46 | flag = False; break 47 | if flag: 48 | head_ents.append(new_ent) 49 | 50 | tail_ents = [] 51 | for ss in data['object_spans']: 52 | ents = [ent for ent in entities if ent[0][0]+offset in ss] 53 | if len(ents)==0: 54 | continue 55 | new_ent = [ 56 | [min([ent[0][0] for ent in ents]),max([ent[0][1] for ent in ents])], 57 | [' '.join([ent[1] for ent in ents])], [1.0], 58 | [min([ent[3][0] for ent in ents]),max([ent[3][1] for ent in ents])], 59 | ] 60 | flag = True 61 | for ent1 in tail_ents: 62 | if not (ent1[3][1] <= new_ent[3][0] or new_ent[3][1] <= ent1[3][0]): 63 | flag = False; break 64 | if flag: 65 | tail_ents.append(new_ent) 66 | all_ents = head_ents + tail_ents 67 | 68 | rel_ents = [] 69 | for rname, rels in data['rel'].items(): 70 | for rel in rels: 71 | rel_words = [ent for ent in entities if not (rel['char_span'][1] <= ent[3][0]+offset or ent[3][1]+offset <= rel['char_span'][0])] 72 | if len(rel_words)==0: 73 | continue 74 | rel_ent = [ 75 | [min([ent[0][0] for ent in rel_words]),max([ent[0][1] for ent in rel_words])], 76 | [' '.join([ent[1] for ent in rel_words])], [1.0], 77 | [min([ent[3][0] for ent in rel_words]),max([ent[3][1] for ent in rel_words])], 78 | ] 79 | flag = True 80 | if flag: 81 | rel_ents.append(rel_ent) 82 | ret = { 83 | "tokenized_text": tokens, 84 | "candidate_spans": [], 85 | "candidate_entities": [], 86 | "candidate_entity_priors": [], 87 | "candidate_positions": [], 88 | 89 | "head_candidate_spans": [head_ent[0] for head_ent in head_ents], 90 | "head_candidate_entities": [head_ent[1] for head_ent in head_ents], 91 | "head_candidate_entity_priors": [head_ent[2] for head_ent in head_ents], 92 | "head_candidate_positions": [head_ent[3] for head_ent in head_ents], 93 | 94 | "tail_candidate_spans": [tail_ent[0] for tail_ent in tail_ents], 95 | "tail_candidate_entities": [tail_ent[1] for tail_ent in tail_ents], 96 | "tail_candidate_entity_priors": [tail_ent[2] for tail_ent in tail_ents], 97 | "tail_candidate_positions": [tail_ent[3] for tail_ent in tail_ents], 98 | 99 | "relation_candidate_spans": [rel_ent[0] for rel_ent in rel_ents], 100 | "relation_candidate_entities": [rel_ent[1] for rel_ent in rel_ents], 101 | "relation_candidate_entity_priors": [rel_ent[2] for rel_ent in rel_ents], 102 | "relation_candidate_positions": [rel_ent[3] for rel_ent in rel_ents], 103 | } 104 | 105 | return ret 106 | -------------------------------------------------------------------------------- /src/deepex/data/re_data.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import time 4 | from zipfile import ZipFile 5 | from bisect import bisect, bisect_left 6 | from html.parser import HTMLParser 7 | from dataclasses import dataclass, field 8 | from filelock import FileLock 9 | from typing import List, Optional, Tuple, Dict, NewType, Any 10 | import xml.etree.ElementTree as ET 11 | import re 12 | from collections import namedtuple 13 | import json 14 | import math 15 | import itertools 16 | from tqdm import tqdm 17 | 18 | import spacy 19 | from spacy.lang.en import English 20 | import numpy as np 21 | import torch 22 | from torch.utils.data.dataset import Dataset 23 | from .text_handler import TextHandler, re_pronouns 24 | 25 | logger = logging.getLogger(__name__) 26 | logger.setLevel(logging.INFO) 27 | logger.addHandler(logging.StreamHandler()) 28 | 29 | Entity = namedtuple('Entity', 'name, span, score') 30 | 31 | 32 | @dataclass 33 | class InputExample: 34 | docid: str 35 | text: str 36 | offset: int 37 | 38 | 39 | @dataclass(frozen=True) 40 | class InputFeatures: 41 | docid: str 42 | offset: int 43 | input_ids: List[int] 44 | attention_mask: Optional[List[int]] = None 45 | token_type_ids: Optional[List[int]] = None 46 | special_tokens_mask: Optional[List[int]] = None 47 | entity_ids: List[Entity] = None 48 | head_entity_ids: List[Entity] = None 49 | tail_entity_ids: List[Entity] = None 50 | relation_entity_ids: List[Entity] = None 51 | text: str = "" 52 | 53 | 54 | class SequentialDataset(Dataset): 55 | def __init__(self, filepaths, 56 | tokenizer, 57 | mention_generator, 58 | max_seq_length, 59 | overwrite_cache: Optional[bool] = False): 60 | if len(filepaths) == 0: 61 | self.features = [] 62 | else: 63 | logger.addHandler(logging.FileHandler(os.path.join('/'.join(filepaths[0].split('/')[:-2]), 64 | 'run_kbp_{}_{}.log'.format(tokenizer.__class__.__name__, 65 | mention_generator.__class__.__name__)))) 66 | self.features = [] 67 | for filepath in filepaths: 68 | dataset = REDataset(tokenizer, 69 | mention_generator, 70 | max_seq_length, 71 | overwrite_cache) 72 | self.features.extend(dataset.features) 73 | 74 | def __len__(self): 75 | return len(self.features) 76 | 77 | def __getitem__(self, i) -> InputFeatures: 78 | return self.features[i] 79 | 80 | class REDataset: 81 | def __init__( 82 | self, 83 | filedir, 84 | index, 85 | tokenizer, 86 | mention_generator, 87 | max_seq_length, 88 | example_batch_size=2048, 89 | overwrite_cache: Optional[bool] = False, 90 | ): 91 | self.filedir = filedir 92 | self.index = index 93 | self.max_seq_length = max_seq_length 94 | self.overwrite_cache = overwrite_cache 95 | self.use_coref = False 96 | self.text_handler = TextHandler(index=self.index, use_coref=self.use_coref, DIR=filedir) 97 | self.processor = Processor(tokenizer, self.text_handler, mention_generator, example_batch_size) 98 | 99 | def generate_batched_datasets(self): 100 | for i, self.features in enumerate( 101 | tqdm(self.processor._convert_batch_examples_to_features( 102 | self.filedir, self.index, self.overwrite_cache, 103 | max_length=self.max_seq_length, use_coref=self.use_coref 104 | ), desc='process feature files...')): 105 | logger.debug('features size {}'.format(len(self.features))) 106 | yield DatasetWrapper(self.features) 107 | 108 | class DatasetWrapper(Dataset): 109 | def __init__( 110 | self, 111 | features, 112 | ): 113 | self.features = features 114 | 115 | def __len__(self): 116 | return len(self.features) 117 | 118 | def __getitem__(self, i) -> InputFeatures: 119 | return self.features[i] 120 | 121 | class Processor: 122 | def __init__(self, tokenizer, text_handler, mention_generator, example_batch_size=2048): 123 | self.tokenizer = tokenizer 124 | self.text_handler = text_handler 125 | self.mention_generator = mention_generator 126 | self.example_batch_size = example_batch_size 127 | self.examples = [] 128 | self.features = [] 129 | 130 | def overlap_span(self, span0, span1, tokenizer): 131 | return span1[1] > span0[0] and span1[0] < span0[1] 132 | 133 | def _create_batch_examples(self): 134 | last_dir_name = None 135 | file_cnt = 0 136 | for i, (text, offset, dir_name, filename) in enumerate(tqdm(self.text_handler, desc='create batch examples...')): 137 | logger.debug('text: {}'.format(text)) 138 | logger.debug('offset: {}'.format(offset)) 139 | logger.debug('dir_name: {}'.format(dir_name)) 140 | logger.debug('filename: {}'.format(filename)) 141 | if last_dir_name != dir_name: 142 | file_cnt += 1 143 | last_dir_name = dir_name 144 | self.examples.append(InputExample(docid=dir_name, text=text, offset=offset)) 145 | if (i+1) % self.example_batch_size == 0: 146 | logger.debug('processed number of sentences/samples {}'.format(i+1)) 147 | yield self.examples 148 | self.examples = [] 149 | logger.debug('cleaned example size {}'.format(len(self.examples))) 150 | if len(self.examples) != 0: 151 | yield self.examples 152 | self.examples = [] 153 | 154 | def _convert_to_coref(self, name, span): 155 | coref = self.text_handler.get_coref(span) 156 | if coref and self.text_handler.cur_text[coref[1][0]:coref[1][1]].strip(' ').lower() in re_pronouns: 157 | logger.debug('org name: {}'.format(name)) 158 | name = coref[0].strip('\n') 159 | logger.debug('coref name: {}'.format(name)) 160 | logger.debug('org span: {}'.format(str(span))) 161 | span = coref[1] 162 | logger.debug('coref span: {}'.format(str(span))) 163 | return name, span 164 | 165 | def _convert_batch_examples_to_features(self, filedir, index, overwrite_cache, use_coref=False, 166 | max_length: Optional[int] = None): 167 | for i, self.examples in enumerate(tqdm(self._create_batch_examples(), desc='convert batch examples to features...')): 168 | logger.debug('example size {}'.format(len(self.examples))) 169 | cached_features_file = os.path.join( 170 | filedir, 171 | "cached_{}_{}_{}_{}_{}_{}_{}".format( 172 | index, self.tokenizer.__class__.__name__, self.mention_generator.__class__.__name__, max_length, i, 173 | use_coref, self.example_batch_size 174 | ), 175 | ) 176 | cached_mentions_file = os.path.join( 177 | filedir, 178 | "cachedmentions_{}_{}_{}_{}_{}_{}_{}".format( 179 | index, self.tokenizer.__class__.__name__, self.mention_generator.__class__.__name__, max_length, i, 180 | use_coref, self.example_batch_size 181 | ), 182 | ) 183 | lock_path = cached_features_file + ".lock" 184 | with FileLock(lock_path): 185 | 186 | if os.path.exists(cached_features_file) and not overwrite_cache: 187 | start = time.time() 188 | try: 189 | if os.path.getsize(cached_features_file) == 0: 190 | self.features = [] 191 | logger.debug( 192 | f"Skipping features from cached file {cached_features_file} [took %.3f s]", time.time() - start 193 | ) 194 | else: 195 | self.features = torch.load(cached_features_file) 196 | logger.debug( 197 | f"Loading features from cached file {cached_features_file} [took %.3f s]", time.time() - start 198 | ) 199 | except: 200 | self.features = [] 201 | else: 202 | logger.debug(f"Creating features from dataset file at {index} {i}") 203 | if max_length is None: 204 | max_length = self.tokenizer.max_len 205 | batch_encoding = self.tokenizer.batch_encode_plus( 206 | [example.text for example in self.examples], 207 | max_length=max_length, 208 | padding="max_length", 209 | truncation=True, 210 | return_special_tokens_mask=True, 211 | return_offsets_mapping=True 212 | ) 213 | all_mentions = {} 214 | for i in range(len(self.examples)): 215 | inputs = {k: batch_encoding[k][i] for k in batch_encoding} 216 | mentions = self.mention_generator.get_mentions_raw_text(self.examples[i].text,extra=(self.examples[i].docid,self.examples[i].offset)) 217 | all_mentions[(self.examples[i].docid,self.examples[i].offset)] = mentions 218 | logger.debug(('candidate entities: {}'.format(str(mentions['candidate_entities'])))) 219 | 220 | entity_ids = [] 221 | for j, encoding_span in enumerate(batch_encoding['offset_mapping'][i]): 222 | if encoding_span[0] == 0 and encoding_span[1] == 0: 223 | entity_ids.append(Entity(name='$NIL$', span=[-1, -1], score=1.0)) 224 | continue 225 | has_entity = False 226 | logger.debug('encoding_span: {} name: {}'.format(encoding_span, 227 | self.tokenizer.convert_ids_to_tokens( 228 | batch_encoding['input_ids'][i][j]))) 229 | for m, (name, raw_span) in enumerate( 230 | zip(mentions['candidate_entities'], mentions['candidate_positions'])): 231 | if raw_span[0] == -1 and raw_span[1] == -1: 232 | continue 233 | logger.debug('raw_span: {} name: {}'.format(raw_span, name)) 234 | if self.overlap_span(encoding_span, raw_span, self.tokenizer): 235 | char_span = [raw_span[0] + self.examples[i].offset, 236 | raw_span[1] + self.examples[i].offset] 237 | char_name = name[0] 238 | if use_coref: 239 | char_name, char_span = self._convert_to_coref(char_name, char_span) 240 | entity_ids.append(Entity(name=char_name, span=char_span, 241 | score=1.0)) 242 | has_entity = True 243 | break 244 | if not has_entity: 245 | entity_ids.append(Entity(name='$NIL$', span=[-1, -1], score=1.0)) 246 | 247 | head_entity_ids = [] 248 | for j, encoding_span in enumerate(batch_encoding['offset_mapping'][i]): 249 | if encoding_span[0] == 0 and encoding_span[1] == 0: 250 | head_entity_ids.append(Entity(name='$NIL$', span=[-1, -1], score=1.0)) 251 | continue 252 | has_entity = False 253 | logger.debug('encoding_span: {} name: {}'.format(encoding_span, 254 | self.tokenizer.convert_ids_to_tokens( 255 | batch_encoding['input_ids'][i][j]))) 256 | for m, (name, raw_span) in enumerate( 257 | zip(mentions['head_candidate_entities'], mentions['head_candidate_positions'])): 258 | if raw_span[0] == -1 and raw_span[1] == -1: 259 | continue 260 | logger.debug('raw_span: {} name: {}'.format(raw_span, name)) 261 | if self.overlap_span(encoding_span, raw_span, self.tokenizer): 262 | char_span = [raw_span[0] + self.examples[i].offset, 263 | raw_span[1] + self.examples[i].offset] 264 | char_name = name[0] 265 | if use_coref: 266 | char_name, char_span = self._convert_to_coref(char_name, char_span) 267 | head_entity_ids.append(Entity(name=char_name, span=char_span, 268 | score=1.0)) 269 | has_entity = True 270 | break 271 | if not has_entity: 272 | head_entity_ids.append(Entity(name='$NIL$', span=[-1, -1], score=1.0)) 273 | 274 | tail_entity_ids = [] 275 | for j, encoding_span in enumerate(batch_encoding['offset_mapping'][i]): 276 | if encoding_span[0] == 0 and encoding_span[1] == 0: 277 | tail_entity_ids.append(Entity(name='$NIL$', span=[-1, -1], score=1.0)) 278 | continue 279 | has_entity = False 280 | logger.debug('encoding_span: {} name: {}'.format(encoding_span, 281 | self.tokenizer.convert_ids_to_tokens( 282 | batch_encoding['input_ids'][i][j]))) 283 | for m, (name, raw_span) in enumerate( 284 | zip(mentions['tail_candidate_entities'], mentions['tail_candidate_positions'])): 285 | if raw_span[0] == -1 and raw_span[1] == -1: 286 | continue 287 | logger.debug('raw_span: {} name: {}'.format(raw_span, name)) 288 | if self.overlap_span(encoding_span, raw_span, self.tokenizer): 289 | char_span = [raw_span[0] + self.examples[i].offset, 290 | raw_span[1] + self.examples[i].offset] 291 | char_name = name[0] 292 | if use_coref: 293 | char_name, char_span = self._convert_to_coref(char_name, char_span) 294 | tail_entity_ids.append(Entity(name=char_name, span=char_span, 295 | score=1.0)) 296 | has_entity = True 297 | break 298 | if not has_entity: 299 | tail_entity_ids.append(Entity(name='$NIL$', span=[-1, -1], score=1.0)) 300 | 301 | relation_entity_ids = [] 302 | for j, encoding_span in enumerate(batch_encoding['offset_mapping'][i]): 303 | if encoding_span[0] == 0 and encoding_span[1] == 0: 304 | relation_entity_ids.append(Entity(name='$NIL$', span=[-1, -1], score=1.0)) 305 | continue 306 | has_entity = False 307 | logger.debug('encoding_span: {} name: {}'.format(encoding_span, 308 | self.tokenizer.convert_ids_to_tokens( 309 | batch_encoding['input_ids'][i][j]))) 310 | for m, (name, raw_span) in enumerate( 311 | zip(mentions['relation_candidate_entities'], mentions['relation_candidate_positions'])): 312 | if raw_span[0] == -1 and raw_span[1] == -1: 313 | continue 314 | logger.debug('raw_span: {} name: {}'.format(raw_span, name)) 315 | if self.overlap_span(encoding_span, raw_span, self.tokenizer): 316 | char_span = [raw_span[0] + self.examples[i].offset, 317 | raw_span[1] + self.examples[i].offset] 318 | char_name = name[0] 319 | if use_coref: 320 | char_name, char_span = self._convert_to_coref(char_name, char_span) 321 | relation_entity_ids.append(Entity(name=char_name, span=char_span, 322 | score=1.0)) 323 | has_entity = True 324 | break 325 | if not has_entity: 326 | relation_entity_ids.append(Entity(name='$NIL$', span=[-1, -1], score=1.0)) 327 | 328 | inputs['docid'] = self.examples[i].docid 329 | inputs['entity_ids'] = entity_ids 330 | inputs['head_entity_ids'] = head_entity_ids 331 | inputs['tail_entity_ids'] = tail_entity_ids 332 | inputs['relation_entity_ids'] = relation_entity_ids 333 | inputs['offset'] = self.examples[i].offset 334 | inputs['text'] = self.examples[i].text 335 | inputs.pop('offset_mapping') 336 | 337 | feature = InputFeatures(**inputs) 338 | self.features.append(feature) 339 | start = time.time() 340 | if len(self.features) == 0: 341 | logger.debug( 342 | f"Empty features to cached file {cached_features_file} [took %.3f s]", time.time() - start 343 | ) 344 | torch.save(self.features, cached_features_file) 345 | torch.save(all_mentions, cached_mentions_file) 346 | logger.debug( 347 | "Saving features into cached file %s [took %.3f s]", cached_features_file, time.time() - start 348 | ) 349 | yield self.features 350 | self.features = [] 351 | logger.debug('cleaned features size {}'.format(len(self.features))) -------------------------------------------------------------------------------- /src/deepex/data/text_handler.py: -------------------------------------------------------------------------------- 1 | import xml.etree.ElementTree as ET 2 | from datetime import datetime 3 | import re 4 | import spacy 5 | import os 6 | from os.path import abspath, dirname, join, exists 7 | from collections import defaultdict 8 | import json 9 | import codecs 10 | import csv 11 | from tqdm import tqdm 12 | from spacy.lang.en import English 13 | 14 | re_pronouns = {'he', 'we', 'you', 'he', 'she', 'it', 'they', 15 | 'me', 'us', 'you', 'him', 'her', 'them', 16 | 'my', 'our', 'your', 'his', 'their', 'its', 17 | 'mine', 'ours', 'yours', 'hers', 'theirs', 18 | 'myself', 'ourselves', 'yourself', 'herself', 'himself', 'themselves', 'itself'} 19 | 20 | class TextHandler(object): 21 | def __init__(self, index, use_coref=False, DIR=""): 22 | self.index = index 23 | self.use_coref = use_coref 24 | self.input = codecs.open(join(DIR, 'P{}.jsonl'.format(self.index)), 'r', 'utf-8') 25 | self.nlp = English() 26 | self.nlp.add_pipe('sentencizer') 27 | if use_coref: 28 | neuralcoref.add_to_pipe(self.nlp) 29 | self.cur_doc = None 30 | self.cur_coref = None 31 | self.cur_text = None 32 | 33 | def gen_coref(self): 34 | self.cur_coref = defaultdict(dict) 35 | for cluster in self.cur_doc._.coref_clusters: 36 | main_entity, main_span = cluster.main.text, [cluster.main.start_char, cluster.main.end_char] 37 | for mention in cluster.mentions: 38 | self.cur_coref[mention.start_char][mention.end_char] = [main_entity, main_span] 39 | 40 | def get_coref(self, span): 41 | if self.cur_coref.get(span[0]): 42 | return self.cur_coref.get(span[0]).get(span[1]) 43 | 44 | def __iter__(self): 45 | num_of_dir = 0 46 | for i, line in enumerate(self.input): 47 | doc = json.loads(line) 48 | 49 | full_text, title, _id = doc['text'], doc['title'], doc['id'] 50 | full_text = re.sub(r'\(\(.*?\)\)', lambda m: ' ' * len(m.group()), full_text) 51 | full_text = re.sub( r'\(.*?\)', lambda m: ' ' * len(m.group()), full_text) 52 | self.cur_text = full_text 53 | text = self.nlp(full_text) 54 | self.cur_doc = text 55 | if self.use_coref: 56 | self.gen_coref() 57 | num_of_dir += 1 58 | for sentence in text.sents: 59 | yield sentence.text, full_text.find(sentence.text), (None if _id is None else ('0' * (40 - len(_id)) + _id)), title 60 | -------------------------------------------------------------------------------- /src/deepex/model/__init__.py: -------------------------------------------------------------------------------- 1 | from .kgm import * 2 | from .distillation import * 3 | from .eval import * -------------------------------------------------------------------------------- /src/deepex/model/distillation.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import itertools 4 | import copy 5 | import warnings 6 | from tqdm import tqdm 7 | import re 8 | from requests import get 9 | 10 | class Distillation: 11 | def __init__(self, input_dir, filepath): 12 | self.input_dir = input_dir 13 | self.filepath = filepath 14 | self.search_res = [] 15 | self.corpus_dedup_triplets = {} 16 | 17 | def merge_search_res(self, search_res, global_search_res): 18 | for k, v in search_res.items(): 19 | if k not in global_search_res: 20 | global_search_res[k] = v 21 | else: 22 | global_search_res[k].extend(v) 23 | 24 | def load_search_res(self): 25 | search_res = {} 26 | for f in os.listdir(self.input_dir): 27 | if os.path.isdir(os.path.join(self.input_dir, f)): 28 | print(os.path.join(os.path.join(self.input_dir, f), 'search_res.json')) 29 | self.merge_search_res( 30 | json.load(open(os.path.join(os.path.join(self.input_dir, f), 'search_res.json'), 'r')), search_res) 31 | return search_res 32 | 33 | def rank_entity_seqs_with_score_freq(self, x, dedup_ranking_type): 34 | if dedup_ranking_type == 'freq': 35 | return {k: v for k, v in sorted(x.items(), 36 | key=lambda item: item[1][0], reverse=True)} 37 | elif dedup_ranking_type == 'score': 38 | return {k: v for k, v in sorted(x.items(), 39 | key=lambda item: item[1][1], reverse=True)} 40 | elif dedup_ranking_type == 'score_freq': 41 | return {k: v for k, v in sorted(x.items(), 42 | key=lambda item: item[1][1] / item[1][0], reverse=True)} 43 | elif dedup_ranking_type == 'score_freq_len': 44 | warnings.warn( 45 | 'use score_len instead! score_freq_len is not recommended since it incorporates extended as a result of continous relation span constrain') 46 | return {k: v for k, v in sorted(x.items(), 47 | key=lambda item: item[1][1] / ( 48 | item[1][0] * len(item[0].strip().split(' '))), reverse=True)} 49 | elif dedup_ranking_type == 'score_len': 50 | return {k: v for k, v in sorted(x.items(), 51 | key=lambda item: item[1][1] / item[1][3], reverse=True)} 52 | else: 53 | raise ValueError('support (freq, score, score_freq, score_freq_len, score_len)') 54 | 55 | def rank_entity_seqs_with_attached_score(self, x, dedup_ranking_type): 56 | if dedup_ranking_type == 'freq': 57 | return {k: [v, v[0]] for k, v in sorted(x.items(), 58 | key=lambda item: item[1][0], reverse=True)} 59 | elif dedup_ranking_type == 'score': 60 | return {k: [v, v[1]] for k, v in sorted(x.items(), 61 | key=lambda item: item[1][1], reverse=True)} 62 | elif dedup_ranking_type == 'score_freq': 63 | return {k: [v, v[1] / v[0]] for k, v in sorted(x.items(), 64 | key=lambda item: item[1][1] / item[1][0], reverse=True)} 65 | elif dedup_ranking_type == 'score_freq_len': 66 | warnings.warn( 67 | 'use score_len instead! score_freq_len is not recommended since it incorporates extended as a result of continous relation span constrain') 68 | return {k: [v, v[1] / (v[0] * len(k.strip().split(' ')))] for k, v in sorted(x.items(), 69 | key=lambda item: item[1][1] / ( 70 | item[1][0] * len( 71 | item[0].strip().split( 72 | ' '))), 73 | reverse=True)} 74 | elif dedup_ranking_type == 'score_len': 75 | return {k: [v, v[1] / v[3]] for k, v in sorted(x.items(), 76 | key=lambda item: item[1][1] / item[1][3], reverse=True)} 77 | else: 78 | raise ValueError('support (freq, score, score_freq, score_freq_len, score_len)') 79 | 80 | def deduplicate(self, topk=100, dedup_ranking_type='freq'): 81 | self.search_res = self.load_search_res() 82 | for res in self.search_res: 83 | if topk is None: 84 | sent_dedup_triplets = res[1]['deduplicated:'] 85 | else: 86 | sent_dedup_triplets = {k: v for k, v in itertools.islice(res[1]['deduplicated:'].items(), topk)} 87 | for k, v in sent_dedup_triplets.items(): 88 | triplet = k.strip() 89 | freq = v[0] 90 | score = v[1] 91 | if triplet not in self.corpus_dedup_triplets: 92 | self.corpus_dedup_triplets[triplet] = [freq, score] 93 | else: 94 | self.corpus_dedup_triplets[triplet][0] += freq 95 | self.corpus_dedup_triplets[triplet][1] += score 96 | self.corpus_dedup_triplets = self.rank_entity_seqs_with_score_freq(self.corpus_dedup_triplets, 97 | dedup_ranking_type) 98 | json.dump(self.corpus_dedup_triplets, open(self.filepath, 'w')) 99 | 100 | def remove_non_ascii(self, text): 101 | return re.sub(r'[^\x00-\x7F]+', ' ', text).strip() 102 | 103 | def convert_to_eval_format(self, k_triplet, v_triplet, return_reverse=True, remove_relation_non_ascii=True): 104 | h_r_t = k_triplet.split('[SEP]') 105 | h = h_r_t[0].strip() 106 | h_span = v_triplet[2][0] 107 | r = h_r_t[1].strip() 108 | t = h_r_t[2].strip() 109 | t_span = v_triplet[2][1] 110 | if remove_relation_non_ascii: 111 | r = self.remove_non_ascii(r) 112 | if len(r) == 0: 113 | return None 114 | if return_reverse: 115 | return {"subject": h, "subject_char_span": h_span, "relation": r, "object": t, "object_char_span": t_span}, \ 116 | {"subject": t, "subject_char_span": t_span, "relation": r, "object": h, "object_char_span": h_span} 117 | return {"subject": h, "subject_char_span": h_span, "relation": r, "object": t, "object_char_span": t_span} 118 | 119 | def deduplicate_for_eval_fast(self, filepath, topk=None, dedup_ranking_type='freq', sent_dedup_type='entity_pair', 120 | doc_dedup_type='whole', return_reverse=True): 121 | 122 | def existstriplet(cand_triplet, existset, dedup_type): 123 | triplet = copy.deepcopy(cand_triplet) 124 | triplet.pop('score') 125 | triplet.pop('sentence') 126 | triplet.pop('offset') 127 | if dedup_type == 'entity_pair': 128 | triplet.pop('relation') 129 | elif dedup_type == 'whole': 130 | pass 131 | else: 132 | raise ValueError('support entity_pair or whole') 133 | if triplet not in existset: 134 | existset.append(triplet) 135 | return False 136 | return True 137 | 138 | dedup_triplets = {} 139 | dedup_triplets_with_sent = {} 140 | for f in tqdm(os.listdir(self.input_dir), desc='deduplicating batch'): 141 | if os.path.isdir(os.path.join(self.input_dir, f)): 142 | print(os.path.join(os.path.join(self.input_dir, f), 'search_res.json')) 143 | search_res = json.load(open(os.path.join(os.path.join(self.input_dir, f), 'search_res.json'), 'r')) 144 | for k, v in tqdm(search_res.items(), desc='deduplicating doc'): 145 | sent_dedup_triplets = [] 146 | sent_dedup_triplets_with_sent = [] 147 | for res in v: 148 | cands = [] 149 | if topk is None: 150 | raw_per_sent_dedup_triplets = res[1]['deduplicated:'] 151 | else: 152 | raw_per_sent_dedup_triplets = {k: v for k, v in 153 | itertools.islice(res[1]['deduplicated:'].items(), topk)} 154 | raw_per_sent_dedup_triplets = self.rank_entity_seqs_with_attached_score( 155 | raw_per_sent_dedup_triplets, dedup_ranking_type) 156 | for k_triplet, v_triplet in raw_per_sent_dedup_triplets.items(): 157 | eval_format = self.convert_to_eval_format(k_triplet, v_triplet[0], return_reverse) 158 | if eval_format is None: 159 | continue 160 | if return_reverse: 161 | sent_dedup_triplets.append(eval_format[0]) 162 | sent_dedup_triplets.append(eval_format[1]) 163 | eval_format_sent = copy.deepcopy(eval_format) 164 | eval_format_sent[0]['sentence'] = res[0] 165 | eval_format_sent[0]['score'] = v_triplet[1] 166 | eval_format_sent[0]['offset'] = v_triplet[0][4] 167 | eval_format_sent[1]['sentence'] = res[0] 168 | eval_format_sent[1]['score'] = v_triplet[1] 169 | eval_format_sent[1]['offset'] = v_triplet[0][4] 170 | sent_dedup_triplets_with_sent.append(eval_format_sent[0]) 171 | sent_dedup_triplets_with_sent.append(eval_format_sent[1]) 172 | else: 173 | sent_dedup_triplets.append(eval_format) 174 | eval_format_sent = copy.deepcopy(eval_format) 175 | eval_format_sent['sentence'] = res[0] 176 | eval_format_sent['score'] = v_triplet[1] 177 | eval_format_sent['offset'] = v_triplet[0][4] 178 | sent_dedup_triplets_with_sent.append(eval_format_sent) 179 | if k not in dedup_triplets: 180 | dedup_triplets[k] = sent_dedup_triplets 181 | else: 182 | dedup_triplets[k].extend(sent_dedup_triplets) 183 | if k not in dedup_triplets_with_sent: 184 | dedup_triplets_with_sent[k] = sent_dedup_triplets_with_sent 185 | else: 186 | dedup_triplets_with_sent[k].extend(sent_dedup_triplets_with_sent) 187 | 188 | for k, v in tqdm(dedup_triplets_with_sent.items(), desc='sorting'): 189 | dedup_triplets_with_sent[k] = [e for e in sorted(v, key=lambda item: item['score'], reverse=True)] 190 | for docid, cand_triplets in tqdm(dedup_triplets_with_sent.items(), desc='merging doc'): 191 | sent_dedup_triplets_with_sent = [] 192 | existset = [] 193 | for cand_triplet in cand_triplets: 194 | sent_dedup_triplets_with_sent.append(cand_triplet) 195 | dedup_triplets_with_sent[docid] = sent_dedup_triplets_with_sent 196 | json.dump(dedup_triplets_with_sent, open(filepath, 'w')) 197 | -------------------------------------------------------------------------------- /src/deepex/model/eval.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | class Eval: 4 | def __init__(self): 5 | self.num_triplets = 0 6 | 7 | def eval_number_of_triplets(self, filepath): 8 | if filepath.endswith('.json'): 9 | res = json.load(open(filepath, 'r')) 10 | self.num_triplets = len(res) 11 | else: 12 | raise ValueError('the result format should be json') 13 | 14 | def eval_number_of_triplets_with_docid(self, filepath): 15 | self.num_triplets = 0 16 | res = json.load(open(filepath, 'r')) 17 | for k, v in res.items(): 18 | self.num_triplets += len(v) 19 | -------------------------------------------------------------------------------- /src/deepex/model/kgm.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import math 3 | from tqdm.auto import tqdm, trange 4 | from typing import Callable, Dict, List, Optional, Tuple 5 | import copy 6 | import time 7 | import numpy as np 8 | from itertools import islice 9 | import warnings 10 | from torch.multiprocessing import Pool, set_start_method 11 | from functools import partial 12 | 13 | import torch 14 | from torch.utils.data.dataloader import DataLoader 15 | 16 | from transformers import GPT2TokenizerFast, BertTokenizerFast, GPT2Tokenizer 17 | from transformers.training_args import is_torch_tpu_available 18 | from transformers.trainer_utils import EvalPrediction, PredictionOutput 19 | 20 | if is_torch_tpu_available(): 21 | import torch_xla.core.xla_model as xm 22 | import torch_xla.debug.metrics as met 23 | import torch_xla.distributed.parallel_loader as pl 24 | 25 | from ..utils import * 26 | 27 | logger = logging.getLogger(__name__) 28 | 29 | 30 | def layer_attention(attention, layer_id): 31 | if layer_id == -100: 32 | all_attention = torch.stack(attention, dim=0) 33 | return all_attention.mean(dim=0) 34 | return attention[layer_id][:, :, :, :] 35 | 36 | def transform_layer_attention(attention, type): 37 | if type == 'mean': 38 | return attention.mean(1) 39 | elif type == 'max': 40 | return attention.max(1).values 41 | elif type == 'sum': 42 | return attention.sum(1) 43 | else: 44 | raise ValueError('support mean max sum') 45 | 46 | 47 | def convert_tokens_to_string(tokens, tokenizer): 48 | if isinstance(tokenizer, BertTokenizerFast): 49 | out_string = " ".join(tokens).replace(" ##", "").strip() 50 | return out_string 51 | elif isinstance(tokenizer, GPT2TokenizerFast) or isinstance(tokenizer, GPT2Tokenizer): 52 | text = "".join(tokens) 53 | text = bytearray([tokenizer.byte_decoder[c] for c in text]).decode("utf-8", errors=tokenizer.errors) 54 | return text 55 | else: 56 | raise ValueError('only gpt2 and bert tokenizer') 57 | 58 | def find_seq_offsets(seq_id, batch, inputs, tokenizer, begin_seq_id, end_seq_id): 59 | if 'input_ids' in inputs: 60 | inputs = inputs['input_ids'] 61 | subword = convert_seq_id_to_subword(seq_id, batch, inputs, tokenizer) 62 | pre_offset = 0 63 | if subword.startswith('##'): 64 | pre_offset = 1 65 | for pre_id in range(seq_id - 1, begin_seq_id - 1, -1): 66 | presubword = convert_seq_id_to_subword(pre_id, batch, inputs, tokenizer) 67 | if not presubword.startswith('##'): 68 | break 69 | pre_offset += 1 70 | next_offset = 0 71 | for next_id in range(seq_id + 1, end_seq_id + 1, 1): 72 | nextsubword = convert_seq_id_to_subword(next_id, batch, inputs, tokenizer) 73 | if not nextsubword.startswith('##'): 74 | break 75 | next_offset += 1 76 | return pre_offset, next_offset 77 | 78 | 79 | def is_same_span(span0, span1): 80 | return span0[0] == span1[0] and span0[1] == span1[1] 81 | 82 | 83 | def find_rids(seq, batch, inputs): 84 | first_rid = seq[1] 85 | last_rid = seq[-2] 86 | entity_ids = [inputs['entity_ids'][batch][ind] for ind in seq] 87 | h_span = entity_ids[0].span 88 | t_span = entity_ids[-1].span 89 | for i in range(1, len(entity_ids) - 2, 1): 90 | if is_same_span(h_span, entity_ids[i].span): 91 | first_rid = seq[i + 1] 92 | else: 93 | break 94 | for i in range(len(entity_ids) - 2, 1, -1): 95 | if is_same_span(t_span, entity_ids[i].span): 96 | last_rid = seq[i - 1] 97 | else: 98 | break 99 | if first_rid > last_rid: 100 | return None, None 101 | return first_rid, last_rid 102 | 103 | 104 | def convert_to_triplet_relation(seq, batch, inputs, tokenizer): 105 | hid = seq[0] 106 | tid = seq[-1] 107 | first_rid, last_rid = find_rids(seq, batch, inputs) 108 | if first_rid is None or last_rid is None: 109 | return None 110 | first_rid_pre_offset, first_rid_next_offset = find_seq_offsets(first_rid, batch, inputs, tokenizer, hid, tid) 111 | last_rid_pre_offset, last_rid_next_offset = find_seq_offsets(last_rid, batch, inputs, tokenizer, hid, tid) 112 | first_pruned_rid = first_rid 113 | last_pruned_rid = last_rid 114 | if first_rid - first_rid_pre_offset <= hid: 115 | first_pruned_rid = first_rid + first_rid_next_offset + 1 116 | if last_rid + last_rid_next_offset >= tid: 117 | last_pruned_rid = last_rid - last_rid_pre_offset - 1 118 | if first_pruned_rid > last_pruned_rid: 119 | return None 120 | return convert_tokens_to_string( 121 | tokenizer.convert_ids_to_tokens(inputs['input_ids'][batch][first_pruned_rid:last_pruned_rid+1]), tokenizer) 122 | 123 | 124 | def convert_tokens_to_triplet(seq, batch, inputs, tokenizer, beam_mode='IE'): 125 | if len(seq) < 3: 126 | return None, None 127 | if beam_mode=='RC': 128 | entity_ids = [inputs['head_entity_ids'][batch][seq[0]]] + [inputs['relation_entity_ids'][batch][ind] for ind in seq[1:-1]] + [inputs['tail_entity_ids'][batch][seq[-1]]] 129 | else: 130 | entity_ids = [inputs['entity_ids'][batch][ind] for ind in seq] 131 | h = entity_ids[0].name.title() 132 | t = entity_ids[-1].name.title() 133 | h_span = entity_ids[0].span 134 | t_span = entity_ids[-1].span 135 | if is_same_span(h_span, t_span): 136 | return None, None 137 | h_t_spans = [h_span, t_span] 138 | if beam_mode=='RC': 139 | r = entity_ids[1].name 140 | else: 141 | r = convert_to_triplet_relation(seq, batch, inputs, tokenizer) 142 | if r is None: 143 | return None, None 144 | return h + ' [SEP] ' + r + ' [SEP] ' + t, h_t_spans 145 | 146 | 147 | def search_candidate_gen(seq, batch, inputs, tokenizer, model_args): 148 | cand_type = model_args.search_cand_type 149 | if cand_type == 'word': 150 | return convert_tokens_to_string( 151 | [convert_seq_id_to_subword(ind, batch, inputs, tokenizer) for ind in seq], tokenizer) 152 | elif cand_type == 'entity': 153 | if 'docid' in inputs: 154 | docid = inputs['docid'][batch] 155 | else: 156 | docid = -1 157 | if 'offset' in inputs: 158 | offset = inputs['offset'][batch] 159 | else: 160 | offset = -1 161 | triplet, head_tail_spans = convert_tokens_to_triplet(seq, batch, inputs, tokenizer, model_args.beam_mode) 162 | if triplet is None or head_tail_spans is None: 163 | return triplet, None, head_tail_spans, docid, offset 164 | return triplet.strip(), \ 165 | convert_tokens_to_string([convert_seq_id_to_subword(ind, batch, inputs, tokenizer) 166 | for ind in seq], tokenizer).strip(), head_tail_spans, docid, offset 167 | else: 168 | raise ValueError('candidate type can only be word or entity') 169 | 170 | 171 | def filter_cand_by_min_len(k, model_args): 172 | if len(k.strip().split(' ')) >= model_args.cand_min_len: 173 | return False 174 | return True 175 | 176 | def rank_entity_seqs_with_score_freq(x, model_args): 177 | 178 | if model_args.dedup_ranking_type == 'freq': 179 | return {k: v for k, v in sorted(x.items(), 180 | key=lambda item: item[1][0], reverse=True) 181 | if not filter_cand_by_min_len(k, model_args)} 182 | elif model_args.dedup_ranking_type == 'score': 183 | return {k: v for k, v in sorted(x.items(), 184 | key=lambda item: item[1][1], reverse=True) 185 | if not filter_cand_by_min_len(k, model_args)} 186 | elif model_args.dedup_ranking_type == 'score_freq': 187 | return {k: v for k, v in sorted(x.items(), 188 | key=lambda item: item[1][1] / item[1][0], reverse=True) 189 | if not filter_cand_by_min_len(k, model_args)} 190 | elif model_args.dedup_ranking_type == 'score_freq_len': 191 | warnings.warn( 192 | 'use score_len instead! score_freq_len is not recommended since it incorporates extended as a result of continous relation span constrain') 193 | return {k: v for k, v in sorted(x.items(), 194 | key=lambda item: item[1][1] / ( 195 | item[1][0] * len(item[0].strip().split(' '))), 196 | reverse=True) 197 | if not filter_cand_by_min_len(k, model_args)} 198 | elif model_args.dedup_ranking_type == 'score_len': 199 | return {k: v for k, v in sorted(x.items(), 200 | key=lambda item: item[1][1] / item[1][3], reverse=True) 201 | if not filter_cand_by_min_len(k, model_args)} 202 | else: 203 | raise ValueError('support (freq, score, score_freq, score_freq_len, score_len)') 204 | 205 | 206 | def assign_search_result(search_res, subword_input, docid, entity_seqs_with_score_freq, subword_seqs_with_scores, 207 | model_args): 208 | if docid is None: 209 | return 210 | subword_input = '$input_txt:$ ' + subword_input 211 | if docid not in search_res: 212 | search_res[docid] = [ 213 | [subword_input, 214 | {'deduplicated:': rank_entity_seqs_with_score_freq(entity_seqs_with_score_freq, model_args)}]] 215 | else: 216 | search_res[docid].append( 217 | [subword_input, 218 | {'deduplicated:': rank_entity_seqs_with_score_freq(entity_seqs_with_score_freq, model_args)}]) 219 | 220 | 221 | def search_results_gen(res_indices, model_args, inputs, tokenizer): 222 | tic = time.time() 223 | special_tokens = set(tokenizer.special_tokens_map.values()) 224 | search_res = {} 225 | pre_b = -1 226 | pre_docid = None 227 | subword_seqs_with_scores = [] 228 | entity_seqs_with_score_freq = {} 229 | docid = None 230 | for seq in res_indices: 231 | cur_b = seq[3] 232 | if model_args.beam_mode!="RC": 233 | seq[0] = seq[0] if seq[0][0]= min_len and len(seq[0]) <= max_len: 278 | if search_ranking_type == 'mean': 279 | seq[1] = seq[1] / len(seq[0]) 280 | if seq[1] > score_threshold: 281 | filter_sort_res.append(seq) 282 | filter_sort_res = sorted(sorted(filter_sort_res, key=lambda tup: tup[1], reverse=True), key=lambda tup: tup[3]) 283 | dict_filter_sort_res = {} 284 | for tup in filter_sort_res: 285 | if tup[3] not in dict_filter_sort_res: 286 | dict_filter_sort_res[tup[3]] = [] 287 | dict_filter_sort_res[tup[3]].append(tup) 288 | filter_sort_res = [] 289 | for k, v in dict_filter_sort_res.items(): 290 | if n is not None and n != 'None': 291 | filter_sort_res.extend(v[:n]) 292 | else: 293 | filter_sort_res.extend(v) 294 | return filter_sort_res 295 | 296 | 297 | def entity_sent_gen_per_sample(attention, b, inputs, tokenizer, model_args, prefix=""): 298 | eid_sids = [seq_id for seq_id in range(attention.size()[1]) 299 | if inputs['%sentity_ids'%prefix][b][seq_id].name != '$NIL$' 300 | and inputs['special_tokens_mask'][b][seq_id].item() == 0 301 | and convert_tokens_to_string( 302 | convert_seq_id_to_subword(seq_id, b, inputs, tokenizer), 303 | tokenizer).strip() not in '!=?'] 304 | if model_args.add_extra_entity: 305 | non_special_tokens_mask_indices = (inputs['special_tokens_mask'][b] == 0).nonzero(as_tuple=False) 306 | if len(non_special_tokens_mask_indices)>0: 307 | first_token_id = non_special_tokens_mask_indices[0].item() 308 | if first_token_id not in eid_sids: 309 | eid_sids.append(first_token_id) 310 | if len(non_special_tokens_mask_indices)>1: 311 | last_token_id = non_special_tokens_mask_indices[-1].item() - 1 312 | if last_token_id not in eid_sids: 313 | eid_sids.append(last_token_id) 314 | if len(eid_sids) < 1: 315 | return None, None 316 | eid_sids = sorted(eid_sids) 317 | if model_args.sentence: 318 | if '%sentity_ids'%prefix in inputs: 319 | split_indices = [seq_id for seq_id in range(attention.size()[1]) 320 | if convert_tokens_to_string( 321 | convert_seq_id_to_subword(seq_id, b, inputs, tokenizer), tokenizer).strip() in '!=?' and 322 | convert_tokens_to_string( 323 | convert_seq_id_to_subword(seq_id, b, inputs, tokenizer), tokenizer).strip() != ''] 324 | sent_eid_sids = [] 325 | for i in range(-1, len(split_indices)): 326 | sent_eid_sid = [] 327 | if model_args.add_extra_entity: 328 | if 0 <= i < len(split_indices)-1: 329 | sent_eid_sid.extend([split_indices[i] + 1, split_indices[i + 1] - 1]) 330 | for j in range(len(eid_sids)): 331 | if i == -1: 332 | if len(split_indices)==0 or eid_sids[j] < split_indices[0]: 333 | if eid_sids[j] not in sent_eid_sid: 334 | sent_eid_sid.append(eid_sids[j]) 335 | elif i == len(split_indices) - 1: 336 | if eid_sids[j] > split_indices[i]: 337 | if eid_sids[j] not in sent_eid_sid: 338 | sent_eid_sid.append(eid_sids[j]) 339 | else: 340 | if eid_sids[j] > split_indices[i] and eid_sids[j] < split_indices[i + 1]: 341 | if eid_sids[j] not in sent_eid_sid: 342 | sent_eid_sid.append(eid_sids[j]) 343 | sent_eid_sids.append(sorted(sent_eid_sid)) 344 | if len(sent_eid_sid) >= 1: 345 | eid_sids.append(sorted(sent_eid_sid)[-1]) 346 | else: 347 | raise ValueError('entity ids must be provided in input to use the generation algs') 348 | return sorted(eid_sids), sent_eid_sids 349 | else: 350 | return sorted(eid_sids), [eid_sids] 351 | 352 | def segment_location(a, u, v): 353 | return (a1 and ( 373 | (direction== 'left' and indices[ind].item() >= v) 374 | or (direction=='right' and indices[ind].item() <= v) 375 | or cross_segment_check(indices[ind].item()+offset,v+offset,node,bound+offset) 376 | )) 377 | or indices[ind].item()+offset in c[0] 378 | ): 379 | continue 380 | c_new = copy.deepcopy(c) 381 | c_new[0].append(indices[ind].item()+offset) 382 | c_new[1] += vals[ind].item() 383 | c_new[2] = False 384 | c_new[3] = b 385 | beam_new.append(c_new) 386 | tempk += 1 387 | else: 388 | c[2] = True 389 | beam_new.append(c) 390 | beam = sorted(beam_new, key=lambda tup: tup[1]/len(tup[0]), reverse=True)[:topk] 391 | return beam 392 | 393 | def fast_bidirectional_beam_search_alg(attention, n, topk, max_len, min_len, score_threshold, inputs, tokenizer, model_args): 394 | if model_args.beam_mode=="IE": 395 | res = [] 396 | for b in range(attention.size()[0]): 397 | eid_sids, sent_eid_sids = entity_sent_gen_per_sample(attention, b, inputs, tokenizer, model_args) 398 | if eid_sids is None or sent_eid_sids is None: 399 | continue 400 | offset = eid_sids[0] 401 | pruned_attention = attention[b][offset:eid_sids[-1] + 1, offset:eid_sids[-1] + 1] 402 | if 'gpt2' in model_args.model_name_or_path: 403 | pruned_attention_t = pruned_attention.transpose(0, 1).triu(diagonal=1) 404 | pruned_attention = pruned_attention + pruned_attention_t 405 | vals, indices = pruned_attention.sort(descending=True) 406 | for sent_eid_sid in sent_eid_sids: 407 | for i in range(len(sent_eid_sid)): 408 | u = sent_eid_sid[i] 409 | for j in range(i - 1, i - 1 - model_args.dist_const, -1): 410 | if j < 0: 411 | break 412 | v = sent_eid_sid[j] 413 | left_cur_res = fast_unidirectional_beam_search_helper(u, offset, vals, indices, topk, b, 'left', v) 414 | res.extend(left_cur_res) 415 | for j in range(i + 1, i + 1 + model_args.dist_const, 1): 416 | if j > len(sent_eid_sid) - 1: 417 | break 418 | v = sent_eid_sid[j] 419 | right_cur_res = fast_unidirectional_beam_search_helper(u, offset, vals, indices, topk, b, 'right', v) 420 | res.extend(right_cur_res) 421 | return filter_sort_result(res, n, max_len, min_len, score_threshold, model_args.search_ranking_type) 422 | elif model_args.beam_mode=="RC": 423 | model_args.add_extra_entity = False; res = [] 424 | for b in range(attention.size()[0]): 425 | head_eid_sids, head_sent_eid_sids = entity_sent_gen_per_sample(attention, b, inputs, tokenizer, model_args, prefix="head_") 426 | tail_eid_sids, tail_sent_eid_sids = entity_sent_gen_per_sample(attention, b, inputs, tokenizer, model_args, prefix="tail_") 427 | relation_eid_sids, relation_sent_eid_sids = entity_sent_gen_per_sample(attention, b, inputs, tokenizer, model_args, prefix="relation_") 428 | if head_eid_sids is None or head_sent_eid_sids is None or tail_eid_sids is None or tail_sent_eid_sids is None or relation_eid_sids is None or relation_sent_eid_sids is None: 429 | continue 430 | offset = min(head_eid_sids[0],tail_eid_sids[0],relation_eid_sids[0]); bound = max(head_eid_sids[-1],tail_eid_sids[-1],relation_eid_sids[-1]) 431 | pruned_attention = attention[b][offset:bound + 1, offset:bound + 1] 432 | if 'gpt2' in model_args.model_name_or_path: 433 | pruned_attention_t = pruned_attention.transpose(0, 1).triu(diagonal=1) 434 | pruned_attention = pruned_attention + pruned_attention_t 435 | for (head_sent_eid_sid,relation_sent_eid_sid,tail_sent_eid_sid) in zip(head_sent_eid_sids, relation_sent_eid_sids, tail_sent_eid_sids): 436 | heads = [] 437 | for k,i in enumerate(head_sent_eid_sid): 438 | if inputs['head_entity_ids'][b][i].name=="$NIL": 439 | continue 440 | new = True; head = [] 441 | for p,j in enumerate(head_sent_eid_sid): 442 | if inputs['head_entity_ids'][b][i].span==inputs['head_entity_ids'][b][j].span: 443 | if p < k: 444 | new = False; break 445 | else: 446 | head.append(j) 447 | if new: 448 | heads.append(head) 449 | 450 | tails = [] 451 | for k,i in enumerate(tail_sent_eid_sid): 452 | if inputs['tail_entity_ids'][b][i].name=="$NIL": 453 | continue 454 | new = True; tail = [] 455 | for p,j in enumerate(tail_sent_eid_sid): 456 | if inputs['tail_entity_ids'][b][i].span==inputs['tail_entity_ids'][b][j].span: 457 | if p < k: 458 | new = False; break 459 | else: 460 | tail.append(j) 461 | if new: 462 | tails.append(tail) 463 | 464 | relations = [] 465 | for k,i in enumerate(relation_sent_eid_sid): 466 | if inputs['relation_entity_ids'][b][i].name=="$NIL": 467 | continue 468 | new = True; relation = [] 469 | for p,j in enumerate(relation_sent_eid_sid): 470 | if inputs['relation_entity_ids'][b][i].span==inputs['relation_entity_ids'][b][j].span: 471 | if p < k: 472 | new = False; break 473 | else: 474 | relation.append(j) 475 | if new: 476 | relations.append(relation) 477 | 478 | def sim_beam0(head, relation, tail): 479 | beam_score = -1; beam = [] 480 | for r in range(1,len(relation)+1): 481 | for l in range(r): 482 | part_rel = relation[l:r] 483 | rel_score = sum([pruned_attention[i-offset][j-offset] for i,j in zip(part_rel,part_rel[1:])]) 484 | for h in head: 485 | for t in tail: 486 | score = float(pruned_attention[h-offset][relation[l]-offset] + rel_score + pruned_attention[relation[r-1]-offset][t-offset]) 487 | if score > beam_score: 488 | beam_score = score; beam = [[h] + part_rel + [t], score, True, b] 489 | return beam 490 | 491 | for head in heads: 492 | for tail in tails: 493 | cur_res = [] 494 | for relation in relations: 495 | beam = sim_beam0(head, relation, tail) 496 | if beam!=[]: 497 | cur_res.append(beam) 498 | beam = sim_beam0(tail, relation, head) 499 | beam[0][0], beam[0][-1] = beam[0][-1], beam[0][0] 500 | if beam!=[]: 501 | cur_res.append(beam) 502 | res.extend(sorted(cur_res,key=lambda x:-x[1]/len(x[0]))[:topk*2]) 503 | return filter_sort_result(res, n, max_len, min_len, score_threshold, model_args.search_ranking_type) 504 | else: 505 | raise NotImplementedError 506 | 507 | def fast_unsupervised_bidirectional_beam_search(attention, model_args, inputs, tokenizer): 508 | tic = time.time() 509 | res_indices = fast_bidirectional_beam_search_alg(attention, model_args.search_n, 510 | model_args.beam_size, 511 | model_args.search_max_len, 512 | model_args.search_min_len, 513 | model_args.search_score_threshold, 514 | inputs, tokenizer, model_args) 515 | logger.info('search time cost {}s'.format(time.time() - tic)) 516 | return search_results_gen(res_indices, model_args, inputs, tokenizer) 517 | 518 | def convert_seq_id_to_subword(seq_id, batch, inputs, tokenizer): 519 | if not isinstance(inputs, torch.Tensor): 520 | if 'input_ids' in inputs: 521 | inputs = inputs['input_ids'] 522 | subword_id = inputs[batch][seq_id].item() 523 | subword = tokenizer.convert_ids_to_tokens([subword_id])[0] 524 | return subword 525 | 526 | 527 | def merge_search_res(search_res, global_search_res): 528 | for k, v in search_res.items(): 529 | if k not in global_search_res: 530 | global_search_res[k] = v 531 | else: 532 | global_search_res[k].extend(v) 533 | 534 | 535 | def predict_and_save_results(dataloader: DataLoader, description: str, trainer, 536 | model_args, tokenizer, prediction_loss_only: Optional[bool] = None 537 | ): 538 | if model_args.compute_loss: 539 | prediction_loss_only = prediction_loss_only if prediction_loss_only is not None else trainer.prediction_loss_only 540 | 541 | model = trainer.model 542 | if trainer.args.n_gpu > 1: 543 | model = torch.nn.DataParallel(model) 544 | else: 545 | model = trainer.model 546 | 547 | batch_size = dataloader.batch_size 548 | logger.info("***** Running %s *****", description) 549 | logger.info(" Num examples = %d", trainer.num_examples(dataloader)) 550 | logger.info(" Batch size = %d", batch_size) 551 | if model_args.compute_loss: 552 | eval_losses: List[float] = [] 553 | preds: torch.Tensor = None 554 | label_ids: torch.Tensor = None 555 | model.eval() 556 | 557 | if is_torch_tpu_available(): 558 | dataloader = pl.ParallelLoader(dataloader, [trainer.args.device]).per_device_loader(trainer.args.device) 559 | 560 | res_dict = {} 561 | res_rel_dict = {} 562 | search_res = {} 563 | stats = {'max': -1, 'min': 1, 'sum': 0, 'num': 0, 'plot': None} 564 | for inputs in tqdm(dataloader, desc=description): 565 | if model_args.compute_loss: 566 | has_labels = any(inputs.get(k) is not None for k in ["labels", "lm_labels", "masked_lm_labels"]) 567 | entity_ids = inputs.pop('entity_ids') 568 | head_entity_ids = inputs.pop('head_entity_ids') 569 | tail_entity_ids = inputs.pop('tail_entity_ids') 570 | relation_entity_ids = inputs.pop('relation_entity_ids') 571 | special_tokens_mask = inputs.pop('special_tokens_mask') 572 | docid = inputs.pop('docid') 573 | offset = inputs.pop('offset') 574 | text = inputs.pop('text') 575 | for k, v in inputs.items(): 576 | inputs[k] = v.to(trainer.args.device) 577 | 578 | with torch.no_grad(): 579 | tic = time.time() 580 | outputs = model(**inputs) 581 | logger.info('forward time cost {}s'.format(time.time() - tic)) 582 | for k, v in inputs.items(): 583 | inputs[k] = v.cpu() 584 | inputs['entity_ids'] = entity_ids 585 | inputs['head_entity_ids'] = head_entity_ids 586 | inputs['tail_entity_ids'] = tail_entity_ids 587 | inputs['relation_entity_ids'] = relation_entity_ids 588 | inputs['special_tokens_mask'] = special_tokens_mask 589 | inputs['docid'] = docid 590 | inputs['offset'] = offset 591 | inputs['text'] = text 592 | if model_args.generation_type == 'fast_unsupervised_bidirectional_beam_search': 593 | attentions = transform_layer_attention(layer_attention(outputs[-1], model_args.search_layer_id), 594 | model_args.search_attention_head_type) 595 | merge_search_res(fast_unsupervised_bidirectional_beam_search(attentions, model_args, inputs, tokenizer), 596 | search_res) 597 | else: 598 | raise ValueError('search not supported') 599 | if model_args.compute_loss: 600 | if has_labels: 601 | step_eval_loss, logits = outputs[:2] 602 | eval_losses += [step_eval_loss.mean().item()] 603 | else: 604 | logits = outputs[0] 605 | if model_args.compute_loss: 606 | if not prediction_loss_only: 607 | if preds is None: 608 | preds = logits.detach() 609 | else: 610 | preds = torch.cat((preds, logits.detach()), dim=0) 611 | if inputs.get("labels") is not None: 612 | if label_ids is None: 613 | label_ids = inputs["labels"].detach() 614 | else: 615 | label_ids = torch.cat((label_ids, inputs["labels"].detach()), dim=0) 616 | if model_args.compute_loss: 617 | if trainer.args.local_rank != -1: 618 | if preds is not None: 619 | preds = trainer.distributed_concat(preds, num_total_examples=trainer.num_examples(dataloader)) 620 | if label_ids is not None: 621 | label_ids = trainer.distributed_concat(label_ids, num_total_examples=trainer.num_examples(dataloader)) 622 | elif is_torch_tpu_available(): 623 | if preds is not None: 624 | preds = xm.mesh_reduce("eval_preds", preds, torch.cat) 625 | if label_ids is not None: 626 | label_ids = xm.mesh_reduce("eval_label_ids", label_ids, torch.cat) 627 | 628 | if preds is not None: 629 | preds = preds.cpu().numpy() 630 | if label_ids is not None: 631 | label_ids = label_ids.cpu().numpy() 632 | 633 | if trainer.compute_metrics is not None and preds is not None and label_ids is not None: 634 | metrics = trainer.compute_metrics(EvalPrediction(predictions=preds, label_ids=label_ids)) 635 | else: 636 | metrics = {} 637 | if len(eval_losses) > 0: 638 | metrics["eval_loss"] = np.mean(eval_losses) 639 | 640 | for key in list(metrics.keys()): 641 | if not key.startswith("eval_"): 642 | metrics[f"eval_{key}"] = metrics.pop(key) 643 | res_dict = sorted(res_dict.items(), key=lambda x: x[1], reverse=True) 644 | if model_args.compute_loss: 645 | return PredictionOutput(predictions=preds, label_ids=label_ids, metrics=metrics), \ 646 | (res_dict, res_rel_dict, stats, search_res) 647 | return None, (res_dict, res_rel_dict, stats, search_res) -------------------------------------------------------------------------------- /src/deepex/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import sys 4 | import math 5 | import time 6 | import tqdm 7 | import json 8 | import shutil 9 | import logging 10 | import requests 11 | import jsonlines 12 | import subprocess 13 | import numpy as np 14 | from tqdm import tqdm 15 | 16 | def Folder(PATH): 17 | return "/".join(PATH.split('/')[:-1])+"/" 18 | 19 | def File(PATH): 20 | return PATH.split('/')[-1] 21 | 22 | def Prefix(PATH): 23 | return ".".join(PATH.split('.')[:-1]) 24 | 25 | def Suffix(PATH): 26 | return PATH.split('.')[-1] 27 | 28 | def Create(PATH): 29 | None if os.path.exists(PATH) else os.makedirs(PATH) 30 | 31 | def Delete(PATH): 32 | shutil.rmtree(PATH) if os.path.exists(PATH) else None 33 | 34 | def Clear(PATH): 35 | shutil.rmtree(PATH) if os.path.exists(PATH) else None; os.makedirs(PATH) 36 | 37 | 38 | def SaveJSON(object, FILE, jsonl=False, indent=None): 39 | if jsonl: 40 | with jsonlines.open(FILE, 'w') as f: 41 | for data in object: 42 | f.write(data) 43 | else: 44 | with open(FILE, 'w') as f: 45 | json.dump(object, f, indent=indent) 46 | 47 | def PrettifyJSON(PATH): 48 | if PATH[-1]=='/': 49 | for FILE in os.listdir(PATH): 50 | SaveJSON(LoadJSON(PATH+FILE),PATH+FILE,indent=4) 51 | else: 52 | SaveJSON(LoadJSON(PATH),PATH,indent=4) 53 | 54 | def LoadJSON(FILE, jsonl=False): 55 | if jsonl: 56 | with open(FILE, 'r') as f: 57 | return [data for data in jsonlines.Reader(f)] 58 | else: 59 | with open(FILE, 'r') as f: 60 | return json.load(f) 61 | 62 | def View(something, length=4096): 63 | print(str(something)[:length]+" ..." if len(str(something))>length+3 else str(something)) 64 | 65 | def ViewS(something, length=4096): 66 | return (str(something)[:length]+" ..." if len(str(something))>length+3 else str(something)) 67 | 68 | def ViewDict(something, length=4096, limit=512): 69 | print("{") 70 | for i,item in enumerate(something.items()): 71 | print("\t"+str(item[0])+": "+(ViewS(item[1])+',')) 72 | if i>=limit: 73 | print("\t..."); break 74 | print("}") 75 | 76 | def ViewDictS(something, length=4096, limit=512): 77 | s = "{\n" 78 | for i,item in enumerate(something.items()): 79 | s += "\t"+str(item[0])+": "+(ViewS(item[1])+',')+"\n" 80 | if i>=limit: 81 | s += "\t...\n"; break 82 | s += "}\n"; return s 83 | 84 | def ViewJSON(json_dict, length=4096): 85 | print(ViewS(json.dumps(json_dict,indent=4))) 86 | 87 | def ViewJSONS(json_dict, length=4096): 88 | return ViewS(json.dumps(json_dict,indent=4)) 89 | 90 | def DATE(): 91 | return time.strftime("%Y-%m-%d",time.localtime(time.time())) 92 | 93 | def CMD(command, wait=True): 94 | h = subprocess.Popen(command,shell=True); return h.wait() if wait else h 95 | 96 | def PrintConsole(*args, **kwargs): 97 | print(*args, file=sys.stdout, **kwargs) 98 | 99 | def PrintError(*args, **kwargs): 100 | print(*args, file=sys.stderr, **kwargs) 101 | 102 | def LineToFloats(line): 103 | return [float(s) for s in re.findall(r"(?