├── LICENSE ├── README.md ├── data └── process_rawdata.py ├── data_setup.sh ├── entity_detection ├── args.py ├── evaluation.py ├── model.py ├── predict.py ├── process.sh ├── seqLabelingLoader.py └── train.py ├── freebase_data ├── convert.py └── dump_virtuoso_data │ └── virtuoso.ini ├── relation_ranking ├── args.py ├── attention.py ├── model.py ├── predict.py ├── process.sh ├── seqRankingLoader.py └── train.py ├── tools ├── __init__.py ├── embedding.py ├── qa_data.py ├── utils.py └── virtuoso.py └── vocab ├── __init__.py ├── create_vocab.py └── dictionary.py /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # kbqa-ar-smcnn 2 | Question answering over Freebase (single-relation). 3 | 4 | This is the source code for Question Answering over Freebase via Attentive RNN with Similarity Matrix based CNN. 5 | ## Install packages 6 | - Python 3.5 7 | - PyTorch 0.2.0 8 | - NLTK 9 | - NLTK data (tokenizers, stopwords list) 10 | - Virtuoso 11 | 12 | After install Virtuoso, you should modify config file at freebase_data/dump_virtuoso_data/virtuoso.ini (or you could copy your virtuoso.ini file to this directory) 13 | 1. You have to change the variables such as NumberOfBuffers and MaxDirtyBuffers depending on your RAM. For other variables you can follow the official documents. 14 | 2. Add the absolute path of XXX/freebase_data/dump_virtuoso_data/ to DirsAllowed. 15 | 16 | ## Start the Virtuoso server 17 | This may need to be under the root user. 18 | ``` 19 | virtuoso-t +foreground +configfile freebase_data/dump_virtuoso_data/virtuoso.ini & 20 | ``` 21 | 22 | ## Set up 23 | Run the setup script. This takes a long time. It fetches datasets, does some preprocesses, and dumps Freebase triples into Virtuoso. 24 | ``` 25 | sh data_setup.sh 26 | ``` 27 | 28 | ## Training 29 | - entity detection model 30 | ``` 31 | cd entity_detection 32 | sh process.sh 33 | python predict.py --trained_model XXX --results_path results --save_qadata 34 | ``` 35 | - relation detection model 36 | ``` 37 | cd relation_ranking 38 | python seqRankingLoader.py --batch_size 64 --neg_size 50 #Create training data for relation detection 39 | sh process.sh 40 | python predict.py --trained_model XXX --results_path results --predict 41 | ``` 42 | -------------------------------------------------------------------------------- /data/process_rawdata.py: -------------------------------------------------------------------------------- 1 | # This tool preprocess the original simple question dataset in 5 aspects: 2 | # 1. change triple information in to fb:... format 3 | # 2. replace the escape ('//') simbol in original question 4 | # 3. tokenize the question 5 | # 4. change the tokenized question into lower cases 6 | # 5. add another fields which indicates the token number of the question 7 | # 6. All pairs are sorted by the length of the question, and stored to QAData.*.pkl 8 | 9 | import multiprocessing as mp 10 | import sys, os, io, re 11 | import pickle 12 | from nltk import word_tokenize 13 | sys.path.append('../tools') 14 | from qa_data import QAData 15 | import virtuoso 16 | 17 | split = None 18 | 19 | def extract(line): 20 | fields = line.strip().split('\t') 21 | sub = 'fb:' + fields[0].split('www.freebase.com/')[-1].replace('/','.') 22 | rel = 'fb:' + fields[1].split('www.freebase.com/')[-1].replace('/','.') 23 | obj = 'fb:' + fields[2].split('www.freebase.com/')[-1].replace('/','.') 24 | if sub == 'fb:m.07s9rl0': 25 | sub = 'fb:m.02822' 26 | if obj == 'fb:m.07s9rl0': 27 | obj = 'fb:m.02822' 28 | question = fields[-1].replace('\\\\','') 29 | tokens = word_tokenize(question) 30 | question = ' '.join(tokens).lower() 31 | tmp = question.split(' . ') 32 | return '. '.join(tmp), sub, rel, obj, len(tokens)-len(tmp)+1 33 | 34 | def get_indices(src_list, pattern_list): 35 | indices = None 36 | for i in range(len(src_list)): 37 | match = 1 38 | for j in range(len(pattern_list)): 39 | if src_list[i+j] != pattern_list[j]: 40 | match = 0 41 | break 42 | if match: 43 | indices = range(i, i + len(pattern_list)) 44 | break 45 | return indices 46 | 47 | def query_golden_subs(data): 48 | golden_subs = [] 49 | if data.text_subject: 50 | # extract fields needed 51 | relation = data.relation 52 | subject = data.subject 53 | text_subject = data.text_subject 54 | 55 | # query name / alias by subject (id) 56 | candi_sub_list = virtuoso.str_query_id(text_subject) 57 | 58 | # add candidates to data 59 | for candi_sub in candi_sub_list: 60 | candi_rel_list = virtuoso.id_query_out_rel(candi_sub) 61 | if relation in candi_rel_list: 62 | golden_subs.append(candi_sub) 63 | 64 | if len(golden_subs) == 0: 65 | golden_subs = [data.subject] 66 | 67 | return golden_subs 68 | 69 | def search_ngrams(): 70 | pass 71 | 72 | def reverse_link(question, subject): 73 | # get question tokens 74 | tokens = question.split() 75 | 76 | # init default value of returned variables 77 | text_subject = None 78 | text_attention_indices = None 79 | 80 | # query name / alias by node_id (subject) 81 | res_list = virtuoso.id_query_str(subject) 82 | 83 | # sorted by length 84 | for res in sorted(res_list, key = lambda res: len(res), reverse = True): 85 | pattern = r'(^|\s)(%s)($|\s)' % (re.escape(res)) 86 | if re.search(pattern, question): 87 | text_subject = res 88 | text_attention_indices = get_indices(tokens, res.split()) 89 | break 90 | 91 | return text_subject, text_attention_indices 92 | 93 | def form_question_pattern(data): 94 | question_pattern = None 95 | if data.text_attention_indices: 96 | anonymous_tokens = [] 97 | tokens = data.question.split() 98 | anonymous_tokens.extend(tokens[:data.text_attention_indices[0]]) 99 | anonymous_tokens.append('X') 100 | anonymous_tokens.extend(tokens[data.text_attention_indices[-1]+1:]) 101 | question_pattern = ' '.join(anonymous_tokens) 102 | 103 | return question_pattern 104 | 105 | def knowledge_graph_attributes(data_list, pid = 0): 106 | # Open log file 107 | log_file = open('logs/log.%s.%d.txt'%(split, pid), 'w') 108 | log_file.write('total length: %d\n' %(len(data_list))) 109 | 110 | succ_att_link = 0 111 | qadata_list = [] 112 | for data_index, data_tuple in enumerate(data_list): 113 | # Step-1: create QAData instance 114 | data = QAData(data_tuple) 115 | 116 | # Step-2: reverse linking 117 | data.text_subject, data.text_attention_indices = reverse_link(data.question, data.subject) 118 | 119 | # Step-3: create question pattern 120 | data.question_pattern = form_question_pattern(data) 121 | 122 | qadata_list.append(data) 123 | 124 | # logging 125 | if data.text_subject: 126 | succ_att_link += 1 127 | # log_file.write('[%d] attention: %f\n' % (data_index, succ_att_link / float(data_index+1))) 128 | log_file.write('[%d] %s\t%s\t%s\n' %(data_index, data.question, data.subject, data.text_subject)) 129 | 130 | pickle.dump(qadata_list, open('temp.%s.pkl'%(pid), 'wb')) 131 | log_file.write('total: %d\tattack: %d\trate: %f\n' % (len(data_list), succ_att_link, 132 | succ_att_link/len(data_list))) 133 | log_file.close() 134 | 135 | def process(num_process, data_list): 136 | # Make dir 137 | if not os.path.exists('logs'): 138 | os.mkdir('logs') 139 | 140 | # Split workload 141 | length = len(data_list) 142 | data_per_p = (length + num_process - 1) // num_process 143 | 144 | # Spawn processes 145 | processes = [ 146 | mp.Process( 147 | target = knowledge_graph_attributes, 148 | args = ( 149 | data_list[pid*data_per_p:(pid+1)*data_per_p], 150 | pid 151 | ) 152 | ) 153 | for pid in range(num_process) 154 | ] 155 | 156 | # Run processes 157 | for p in processes: 158 | p.start() 159 | 160 | # Exit the completed processes 161 | for p in processes: 162 | p.join() 163 | 164 | if __name__ == '__main__': 165 | 166 | if len(sys.argv) != 3: 167 | print('python preprocess.py input_file num_process') 168 | sys.exit(-1) 169 | 170 | in_file_path = sys.argv[1] 171 | num_process = int(sys.argv[2]) 172 | 173 | split = in_file_path.split('_')[-1].split('.')[0] 174 | 175 | in_file = io.open(in_file_path, 'r', encoding='utf8') 176 | 177 | data_list = [] 178 | for line in in_file: 179 | question, sub, rel, obj, length = extract(line) 180 | data_list.append((question, sub, rel, obj, length)) 181 | 182 | process(num_process, sorted(data_list, key = lambda data: data[-1], reverse = True)) 183 | 184 | # Merge all data [this will preserve the order] 185 | new_data_list = [] 186 | for p in range(num_process): 187 | temp_fn = 'temp.%d.pkl'%(p) 188 | new_data_list.extend(pickle.load(open(temp_fn, 'rb'))) 189 | os.remove(temp_fn) 190 | 191 | pickle.dump(new_data_list, open('QAData.%s.pkl'%(split), 'wb')) 192 | 193 | in_file.close() 194 | -------------------------------------------------------------------------------- /data_setup.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # 1. download SimpleQuestionv2 4 | echo "Downloading SimpleQuestions dataset...\n" 5 | wget https://www.dropbox.com/s/tohrsllcfy7rch4/SimpleQuestions_v2.tgz 6 | 7 | echo "Unzipping SimpleQuestions dataset...\n" 8 | tar -xvzf SimpleQuestions_v2.tgz 9 | mv SimpleQuestions_v2 freebase_data/ 10 | rm SimpleQuestions_v2.tgz 11 | 12 | echo "Downloading the names file...\n" 13 | wget https://www.dropbox.com/s/yqbesl07hsw297w/FB5M.name.txt 14 | mv FB5M.name.txt freebase_data/dump_virtuoso_data/ 15 | 16 | # 2. create KB data 17 | echo "\n\nCreate KB data...\n" 18 | cd freebase_data 19 | python convert.py SimpleQuestions_v2/freebase-subsets/freebase-FB2M.txt 20 | mv FB2M.core.txt dump_virtuoso_data/ 21 | 22 | # 3. load data into knowledge base 23 | echo "\n\nload data into knowledge base...\n" 24 | isql-vt 1111 dba dba exec="ld_dir_all('`pwd`/dump_virtuoso_data', '*', 'fb2m:');" 25 | pids=() 26 | for i in `seq 1 4`; do 27 | isql-vt 1111 dba dba exec="rdf_loader_run();" & 28 | pids+=($!) 29 | done 30 | for pid in ${pids[@]}; do 31 | wait $pid 32 | done 33 | 34 | # 4. preprocess training data 35 | cd ../data 36 | 37 | echo "\n\npreprocess training data...\n" 38 | python process_rawdata.py ../freebase_data/SimpleQuestions_v2/annotated_fb_data_train.txt 5 39 | python process_rawdata.py ../freebase_data/SimpleQuestions_v2/annotated_fb_data_valid.txt 2 40 | python process_rawdata.py ../freebase_data/SimpleQuestions_v2/annotated_fb_data_test.txt 5 41 | 42 | # 5. create Vocabs 43 | cd ../vocab 44 | 45 | echo "\n\nDownloading Embeddings...\n" 46 | wget https://nlp.stanford.edu/data/wordvecs/glove.42B.300d.zip 47 | unzip glove.42B.300d.zip 48 | rm glove.42B.300d.zip 49 | 50 | echo "create vocabs...\n" 51 | python create_vocab.py 52 | 53 | # 6. create training data 54 | echo "\n\nCreating training data for entity detection...\n" 55 | cd ../entity_detection 56 | python seqLabelingLoader.py 57 | 58 | echo "\n\nDONE!" 59 | -------------------------------------------------------------------------------- /entity_detection/args.py: -------------------------------------------------------------------------------- 1 | import os 2 | from argparse import ArgumentParser 3 | 4 | def get_args(): 5 | parser = ArgumentParser(description='kbqa-FB model') 6 | parser.add_argument('--epochs', type=int, default=30) 7 | parser.add_argument('--batch_size', type=int, default=128) 8 | parser.add_argument('--rnn_type', type=str, default='lstm') # or use 'gru' 9 | parser.add_argument('--d_embed', type=int, default=300) 10 | parser.add_argument('--d_hidden', type=int, default=200) 11 | parser.add_argument('--n_layers', type=int, default=2) 12 | parser.add_argument('--lr', type=float, default=1e-4) 13 | parser.add_argument('--test', action='store_true', dest='test', help='get the testing set result') 14 | parser.add_argument('--dev', action='store_true', dest='dev', help='get the development set result') 15 | parser.add_argument('--not_bidirectional', action='store_false', dest='birnn') 16 | parser.add_argument('--clip_gradient', type=float, default=0.6, help='gradient clipping') 17 | parser.add_argument('--log_every', type=int, default=300) 18 | parser.add_argument('--dev_every', type=int, default=900) 19 | parser.add_argument('--save_every', type=int, default=5000) 20 | parser.add_argument('--dropout_prob', type=float, default=0.5) 21 | parser.add_argument('--patience', type=int, default=5, help="number of epochs to wait before early stopping") 22 | parser.add_argument('--no_cuda', action='store_false', help='do not use CUDA', dest='cuda') 23 | parser.add_argument('--gpu', type=int, default=0, help='GPU device to use') # use -1 for CPU 24 | parser.add_argument('--seed', type=int, default=1111, help='random seed for reproducing results') 25 | parser.add_argument('--save_path', type=str, default='saved_checkpoints') 26 | parser.add_argument('--vocab_file', type=str, default='../vocab/vocab.word&rel.pt') 27 | parser.add_argument('--word_vectors', type=str, default='../vocab/glove.42B.300d.txt') 28 | parser.add_argument('--vector_cache', type=str, default=os.path.join(os.getcwd(), '../input_vectors.pt')) 29 | parser.add_argument('--word_normalize', action='store_true') 30 | parser.add_argument('--train_embed', action='store_false', dest='fix_emb') # fine-tune the word embeddings 31 | parser.add_argument('--resume_snapshot', type=str, default=None) 32 | parser.add_argument('--train_file', type=str, default='data/train.entity_detection.pt') 33 | parser.add_argument('--valid_file', type=str, default='data/valid.entity_detection.pt') 34 | parser.add_argument('--test_file', type=str, default='data/test.entity_detection.pt') 35 | # added for testing 36 | parser.add_argument('--trained_model', type=str, default='') 37 | parser.add_argument('--results_path', type=str, default='results') 38 | parser.add_argument('--save_qadata', action='store_true') 39 | args = parser.parse_args() 40 | return args 41 | -------------------------------------------------------------------------------- /entity_detection/evaluation.py: -------------------------------------------------------------------------------- 1 | 2 | def get_span(label): 3 | span = [] 4 | st = 0 5 | en = 0 6 | flag = False 7 | for k in range(len(label)): 8 | if label[k] == 1 and flag == False: 9 | flag = True 10 | st = k 11 | if label[k] != 1 and flag == True: 12 | flag = False 13 | en = k 14 | span.append((st, en)) 15 | st = 0 16 | en = 0 17 | if st != 0 and en == 0: 18 | en = k 19 | span.append((st, en)) 20 | return span 21 | 22 | def evaluation(gold, pred): 23 | right = 0 24 | predicted = 0 25 | total_en = 0 26 | for i in range(len(gold)): 27 | gold_batch = gold[i] 28 | pred_batch = pred[i] 29 | for j in range(len(gold_batch)): 30 | gold_label = gold_batch[j] 31 | pred_label = pred_batch[j] 32 | gold_span = get_span(gold_label) 33 | pred_span = get_span(pred_label) 34 | total_en += len(gold_span) 35 | predicted += len(pred_span) 36 | for item in pred_span: 37 | if item in gold_span: 38 | right += 1 39 | if predicted == 0: 40 | precision = 0 41 | else: 42 | precision = right / predicted 43 | if total_en == 0: 44 | recall = 0 45 | else: 46 | recall = right / total_en 47 | if precision + recall == 0: 48 | f1 = 0 49 | else: 50 | f1 = 2 * precision * recall / (precision + recall) 51 | return precision, recall, f1 52 | -------------------------------------------------------------------------------- /entity_detection/model.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | from torch import autograd 3 | import torch.nn.functional as F 4 | import torch 5 | import sys 6 | sys.path.append('../tools') 7 | from embedding import Embeddings 8 | 9 | class EntityDetection(nn.Module): 10 | 11 | def __init__(self, dicts, config): 12 | super(EntityDetection, self).__init__() 13 | self.config = config 14 | self.embed = Embeddings(word_vec_size=config.d_embed, dicts=dicts) 15 | if self.config.rnn_type.lower() == 'gru': 16 | self.rnn = nn.GRU(input_size=config.d_embed, hidden_size=config.d_hidden, 17 | num_layers=config.n_layers, dropout=config.dropout_prob, 18 | bidirectional=config.birnn) 19 | else: 20 | self.rnn = nn.LSTM(input_size=config.d_embed, hidden_size=config.d_hidden, 21 | num_layers=config.n_layers, dropout=config.dropout_prob, 22 | bidirectional=config.birnn) 23 | 24 | self.dropout = nn.Dropout(p=config.dropout_prob) 25 | self.relu = nn.ReLU() 26 | seq_in_size = config.d_hidden 27 | if self.config.birnn: 28 | seq_in_size *= 2 29 | 30 | self.hidden2tag = nn.Sequential( 31 | nn.Linear(seq_in_size, seq_in_size), 32 | nn.BatchNorm1d(seq_in_size), 33 | self.relu, 34 | self.dropout, 35 | nn.Linear(seq_in_size, config.n_out) 36 | ) 37 | 38 | def forward(self, batch): 39 | # shape of batch (sequence length, batch size) 40 | inputs = self.embed.forward(batch[0]) # shape (sequence length, batch_size, dimension of embedding) 41 | batch_size = inputs.size()[1] 42 | state_shape = self.config.n_cells, batch_size, self.config.d_hidden 43 | if self.config.rnn_type.lower() == 'gru': 44 | h0 = autograd.Variable(inputs.data.new(*state_shape).zero_()) 45 | outputs, ht = self.rnn(inputs, h0) 46 | else: 47 | h0 = c0 = autograd.Variable(inputs.data.new(*state_shape).zero_()) 48 | outputs, (ht, ct) = self.rnn(inputs, (h0, c0)) 49 | # shape of `outputs` - (sequence length, batch size, hidden size X num directions) 50 | tags = self.hidden2tag(outputs.view(-1, outputs.size(2))) 51 | scores = F.log_softmax(tags) 52 | return scores 53 | 54 | -------------------------------------------------------------------------------- /entity_detection/predict.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import numpy as np 4 | import torch 5 | import pickle 6 | 7 | from args import get_args 8 | from model import EntityDetection 9 | from evaluation import evaluation 10 | from seqLabelingLoader import SeqLabelingLoader 11 | sys.path.append('../tools') 12 | import virtuoso 13 | 14 | # please set the configuration in the file : args.py 15 | args = get_args() 16 | # set the random seed for reproducibility 17 | torch.manual_seed(args.seed) 18 | if not args.cuda: 19 | args.gpu = -1 20 | if torch.cuda.is_available() and args.cuda: 21 | print("Note: You are using GPU for training") 22 | torch.cuda.set_device(args.gpu) 23 | torch.cuda.manual_seed(args.seed) 24 | if torch.cuda.is_available() and not args.cuda: 25 | print("Warning: You have Cuda but do not use it. You are using CPU for training") 26 | 27 | 28 | if not args.trained_model: 29 | print("ERROR: You need to provide a option 'trained_model' path to load the model.") 30 | sys.exit(1) 31 | 32 | # load word vocab for questions 33 | word_vocab = torch.load(args.vocab_file) 34 | print('load word vocab, size: %s' % len(word_vocab)) 35 | 36 | os.makedirs(args.results_path, exist_ok=True) 37 | 38 | # load the model 39 | model = torch.load(args.trained_model, map_location=lambda storage,location: storage.cuda(args.gpu)) 40 | 41 | def predict(dataset=args.test_file, tp='test', save_qadata=args.save_qadata): 42 | # load QAdata 43 | qa_data_path = '../data/QAData.%s.pkl' % tp 44 | qa_data = pickle.load(open(qa_data_path,'rb')) 45 | 46 | # load batch data for predict 47 | data_loader = SeqLabelingLoader(dataset, args.gpu) 48 | print('load %s data, batch_num: %d\tbatch_size: %d' 49 | %(tp, data_loader.batch_num, data_loader.batch_size)) 50 | 51 | model.eval(); 52 | 53 | n_correct = 0 54 | n_correct_sub = 0 55 | n_correct_extend = 0 56 | n_empty = 0 57 | linenum = 1 58 | qa_data_idx = 0 59 | 60 | new_qa_data = [] 61 | 62 | gold_list = [] 63 | pred_list = [] 64 | 65 | for data_batch_idx, data_batch in enumerate(data_loader.next_batch(shuffle=False)): 66 | if data_batch_idx % 50 == 0: 67 | print(tp, data_batch_idx) 68 | scores = model(data_batch) 69 | n_correct += ((torch.max(scores, 1)[1].view(data_batch[1].size()).data == 70 | data_batch[1].data).sum(dim=0) == data_batch[1].size()[0]).sum() 71 | 72 | index_tag = np.transpose(torch.max(scores, 1)[1].view(data_batch[1].size()).cpu().data.numpy()) 73 | gold_tag = np.transpose(data_batch[1].cpu().data.numpy()) 74 | index_question = np.transpose(data_batch[0].cpu().data.numpy()) 75 | 76 | gold_list.append(np.transpose(data_batch[1].cpu().data.numpy())) 77 | pred_list.append(index_tag) 78 | 79 | for i in range(data_loader.batch_size): 80 | while qa_data_idx < len(qa_data) and not qa_data[qa_data_idx].text_subject: 81 | qa_data_idx += 1 82 | if qa_data_idx >= len(qa_data): 83 | break 84 | _qa_data = qa_data[qa_data_idx] 85 | tokens = np.array(_qa_data.question.split()) 86 | pred_text = ' '.join(tokens[np.where(index_tag[i][:len(tokens)])]) 87 | 88 | pred_sub, pred_sub_extend = get_candidate_sub(tokens, index_tag[i]) 89 | if _qa_data.subject in pred_sub: 90 | n_correct_sub += 1 91 | if _qa_data.subject in pred_sub_extend: 92 | n_correct_extend += 1 93 | if not pred_sub_extend: 94 | n_empty += 1 95 | 96 | if save_qadata: 97 | for sub in pred_sub_extend: 98 | rel = virtuoso.id_query_out_rel(sub) 99 | _qa_data.add_candidate(sub, rel) 100 | if hasattr(_qa_data, 'cand_rel'): 101 | _qa_data.remove_duplicate() 102 | 103 | # if _qa_data.subject not in pred_sub_extend: 104 | # _qa_data.neg_rel = virtuoso.id_query_out_rel(_qa_data.subject) 105 | 106 | new_qa_data.append((_qa_data, len(_qa_data.question_pattern.split()))) 107 | 108 | linenum += 1 109 | qa_data_idx += 1 110 | 111 | total = linenum-1 112 | accuracy = 100. * n_correct / total 113 | print("%s\taccuracy: %8.6f\tcorrect: %d\ttotal: %d" %(tp, accuracy, n_correct, total)) 114 | P, R, F = evaluation(gold_list, pred_list) 115 | print("Precision: {:10.6f}% Recall: {:10.6f}% F1 Score: {:10.6f}%".format(100. * P, 100. * R, 100. * F)) 116 | 117 | sub_accuracy = 100. * n_correct_sub / total 118 | print('subject accuracy: %8.6f\tcorrect: %d\ttotal:%d' %(sub_accuracy, n_correct_sub, total)) 119 | 120 | extend_accuracy = 100. * n_correct_extend / total 121 | print('extend accuracy: %8.6f\tcorrect: %d\ttotal:%d' %(extend_accuracy, n_correct_extend, total)) 122 | 123 | print('suject not found: %8.6f\t%d' %(n_empty/total, n_empty)) 124 | print("-" * 80) 125 | 126 | if save_qadata: 127 | qadata_save_path = open(os.path.join(args.results_path, 'QAData.label.%s.pkl' %(tp)), 'wb') 128 | data_list = [data[0] for data in sorted(new_qa_data, key = lambda data: data[1], 129 | reverse=True)] 130 | pickle.dump(data_list, qadata_save_path) 131 | 132 | def get_candidate_sub(question_tokens, pred_tag): 133 | flag = False 134 | starts = [] 135 | ends = [] 136 | for i, tag in enumerate(pred_tag): 137 | if tag==1 and not flag: 138 | starts.append(i) 139 | flag = True 140 | elif tag==0 and flag: 141 | if (i+1 < len(question_tokens) and pred_tag[i+1]==0) or i+1==len(question_tokens): 142 | ends.append(i-1) 143 | flag = False 144 | if flag: 145 | ends.append(len(question_tokens)-1) 146 | 147 | sub_list = [] 148 | shift = [0,-1,1,-2,2] 149 | pred_sub = [] 150 | for left in shift: 151 | for right in shift: 152 | for i in range(len(starts)): 153 | if starts[i]+left < 0:continue 154 | if ends[i]+1+right > len(question_tokens):continue 155 | text = question_tokens[starts[i]+left:ends[i]+1+right] 156 | subject = virtuoso.str_query_id(' '.join(text)) 157 | # print(text, subject) 158 | if left==0 and right==0: 159 | pred_sub = subject 160 | sub_list.extend(subject) 161 | if sub_list: 162 | return pred_sub, sub_list 163 | return pred_sub, sub_list 164 | 165 | # run the model on the dev set and write the output to a file 166 | predict(args.valid_file, "valid") 167 | 168 | # run the model on the test set and write the output to a file 169 | predict(args.test_file, "test") 170 | 171 | # run the model on the train set and write the output to a file 172 | predict(args.train_file, 'train') 173 | -------------------------------------------------------------------------------- /entity_detection/process.sh: -------------------------------------------------------------------------------- 1 | python train.py \ 2 | --save_path saved_checkpoints \ 3 | --gpu 0 \ 4 | --d_hidden 175 \ 5 | --dropout_prob 0.36 \ 6 | --lr 0.00047992167761529993 \ 7 | --n_layers 3 8 | -------------------------------------------------------------------------------- /entity_detection/seqLabelingLoader.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | #-*- coding: utf-8 -*- 3 | 4 | # Author: QuYingqi 5 | # mail: cookiequ17@hotmail.com 6 | # Created Time: 2017-11-06 7 | import sys, os, io 8 | import pickle 9 | import numpy as np 10 | import torch 11 | from torch.autograd import Variable 12 | sys.path.append('../vocab') 13 | sys.path.append('../tools') 14 | 15 | def create_seq_labeling_data(batch_size, qa_data, word_vocab, NoneLabel=0, TrueLabel=1): 16 | file_type = qa_data.split('.')[-2] 17 | log_file = open('data/%s.entity_detection.txt' %file_type, 'w') 18 | seqs = [] 19 | seq_labels = [] 20 | batch_index = -1 # the index of sequence batches 21 | seq_index = 0 # sequence index within each batch 22 | pad_index = word_vocab.lookup(word_vocab.pad_token) 23 | 24 | data_list = pickle.load(open(qa_data, 'rb')) 25 | for data in data_list: 26 | if not data.text_attention_indices: 27 | continue 28 | 29 | tokens = data.question.split() 30 | labels = data.text_attention_indices 31 | log_file.write('%s\t%s\n' %(data.question, ' '.join(tokens[labels[0]:labels[-1]+1]))) 32 | 33 | if seq_index % batch_size == 0: 34 | seq_index = 0 35 | batch_index += 1 36 | seqs.append(torch.LongTensor(len(tokens), batch_size).fill_(pad_index)) 37 | seq_labels.append(torch.LongTensor(len(tokens), batch_size).fill_(NoneLabel)) 38 | 39 | seqs[batch_index][0:len(tokens),seq_index] = torch.LongTensor(word_vocab.convert_to_index(tokens)) 40 | seq_labels[batch_index][labels[0]:labels[-1]+1, seq_index] = TrueLabel 41 | seq_index += 1 42 | 43 | torch.save((seqs, seq_labels), 'data/%s.entity_detection.pt' %file_type) 44 | 45 | 46 | class SeqLabelingLoader(): 47 | 48 | def __init__(self, infile, device=-1): 49 | self.seqs, self.seq_labels = torch.load(infile) 50 | self.batch_size = self.seqs[0].size(1) 51 | self.batch_num = len(self.seqs) 52 | 53 | if device >= 0: 54 | for i in range(self.batch_num): 55 | self.seqs[i] = Variable(self.seqs[i].cuda(device)) 56 | self.seq_labels[i] = Variable(self.seq_labels[i].cuda(device)) 57 | 58 | def next_batch(self, shuffle = True): 59 | if shuffle: 60 | indices = torch.randperm(self.batch_num) 61 | else: 62 | indices = range(self.batch_num) 63 | for i in indices: 64 | yield self.seqs[i], self.seq_labels[i] 65 | 66 | if __name__ == '__main__': 67 | if not os.path.exists('data'): 68 | os.mkdir('data') 69 | 70 | word_vocab = torch.load('../vocab/vocab.word&rel.pt') 71 | create_seq_labeling_data(128, '../data/QAData.valid.pkl', word_vocab) 72 | create_seq_labeling_data(128, '../data/QAData.train.pkl', word_vocab) 73 | create_seq_labeling_data(128, '../data/QAData.test.pkl', word_vocab) 74 | -------------------------------------------------------------------------------- /entity_detection/train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.optim as optim 3 | import torch.nn as nn 4 | import time 5 | import os, sys 6 | import glob 7 | import numpy as np 8 | 9 | from args import get_args 10 | from model import EntityDetection 11 | from evaluation import evaluation 12 | from seqLabelingLoader import SeqLabelingLoader 13 | 14 | # please set the configuration in the file : args.py 15 | args = get_args() 16 | # set the random seed for reproducibility 17 | torch.manual_seed(args.seed) 18 | if not args.cuda: 19 | args.gpu = -1 20 | if torch.cuda.is_available() and args.cuda: 21 | print("Note: You are using GPU for training") 22 | torch.cuda.set_device(args.gpu) 23 | torch.cuda.manual_seed(args.seed) 24 | if torch.cuda.is_available() and not args.cuda: 25 | print("Warning: You have Cuda but do not use it. You are using CPU for training") 26 | 27 | 28 | # load data 29 | train_loader = SeqLabelingLoader(args.train_file, args.gpu) 30 | print('load train data, batch_num: %d\tbatch_size: %d' 31 | %(train_loader.batch_num, train_loader.batch_size)) 32 | valid_loader = SeqLabelingLoader(args.valid_file, args.gpu) 33 | print('load valid data, batch_num: %d\tbatch_size: %d' 34 | %(valid_loader.batch_num, valid_loader.batch_size)) 35 | 36 | # load word vocab for questions 37 | word_vocab = torch.load(args.vocab_file) 38 | print('load word vocab, size: %s' % len(word_vocab)) 39 | 40 | os.makedirs(args.save_path, exist_ok=True) 41 | 42 | # define models 43 | config = args 44 | config.n_out = 2 # I/in entity O/out of entity 45 | config.n_cells = config.n_layers 46 | 47 | if config.birnn: 48 | config.n_cells *= 2 49 | print(config) 50 | 51 | if args.resume_snapshot: 52 | model = torch.load(args.resume_snapshot, map_location=lambda storage, location: storage) 53 | else: 54 | model = EntityDetection(word_vocab, config) 55 | if args.word_vectors: 56 | if os.path.isfile(args.vector_cache): 57 | pretrained = torch.load(args.vector_cache) 58 | model.embed.word_lookup_table.weight.data.copy_(pretrained) 59 | else: 60 | pretrained = model.embed.load_pretrained_vectors(args.word_vectors, binary=False, 61 | normalize=args.word_normalize) 62 | torch.save(pretrained, args.vector_cache) 63 | print('load pretrained word vectors from %s, pretrained size: %s' %(args.word_vectors, 64 | pretrained.size())) 65 | if args.cuda: 66 | model.cuda() 67 | print("Shift model to GPU") 68 | 69 | # show model parameters 70 | for name, param in model.named_parameters(): 71 | print(name, param.size()) 72 | 73 | criterion = nn.NLLLoss() # negative log likelyhood loss function 74 | optimizer = optim.Adam(model.parameters(), lr=args.lr) 75 | 76 | # train the model 77 | iterations = 0 78 | start = time.time() 79 | best_dev_acc = 0 80 | best_dev_F = 0 81 | num_iters_in_epoch = train_loader.batch_num 82 | patience = args.patience * num_iters_in_epoch # for early stopping 83 | iters_not_improved = 0 # this parameter is used for stopping early 84 | early_stop = False 85 | header = ' Time Epoch Iteration Progress (%Epoch) Loss Accuracy Dev/Accuracy' 86 | dev_log_template = ' '.join('{:>6.0f},{:>5.0f},{:>9.0f},{:>5.0f}/{:<5.0f} {:>7.0f}%,{:>8.6f},{:12.4f},{:12.4f}'.split(',')) 87 | log_template = ' '.join('{:>6.0f},{:>5.0f},{:>9.0f},{:>5.0f}/{:<5.0f} {:>7.0f}%,{:>8.6f},{:12.4f},{}'.split(',')) 88 | best_snapshot_prefix = os.path.join(args.save_path, 'best_snapshot') 89 | print(header) 90 | 91 | 92 | for epoch in range(1, args.epochs+1): 93 | if early_stop: 94 | print("Early stopping. Epoch: {}, Best Dev. Acc: {}".format(epoch, best_dev_acc)) 95 | break 96 | 97 | n_correct, n_total = 0, 0 98 | 99 | for batch_idx, batch in enumerate(train_loader.next_batch()): 100 | iterations += 1 101 | label = batch[1] 102 | model.train(); 103 | optimizer.zero_grad() 104 | 105 | scores = model(batch) 106 | 107 | n_correct += ((torch.max(scores, 1)[1].view(label.size()).data == label.data).sum(dim=0) \ 108 | == label.size()[0]).sum() 109 | n_total += train_loader.batch_size 110 | train_acc = 100. * n_correct / n_total 111 | 112 | loss = criterion(scores, label.view(-1,1)[:,0]) 113 | loss.backward() 114 | 115 | # clip the gradient 116 | torch.nn.utils.clip_grad_norm(model.parameters(), args.clip_gradient) 117 | optimizer.step() 118 | 119 | # checkpoint model periodically 120 | if iterations % args.save_every == 0: 121 | snapshot_prefix = os.path.join(args.save_path, 'snapshot') 122 | snapshot_path = snapshot_prefix + \ 123 | '_iter_{}_acc_{:.4f}_loss_{:.6f}_model.pt'.format(iterations, train_acc, loss.data[0]) 124 | torch.save(model, snapshot_path) 125 | for f in glob.glob(snapshot_prefix + '*'): 126 | if f != snapshot_path: 127 | os.remove(f) 128 | 129 | # evaluate performance on validation set periodically 130 | if iterations % args.dev_every == 0: 131 | model.eval() 132 | n_dev_correct = 0 133 | 134 | gold_list = [] 135 | pred_list = [] 136 | 137 | for valid_batch_idx, valid_batch in enumerate(valid_loader.next_batch()): 138 | valid_label = valid_batch[1] 139 | answer = model(valid_batch) 140 | n_dev_correct += ((torch.max(answer, 1)[1].view(valid_label.size()).data == \ 141 | valid_label.data).sum(dim=0) == valid_label.size()[0]).sum() 142 | index_tag = np.transpose(torch.max(answer, 1)[1].view(valid_label.size()).cpu().data.numpy()) 143 | gold_list.append(np.transpose(valid_label.cpu().data.numpy())) 144 | pred_list.append(index_tag) 145 | P, R, F = evaluation(gold_list, pred_list) 146 | 147 | 148 | dev_acc = 100. * n_dev_correct / (valid_loader.batch_num*valid_loader.batch_size) 149 | print(dev_log_template.format(time.time() - start, epoch, iterations, 150 | 1 + batch_idx, train_loader.batch_num, 151 | 100. * (1 + batch_idx) / train_loader.batch_num, 152 | loss.data[0], train_acc, dev_acc)) 153 | print("{} Precision: {:10.6f}% Recall: {:10.6f}% F1 Score: {:10.6f}%".format("Dev", 100. * P, 100. * R, 100. * F)) 154 | # update model 155 | if F > best_dev_F: 156 | best_dev_F = F 157 | iters_not_improved = 0 158 | snapshot_path = best_snapshot_prefix + \ 159 | '_iter_{}_devf1_{}_model.pt'.format(iterations, best_dev_F) 160 | 161 | # save model, delete previous 'best_snapshot' files 162 | torch.save(model, snapshot_path) 163 | for f in glob.glob(best_snapshot_prefix + '*'): 164 | if f != snapshot_path: 165 | os.remove(f) 166 | 167 | else: 168 | iters_not_improved += 1 169 | if iters_not_improved > patience: 170 | early_stop = True 171 | break 172 | 173 | # print progress message 174 | elif iterations % args.log_every == 0: 175 | print(log_template.format(time.time()-start, epoch, iterations, 1+batch_idx, 176 | train_loader.batch_num, 100. * (1+batch_idx)/train_loader.batch_num, 177 | loss.data[0], train_acc, ' '*12)) 178 | 179 | -------------------------------------------------------------------------------- /freebase_data/convert.py: -------------------------------------------------------------------------------- 1 | import sys, os 2 | import pickle 3 | 4 | def www2fb(in_str): 5 | out_str = 'fb:%s' % (in_str.split('www.freebase.com/')[-1].replace('/', '.')) 6 | return out_str 7 | 8 | def main(): 9 | in_fn = sys.argv[1] 10 | db = in_fn.split('-')[-1].split('.')[0] 11 | 12 | out_fn = '%s.core.txt' % (db) 13 | ent_fn = '%s.ent.pkl' % (db) 14 | rel_fn = '%s.rel.pkl' % (db) 15 | 16 | ent_dict = {} 17 | rel_dict = {} 18 | triple_dict = {} 19 | 20 | with open(in_fn) as fi: 21 | for line in fi: 22 | fields = line.strip().split('\t') 23 | sub = www2fb(fields[0]) 24 | rel = www2fb(fields[1]) 25 | objs = fields[2].split() 26 | if sub in ent_dict: 27 | ent_dict[sub] += 1 28 | else: 29 | ent_dict[sub] = 1 30 | if rel in rel_dict: 31 | rel_dict[rel] += 1 32 | else: 33 | rel_dict[rel] = 1 34 | for obj in objs: 35 | obj = www2fb(obj) 36 | triple_dict[(sub, rel, obj)] = 1 37 | if obj in ent_dict: 38 | ent_dict[obj] += 1 39 | else: 40 | ent_dict[obj] = 1 41 | 42 | pickle.dump(ent_dict, open(ent_fn, 'wb')) 43 | with open('%s.ent.txt' % (db), 'w') as fo: 44 | for k, v in sorted(ent_dict.items(), key = lambda kv: kv[1], reverse = True): 45 | fo.write(k + '\n') 46 | 47 | pickle.dump(rel_dict, open(rel_fn, 'wb')) 48 | with open('%s.rel.txt' % (db), 'w') as fo: 49 | for k, v in sorted(rel_dict.items(), key = lambda kv: kv[1], reverse = True): 50 | fo.write(k + '\n') 51 | 52 | with open(out_fn, 'w') as fo: 53 | for (sub, rel, obj) in triple_dict.keys(): 54 | fo.write('<%s>\t<%s>\t<%s>\t.\n' % (sub, rel, obj)) 55 | print(len(triple_dict)) 56 | 57 | if __name__ == '__main__': 58 | main() 59 | -------------------------------------------------------------------------------- /freebase_data/dump_virtuoso_data/virtuoso.ini: -------------------------------------------------------------------------------- 1 | ; 2 | ; virtuoso.ini 3 | ; 4 | ; Configuration file for the OpenLink Virtuoso VDBMS Server 5 | ; 6 | ; To learn more about this product, or any other product in our 7 | ; portfolio, please check out our web site at: 8 | ; 9 | ; http://virtuoso.openlinksw.com/ 10 | ; 11 | ; or contact us at: 12 | ; 13 | ; general.information@openlinksw.com 14 | ; 15 | ; If you have any technical questions, please contact our support 16 | ; staff at: 17 | ; 18 | ; technical.support@openlinksw.com 19 | ; 20 | ; 21 | ; Database setup 22 | ; 23 | [Database] 24 | DatabaseFile = /var/lib/virtuoso-opensource-6.1/db/virtuoso.db 25 | ErrorLogFile = /var/lib/virtuoso-opensource-6.1/db/virtuoso.log 26 | LockFile = /var/lib/virtuoso-opensource-6.1/db/virtuoso.lck 27 | TransactionFile = /var/lib/virtuoso-opensource-6.1/db/virtuoso.trx 28 | xa_persistent_file = /var/lib/virtuoso-opensource-6.1/db/virtuoso.pxa 29 | ErrorLogLevel = 7 30 | FileExtend = 200 31 | MaxCheckpointRemap = 2000 32 | Striping = 0 33 | TempStorage = TempDatabase 34 | 35 | [TempDatabase] 36 | DatabaseFile = /var/lib/virtuoso-opensource-6.1/db/virtuoso-temp.db 37 | TransactionFile = /var/lib/virtuoso-opensource-6.1/db/virtuoso-temp.trx 38 | MaxCheckpointRemap = 2000 39 | Striping = 0 40 | 41 | ; 42 | ; Server parameters 43 | ; 44 | [Parameters] 45 | ServerPort = 1111 46 | LiteMode = 0 47 | DisableUnixSocket = 1 48 | DisableTcpSocket = 0 49 | ;SSLServerPort = 2111 50 | ;SSLCertificate = cert.pem 51 | ;SSLPrivateKey = pk.pem 52 | ;X509ClientVerify = 0 53 | ;X509ClientVerifyDepth = 0 54 | ;X509ClientVerifyCAFile = ca.pem 55 | ServerThreads = 20 56 | CheckpointInterval = 60 57 | O_DIRECT = 0 58 | CaseMode = 2 59 | MaxStaticCursorRows = 5000 60 | CheckpointAuditTrail = 0 61 | AllowOSCalls = 0 62 | SchedulerInterval = 10 63 | DirsAllowed = ., /usr/share/virtuoso-opensource-6.1/vad ,XXX/freebase_data/dump_virtuoso_data/ 64 | ThreadCleanupInterval = 0 65 | ThreadThreshold = 10 66 | ResourcesCleanupInterval = 0 67 | FreeTextBatchSize = 100000 68 | SingleCPU = 0 69 | VADInstallDir = /usr/share/virtuoso-opensource-6.1/vad/ 70 | PrefixResultNames = 0 71 | RdfFreeTextRulesSize = 100 72 | IndexTreeMaps = 256 73 | MaxMemPoolSize = 200000000 74 | PrefixResultNames = 0 75 | MacSpotlight = 0 76 | IndexTreeMaps = 64 77 | ;; 78 | ;; When running with large data sets, one should configure the Virtuoso 79 | ;; process to use between 2/3 to 3/5 of free system memory and to stripe 80 | ;; storage on all available disks. 81 | ;; 82 | ;; Uncomment next two lines if there is 2 GB system memory free 83 | ; NumberOfBuffers = 170000 84 | ; MaxDirtyBuffers = 130000 85 | ;; Uncomment next two lines if there is 4 GB system memory free 86 | ; NumberOfBuffers = 340000 87 | ; MaxDirtyBuffers = 250000 88 | ;; Uncomment next two lines if there is 8 GB system memory free 89 | ; NumberOfBuffers = 680000 90 | ; MaxDirtyBuffers = 500000 91 | ;; Uncomment next two lines if there is 16 GB system memory free 92 | ; NumberOfBuffers = 1360000 93 | ; MaxDirtyBuffers = 1000000 94 | ;; Uncomment next two lines if there is 32 GB system memory free 95 | ; NumberOfBuffers = 2720000 96 | ; MaxDirtyBuffers = 2000000 97 | ;; Uncomment next two lines if there is 48 GB system memory free 98 | ; NumberOfBuffers = 4000000 99 | ; MaxDirtyBuffers = 3000000 100 | ;; Uncomment next two lines if there is 64 GB system memory free 101 | ; NumberOfBuffers = 5450000 102 | ; MaxDirtyBuffers = 4000000 103 | ;; 104 | ;; Note the default settings will take very little memory 105 | ;; but will not result in very good performance 106 | ;; 107 | NumberOfBuffers = 5450000 108 | MaxDirtyBuffers = 4000000 109 | 110 | [HTTPServer] 111 | ServerPort = 8890 112 | ServerRoot = /var/lib/virtuoso-opensource-6.1/vsp 113 | ServerThreads = 20 114 | DavRoot = DAV 115 | EnabledDavVSP = 0 116 | HTTPProxyEnabled = 0 117 | TempASPXDir = 0 118 | DefaultMailServer = localhost:25 119 | ServerThreads = 10 120 | MaxKeepAlives = 10 121 | KeepAliveTimeout = 10 122 | MaxCachedProxyConnections = 10 123 | ProxyConnectionCacheTimeout = 15 124 | HTTPThreadSize = 280000 125 | HttpPrintWarningsInOutput = 0 126 | Charset = UTF-8 127 | ;HTTPLogFile = logs/http.log 128 | 129 | [AutoRepair] 130 | BadParentLinks = 0 131 | 132 | [Client] 133 | SQL_PREFETCH_ROWS = 100 134 | SQL_PREFETCH_BYTES = 16000 135 | SQL_QUERY_TIMEOUT = 0 136 | SQL_TXN_TIMEOUT = 0 137 | ;SQL_NO_CHAR_C_ESCAPE = 1 138 | ;SQL_UTF8_EXECS = 0 139 | ;SQL_NO_SYSTEM_TABLES = 0 140 | ;SQL_BINARY_TIMESTAMP = 1 141 | ;SQL_ENCRYPTION_ON_PASSWORD = -1 142 | 143 | [VDB] 144 | ArrayOptimization = 0 145 | NumArrayParameters = 10 146 | VDBDisconnectTimeout = 1000 147 | KeepConnectionOnFixedThread = 0 148 | 149 | [Replication] 150 | ServerName = db 151 | ServerEnable = 1 152 | QueueMax = 50000 153 | 154 | ; 155 | ; Striping setup 156 | ; 157 | ; These parameters have only effect when Striping is set to 1 in the 158 | ; [Database] section, in which case the DatabaseFile parameter is ignored. 159 | ; 160 | ; With striping, the database is spawned across multiple segments 161 | ; where each segment can have multiple stripes. 162 | ; 163 | ; Format of the lines below: 164 | ; Segment = , [, .. ] 165 | ; 166 | ; must be ordered from 1 up. 167 | ; 168 | ; The is the total size of the segment which is equally divided 169 | ; across all stripes forming the segment. Its specification can be in 170 | ; gigabytes (g), megabytes (m), kilobytes (k) or in database blocks 171 | ; (b, the default) 172 | ; 173 | ; Note that the segment size must be a multiple of the database page size 174 | ; which is currently 8k. Also, the segment size must be divisible by the 175 | ; number of stripe files forming the segment. 176 | ; 177 | ; The example below creates a 200 meg database striped on two segments 178 | ; with two stripes of 50 meg and one of 100 meg. 179 | ; 180 | ; You can always add more segments to the configuration, but once 181 | ; added, do not change the setup. 182 | ; 183 | [Striping] 184 | Segment1 = 100M, db-seg1-1.db, db-seg1-2.db 185 | Segment2 = 100M, db-seg2-1.db 186 | ;... 187 | ;[TempStriping] 188 | ;Segment1 = 100M, db-seg1-1.db, db-seg1-2.db 189 | ;Segment2 = 100M, db-seg2-1.db 190 | ;... 191 | ;[Ucms] 192 | ;UcmPath = 193 | ;Ucm1 = 194 | ;Ucm2 = 195 | ;... 196 | 197 | [Zero Config] 198 | ServerName = virtuoso 199 | ;ServerDSN = ZDSN 200 | ;SSLServerName = 201 | ;SSLServerDSN = 202 | 203 | [Mono] 204 | ;MONO_TRACE = Off 205 | ;MONO_PATH = 206 | ;MONO_ROOT = 207 | ;MONO_CFG_DIR = 208 | ;virtclr.dll = 209 | 210 | [URIQA] 211 | DynamicLocal = 0 212 | DefaultHost = localhost:8890 213 | 214 | [SPARQL] 215 | ;ExternalQuerySource = 1 216 | ;ExternalXsltSource = 1 217 | ;DefaultGraph = http://localhost:8890/dataspace 218 | ;ImmutableGraphs = http://localhost:8890/dataspace 219 | ResultSetMaxRows = 10000 220 | MaxQueryCostEstimationTime = 400 ; in seconds 221 | MaxQueryExecutionTime = 300 ; in seconds 222 | DefaultQuery = select distinct ?Concept where {[] a ?Concept} LIMIT 100 223 | DeferInferenceRulesInit = 0 ; controls inference rules loading 224 | ;PingService = http://rpc.pingthesemanticweb.com/ 225 | 226 | [Plugins] 227 | LoadPath = /usr/lib/virtuoso-opensource-6.1/hosting 228 | Load1 = plain, wikiv 229 | Load2 = plain, mediawiki 230 | Load3 = plain, creolewiki 231 | ;Load4 = plain, im 232 | ;Load5 = plain, wbxml2 233 | ;Load6 = plain, hslookup 234 | ;Load7 = attach, libphp5.so 235 | ;Load8 = Hosting, hosting_php.so 236 | ;Load9 = Hosting,hosting_perl.so 237 | ;Load10 = Hosting,hosting_python.so 238 | ;Load11 = Hosting,hosting_ruby.so 239 | ;Load12 = msdtc,msdtc_sample 240 | -------------------------------------------------------------------------------- /relation_ranking/args.py: -------------------------------------------------------------------------------- 1 | import os 2 | from argparse import ArgumentParser 3 | 4 | def get_args(): 5 | parser = ArgumentParser(description='kbqa-FB model') 6 | parser.add_argument('--epochs', type=int, default=30) 7 | parser.add_argument('--batch_size', type=int, default=64) 8 | parser.add_argument('--lr', type=float, default=1e-3) 9 | parser.add_argument('--clip_gradient', type=float, default=0.6, help='gradient clipping') 10 | parser.add_argument('--dropout_prob', type=float, default=0.3) 11 | parser.add_argument('--word_normalize', action='store_true') 12 | parser.add_argument('--train_embed', action='store_false', dest='fix_emb') # fine-tune the word embeddings 13 | parser.add_argument('--neg_size', type=int, default=50, help='negtive sampling number') 14 | parser.add_argument('--loss_margin', type=float, default=1) 15 | 16 | parser.add_argument('--rnn_type', type=str, default='lstm') # or use 'gru' 17 | parser.add_argument('--not_bidirectional', action='store_false', dest='birnn') 18 | parser.add_argument('--d_word_embed', type=int, default=300) 19 | parser.add_argument('--d_rel_embed', type=int, default=256) 20 | parser.add_argument('--d_hidden', type=int, default=256) 21 | parser.add_argument('--n_layers', type=int, default=2) 22 | 23 | parser.add_argument('--channel_size', type=int, default=8) 24 | parser.add_argument('--conv_kernel_1', type=int, default=3) 25 | parser.add_argument('--conv_kernel_2', type=int, default=3) 26 | parser.add_argument('--pool_kernel_1', type=int, default=3) 27 | parser.add_argument('--pool_kernel_2', type=int, default=3) 28 | parser.add_argument('--rel_maxlen', type=int, default=17) 29 | parser.add_argument('--seq_maxlen', type=int, default=21) 30 | 31 | parser.add_argument('--test', action='store_true', dest='test', help='get the testing set result') 32 | parser.add_argument('--dev', action='store_true', dest='dev', help='get the development set result') 33 | parser.add_argument('--log_every', type=int, default=100) 34 | parser.add_argument('--dev_every', type=int, default=300) 35 | parser.add_argument('--save_every', type=int, default=4500) 36 | parser.add_argument('--patience', type=int, default=5, help="number of epochs to wait before early stopping") 37 | parser.add_argument('--no_cuda', action='store_false', help='do not use CUDA', dest='cuda') 38 | parser.add_argument('--gpu', type=int, default=0, help='GPU device to use') # use -1 for CPU 39 | parser.add_argument('--seed', type=int, default=1111, help='random seed for reproducing results') 40 | 41 | parser.add_argument('--resume_snapshot', type=str, default=None) 42 | parser.add_argument('--save_path', type=str, default='saved_checkpoints') 43 | parser.add_argument('--vocab_file', type=str, default='../vocab/vocab.word&rel.pt') 44 | parser.add_argument('--rel_vocab_file', type=str, default='../vocab/vocab.rel.sep.pt') 45 | parser.add_argument('--word_vectors', type=str, default='../vocab/glove.42B.300d.txt') 46 | parser.add_argument('--vector_cache', type=str, default=os.path.join(os.getcwd(), '../input_vectors.pt')) 47 | parser.add_argument('--train_file', type=str, default='data/train.relation_ranking.pt') 48 | parser.add_argument('--valid_file', type=str, default='data/valid.relation_ranking.pt') 49 | parser.add_argument('--test_file', type=str, default='data/test.relation_ranking.pt') 50 | 51 | # added for testing 52 | parser.add_argument('--trained_model', type=str, default='') 53 | parser.add_argument('--results_path', type=str, default='results') 54 | parser.add_argument('--write_res', action='store_true', help='write predict results to file or not') 55 | parser.add_argument('--write_score', action='store_true') 56 | parser.add_argument('--predict', action='store_true') 57 | args = parser.parse_args() 58 | return args 59 | -------------------------------------------------------------------------------- /relation_ranking/attention.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | import abc 3 | from collections import OrderedDict 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | from torch.autograd import Variable 8 | 9 | class WordSeqAttentionModel(nn.Module): 10 | __metaclass__ = abc.ABCMeta 11 | 12 | def __init__(self, input_size, seq_size): 13 | super(WordSeqAttentionModel, self).__init__() 14 | self.input_size = input_size 15 | self.output_size = seq_size 16 | self.seq_size = seq_size 17 | 18 | @abc.abstractmethod 19 | def _score(self, x, seq): 20 | """ 21 | Using through attention function 22 | :param x: 23 | :param seq: 24 | :return: 25 | """ 26 | pass 27 | 28 | def attention(self, x, seq, lengths=None): 29 | """ 30 | :param x: (batch, dim, ) 31 | :param seq: (batch, length, dim, ) 32 | :param lengths: (batch, ) 33 | :return: weight: (batch, length) 34 | """ 35 | # Check Size 36 | batch_size, input_size = x.size() 37 | seq_batch_size, max_len, seq_size = seq.size() 38 | assert batch_size == seq_batch_size 39 | 40 | score = self._score(x, seq) 41 | 42 | weight = F.softmax(score) 43 | return weight 44 | 45 | def forward(self, x, seq, lengths=None): 46 | """ 47 | :param x: (batch, dim, ) 48 | :param seq: (batch, length, dim, ) 49 | :param lengths: (batch, ) 50 | :return: hidden: (batch, dim) 51 | weight: (batch, length) 52 | """ 53 | # (batch, length) 54 | weight = self.attention(x, seq, lengths) 55 | # (batch, 1, length) bmm (batch, length, dim) -> (batch, 1, dim) -> (batch, dim) 56 | return torch.bmm(weight[:, None, :], seq).squeeze(1), weight 57 | 58 | def check_size(self, x, seq): 59 | batch_size, input_size = x.size() 60 | seq_batch_size, max_len, seq_size = seq.size() 61 | assert batch_size == seq_batch_size 62 | assert input_size == self.input_size 63 | assert seq_size == self.seq_size 64 | 65 | @staticmethod 66 | def expand_x(x, max_len): 67 | """ 68 | :param x: (batch, input_size) 69 | :param max_len: scalar 70 | :return: (batch * max_len, input_size) 71 | """ 72 | batch_size, input_size = x.size() 73 | return torch.unsqueeze(x, 1).expand(batch_size, max_len, input_size).contiguous().view(batch_size * max_len, -1) 74 | 75 | @staticmethod 76 | def pack_seq(seq): 77 | """ 78 | :param seq: (batch_size, max_len, seq_size) 79 | :return: (batch_size * max_len, seq_size) 80 | """ 81 | return seq.view(seq.size(0) * seq.size(1), -1) 82 | 83 | class MLPWordSeqAttention(WordSeqAttentionModel): 84 | def __init__(self, input_size, seq_size, hidden_size=None, activation="Tanh", bias=False): 85 | super(MLPWordSeqAttention, self).__init__(input_size=input_size, seq_size=seq_size) 86 | self.bias = bias 87 | self.hidden_size = hidden_size 88 | if hidden_size is None: 89 | hidden_size = (input_size + seq_size) // 2 90 | component = OrderedDict() 91 | component['layer1'] = nn.Linear(input_size + seq_size, hidden_size, bias=bias) 92 | component['act'] = getattr(nn, activation)() 93 | component['layer2'] = nn.Linear(hidden_size, 1, bias=bias) 94 | self.layer = nn.Sequential(component) 95 | 96 | def _score(self, x, seq): 97 | """ 98 | :param x: (batch, word_dim) 99 | :param seq: (batch, length, seq_dim) 100 | :return: score: (batch, length, ) 101 | """ 102 | self.check_size(x, seq) 103 | 104 | # (batch, word_dim) -> (batch * max_len, word_dim) 105 | _x = self.expand_x(x, max_len=seq.size(1)) 106 | 107 | # (batch, max_len, seq_dim) -> (batch * max_len, seq_dim) 108 | _seq = self.pack_seq(seq) 109 | 110 | # (batch * max_len, word_dim) (batch * max_len, seq_dim) -> (batch * max_len, word_dim + seq_dim) 111 | to_input = torch.cat([_x, _seq], 1) 112 | 113 | # (batch * max_len, word_dim + seq_dim) 114 | # -> (batch * max_len, 1) 115 | # -> (batch * max_len, ) 116 | # -> (batch, max_len) 117 | score = self.layer.forward(to_input).squeeze(-1).view(seq.size(0), seq.size(1)) 118 | 119 | return score 120 | 121 | -------------------------------------------------------------------------------- /relation_ranking/model.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | #-*- coding: utf-8 -*- 3 | 4 | # Author: QuYingqi 5 | # mail: cookiequ17@hotmail.com 6 | # Created Time: 2017-12-15 7 | 8 | from torch import nn 9 | from torch.autograd import Variable 10 | import torch.nn.functional as F 11 | import torch 12 | import numpy as np 13 | import sys 14 | sys.path.append('../tools') 15 | from embedding import Embeddings 16 | from attention import MLPWordSeqAttention 17 | 18 | class RelationRanking(nn.Module): 19 | 20 | def __init__(self, word_vocab, rel_vocab, config): 21 | super(RelationRanking, self).__init__() 22 | self.config = config 23 | rel1_vocab, rel2_vocab = rel_vocab 24 | self.word_embed = Embeddings(word_vec_size=config.d_word_embed, dicts=word_vocab) 25 | self.rel1_embed = Embeddings(word_vec_size=config.d_rel_embed, dicts=rel1_vocab) 26 | self.rel2_embed = Embeddings(word_vec_size=config.d_rel_embed, dicts=rel2_vocab) 27 | 28 | if self.config.rnn_type.lower() == 'gru': 29 | self.rnn = nn.GRU(input_size=config.d_word_embed, hidden_size=config.d_hidden, 30 | num_layers=config.n_layers, dropout=config.dropout_prob, 31 | bidirectional=config.birnn, 32 | batch_first=True) 33 | else: 34 | self.rnn = nn.LSTM(input_size=config.d_word_embed, hidden_size=config.d_hidden, 35 | num_layers=config.n_layers, dropout=config.dropout_prob, 36 | bidirectional=config.birnn, 37 | batch_first=True) 38 | 39 | self.dropout = nn.Dropout(p=config.dropout_prob) 40 | seq_in_size = config.d_hidden 41 | if self.config.birnn: 42 | seq_in_size *= 2 43 | 44 | self.question_attention = MLPWordSeqAttention(input_size=config.d_rel_embed, seq_size=seq_in_size) 45 | 46 | self.bilinear = nn.Bilinear(seq_in_size, config.d_rel_embed, 1, bias=False) 47 | 48 | self.seq_out = nn.Sequential( 49 | self.dropout, 50 | nn.Linear(seq_in_size, config.d_rel_embed) 51 | ) 52 | 53 | self.conv = nn.Sequential( 54 | nn.Conv2d(1, config.channel_size, (config.conv_kernel_1, config.conv_kernel_2), stride=1, 55 | padding=(config.conv_kernel_1//2, config.conv_kernel_2//2)), #channel_in=1, channel_out=8, kernel_size=3*3 56 | nn.ReLU(True)) 57 | 58 | self.seq_maxlen = config.seq_maxlen + (config.conv_kernel_1 + 1) % 2 59 | self.rel_maxlen = config.rel_maxlen + (config.conv_kernel_2 + 1) % 2 60 | 61 | self.pooling = nn.MaxPool2d((config.seq_maxlen, 1), 62 | stride=(config.seq_maxlen, 1), padding=0) 63 | 64 | self.pooling2 = nn.MaxPool2d((1, config.rel_maxlen), 65 | stride=(1, config.rel_maxlen), padding=0) 66 | 67 | self.fc = nn.Sequential( 68 | nn.Linear(config.rel_maxlen * config.channel_size, 20), 69 | nn.ReLU(), 70 | nn.Dropout(p=config.dropout_prob), 71 | nn.Linear(20, 1)) 72 | 73 | self.fc1 = nn.Sequential( 74 | nn.Linear(config.seq_maxlen * config.channel_size, 20), 75 | nn.ReLU(), 76 | nn.Dropout(p=config.dropout_prob), 77 | nn.Linear(20,1)) 78 | 79 | self.fc2 = nn.Sequential( 80 | nn.Linear(4, 1)) 81 | 82 | 83 | def question_encoder(self, inputs): 84 | ''' 85 | :param inputs: (batch, dim1) 86 | ''' 87 | batch_size = inputs.size(0) 88 | state_shape = self.config.n_cells, batch_size, self.config.d_hidden 89 | if self.config.rnn_type.lower() == 'gru': 90 | h0 = Variable(inputs.data.new(*state_shape).zero_()) 91 | outputs, ht = self.rnn(inputs, h0) 92 | else: 93 | h0 = c0 = Variable(inputs.data.new(*state_shape).zero_()) 94 | outputs, (ht, ct) = self.rnn(inputs, (h0, c0)) 95 | # shape of `outputs` - (batch size, sequence length, hidden size X num directions) 96 | outputs.contiguous() 97 | return outputs 98 | 99 | def cal_score(self, outputs, seqs_len, rel_embed, pos=None): 100 | ''' 101 | :param rel_embed: (batch, dim2) or (neg_size, batch, dim2) 102 | return: (batch, 1) 103 | ''' 104 | batch_size = outputs.size(0) 105 | if pos: 106 | neg_size = pos 107 | else: 108 | neg_size, batch_size, embed_size = rel_embed.size() 109 | seq_len, seq_emb_size = outputs.size()[1:] 110 | outputs = outputs.unsqueeze(0).expand(neg_size, batch_size, seq_len, 111 | seq_emb_size).contiguous().view(neg_size*batch_size, seq_len, -1) 112 | rel_embed = rel_embed.view(neg_size * batch_size, -1) 113 | seqs_len = seqs_len.unsqueeze(0).expand(neg_size, batch_size).contiguous().view(neg_size*batch_size) 114 | # `weight` - (batch, length) 115 | seq_att, weight = self.question_attention.forward(rel_embed, outputs) 116 | # `seq_encode` - (batch, hidden size X num directions) 117 | seq_encode = self.seq_out(seq_att) 118 | 119 | # `score` - (batch, 1) or (neg_size * batch, 1) 120 | score = torch.sum(seq_encode * rel_embed, 1, keepdim=True) 121 | 122 | if pos: 123 | score = score.unsqueeze(0).expand(neg_size, batch_size, 1) 124 | else: 125 | score = score.view(neg_size, batch_size, 1) 126 | return score 127 | 128 | def matchPyramid(self, seq, rel, seq_len, rel_len): 129 | ''' 130 | param: 131 | seq: (batch, _seq_len, embed_size) 132 | rel: (batch, _rel_len, embed_size) 133 | seq_len: (batch,) 134 | rel_len: (batch,) 135 | return: 136 | score: (batch, 1) 137 | ''' 138 | batch_size = seq.size(0) 139 | 140 | rel_trans = torch.transpose(rel, 1, 2) 141 | # (batch, 1, seq_len, rel_len) 142 | seq_norm = torch.sqrt(torch.sum(seq*seq, dim=2, keepdim=True)) 143 | rel_norm = torch.sqrt(torch.sum(rel_trans*rel_trans, dim=1, keepdim=True)) 144 | cross = torch.bmm(seq/seq_norm, rel_trans/rel_norm).unsqueeze(1) 145 | 146 | # (batch, channel_size, seq_len, rel_len) 147 | conv1 = self.conv(cross) 148 | channel_size = conv1.size(1) 149 | 150 | # (batch, seq_maxlen) 151 | # (batch, rel_maxlen) 152 | dpool_index1, dpool_index2 = self.dynamic_pooling_index(seq_len, rel_len, self.seq_maxlen, 153 | self.rel_maxlen) 154 | dpool_index1 = dpool_index1.unsqueeze(1).unsqueeze(-1).expand(batch_size, channel_size, 155 | self.seq_maxlen, self.rel_maxlen) 156 | dpool_index2 = dpool_index2.unsqueeze(1).unsqueeze(2).expand_as(dpool_index1) 157 | conv1_expand = torch.gather(conv1, 2, dpool_index1) 158 | conv1_expand = torch.gather(conv1_expand, 3, dpool_index2) 159 | 160 | # (batch, channel_size, p_size1, p_size2) 161 | pool1 = self.pooling(conv1_expand).view(batch_size, -1) 162 | 163 | # (batch, 1) 164 | out = self.fc(pool1) 165 | 166 | pool2 = self.pooling2(conv1_expand).view(batch_size, -1) 167 | out2 = self.fc1(pool2) 168 | 169 | return out, out2 170 | 171 | def dynamic_pooling_index(self, len1, len2, max_len1, max_len2): 172 | def dpool_index_(batch_idx, len1_one, len2_one, max_len1, max_len2): 173 | stride1 = 1.0 * max_len1 / len1_one 174 | stride2 = 1.0 * max_len2 / len2_one 175 | idx1_one = [int(i/stride1) for i in range(max_len1)] 176 | idx2_one = [int(i/stride2) for i in range(max_len2)] 177 | return idx1_one, idx2_one 178 | batch_size = len(len1) 179 | index1, index2 = [], [] 180 | for i in range(batch_size): 181 | idx1_one, idx2_one = dpool_index_(i, len1[i], len2[i], max_len1, max_len2) 182 | index1.append(idx1_one) 183 | index2.append(idx2_one) 184 | index1 = torch.LongTensor(index1) 185 | index2 = torch.LongTensor(index2) 186 | if self.config.cuda: 187 | index1 = index1.cuda() 188 | index2 = index2.cuda() 189 | return Variable(index1), Variable(index2) 190 | 191 | 192 | def forward(self, batch): 193 | # shape of seqs (batch size, sequence length) 194 | seqs, seq_len, pos_rel1, pos_rel2, neg_rel1, neg_rel2, pos_rel, pos_rel_len, neg_rel, neg_rel_len = batch 195 | 196 | # shape (batch_size, sequence length, dimension of embedding) 197 | inputs = self.word_embed.forward(seqs) 198 | outputs = self.question_encoder(inputs) 199 | 200 | # shape (batch_size, dimension of rel embedding) 201 | pos_rel1_embed = self.rel1_embed.word_lookup_table(pos_rel1) 202 | pos_rel2_embed = self.rel2_embed.word_lookup_table(pos_rel2) 203 | pos_rel1_embed = self.dropout(pos_rel1_embed) 204 | pos_rel2_embed = self.dropout(pos_rel2_embed) 205 | # shape (neg_size, batch_size, dimension of rel embedding) 206 | neg_rel1_embed = self.rel1_embed.word_lookup_table(neg_rel1) 207 | neg_rel2_embed = self.rel2_embed.word_lookup_table(neg_rel2) 208 | neg_rel1_embed = self.dropout(neg_rel1_embed) 209 | neg_rel2_embed = self.dropout(neg_rel2_embed) 210 | 211 | neg_size, batch, neg_len = neg_rel.size() 212 | # shape of `score` - (neg_size, batch_size, 1) 213 | pos_score1 = self.cal_score(outputs, seq_len, pos_rel1_embed, neg_size) 214 | pos_score2 = self.cal_score(outputs, seq_len, pos_rel2_embed, neg_size) 215 | neg_score1 = self.cal_score(outputs, seq_len, neg_rel1_embed) 216 | neg_score2 = self.cal_score(outputs, seq_len, neg_rel2_embed) 217 | 218 | # (batch, len, emb_size) 219 | pos_embed = self.word_embed.forward(pos_rel) 220 | # (batch, 20) 221 | pos_score3, pos_score4 = self.matchPyramid(inputs, pos_embed, seq_len, pos_rel_len) 222 | # (neg_size, batch, 20) 223 | pos_score3 = pos_score3.unsqueeze(0).expand(neg_size, batch, pos_score3.size(1)) 224 | pos_score4 = pos_score4.unsqueeze(0).expand(neg_size, batch, pos_score4.size(1)) 225 | 226 | # (neg_size*batch, len, emb_size) 227 | neg_embed = self.word_embed.forward(neg_rel.view(-1, neg_len)) 228 | seqs_embed = inputs.unsqueeze(0).expand(neg_size, batch, inputs.size(1), 229 | inputs.size(2)).contiguous().view(-1, inputs.size(1), inputs.size(2)) 230 | # (neg_size*batch,) 231 | neg_rel_len = neg_rel_len.view(-1) 232 | seq_len = seq_len.unsqueeze(0).expand(neg_size, batch).contiguous().view(-1) 233 | # (neg_size*batch, 20) 234 | neg_score3, neg_score4 = self.matchPyramid(seqs_embed, neg_embed, seq_len, neg_rel_len) 235 | # (neg_size, batch, 20) 236 | neg_score3 = neg_score3.view(neg_size, batch, neg_score3.size(1)) 237 | neg_score4 = neg_score4.view(neg_size, batch, neg_score4.size(1)) 238 | 239 | pos_concat = torch.cat((pos_score1, pos_score2, pos_score3, pos_score4), 2) 240 | neg_concat = torch.cat((neg_score1, neg_score2, neg_score3, neg_score4), 2) 241 | pos_score = self.fc2(pos_concat).squeeze(-1) 242 | neg_score = self.fc2(neg_concat).squeeze(-1) 243 | 244 | return pos_score, neg_score 245 | -------------------------------------------------------------------------------- /relation_ranking/predict.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import numpy as np 4 | import torch 5 | import pickle 6 | 7 | from args import get_args 8 | from seqRankingLoader import * 9 | sys.path.append('../others') 10 | sys.path.append('../tools') 11 | import virtuoso 12 | 13 | # please set the configuration in the file : args.py 14 | args = get_args() 15 | # set the random seed for reproducibility 16 | torch.manual_seed(args.seed) 17 | if not args.cuda: 18 | args.gpu = -1 19 | if torch.cuda.is_available() and args.cuda: 20 | print("Note: You are using GPU for training") 21 | torch.cuda.set_device(args.gpu) 22 | torch.cuda.manual_seed(args.seed) 23 | if torch.cuda.is_available() and not args.cuda: 24 | print("Warning: You have Cuda but do not use it. You are using CPU for training") 25 | 26 | 27 | if not args.trained_model: 28 | print("ERROR: You need to provide a option 'trained_model' path to load the model.") 29 | sys.exit(1) 30 | 31 | # load word vocab for questions, relation vocab for relations 32 | word_vocab = torch.load(args.vocab_file) 33 | print('load word vocab, size: %s' % len(word_vocab)) 34 | rel_vocab = torch.load(args.rel_vocab_file) 35 | print('load relation vocab, size: %s' %len(rel_vocab)) 36 | 37 | os.makedirs(args.results_path, exist_ok=True) 38 | 39 | # load the model 40 | model = torch.load(args.trained_model, map_location=lambda storage,location: storage.cuda(args.gpu)) 41 | 42 | def evaluate(dataset = args.test_file, tp = 'test'): 43 | 44 | # load batch data for predict 45 | data_loader = SeqRankingLoader(dataset, args.gpu) 46 | print('load %s data, batch_num: %d\tbatch_size: %d' 47 | %(tp, data_loader.batch_num, data_loader.batch_size)) 48 | 49 | model.eval(); 50 | n_correct = 0 51 | 52 | for data_batch_idx, data_batch in enumerate(data_loader.next_batch(shuffle=False)): 53 | pos_score1, pos_score2, pos_score3, neg_score1, neg_score2, neg_score3 = model(data_batch) 54 | neg_size, batch_size = pos_score1.size() 55 | n_correct += (torch.sum(torch.gt(pos_score1+pos_score2+pos_score3, 56 | neg_score1+neg_score2+neg_score3), 0).data == 57 | neg_size).sum() 58 | 59 | total = data_loader.batch_num*data_loader.batch_size 60 | accuracy = 100. * n_correct / (total) 61 | print("%s\taccuracy: %8.6f\tcorrect: %d\ttotal: %d" %(tp, accuracy, n_correct, total)) 62 | print("-" * 80) 63 | 64 | def rel_pruned(neg_score, data): 65 | neg_rel = data.cand_rel 66 | pred_rel_scores = sorted(zip(neg_rel, neg_score), key=lambda i:i[1], reverse=True) 67 | pred_rel = pred_rel_scores[0][0] 68 | pred_sub = [] 69 | for i, rels in enumerate(data.sub_rels): 70 | if pred_rel in rels: 71 | pred_sub.append((data.cand_sub[i], len(rels))) 72 | pred_sub = [sub[0] for sub in sorted(pred_sub, key = lambda sub:sub[1], reverse=True)] 73 | return pred_rel, pred_rel_scores, pred_sub 74 | 75 | 76 | def predict(qa_pattern_file, tp): 77 | # load batch data for predict 78 | data_loader = CandidateRankingLoader(qa_pattern_file, word_vocab, rel_vocab, args.gpu) 79 | print('load %s data, batch_num: %d\tbatch_size: %d' %(tp, data_loader.batch_num, 1)) 80 | if args.write_res: 81 | results_file = open(os.path.join(args.results_path, '%s-pred_rel-wrong.txt' %tp), 'w') 82 | results_all_file = open(os.path.join(args.results_path, '%s-results-all.txt' %tp), 'w') 83 | 84 | model.eval() 85 | total = 0 86 | sub_correct = 0 87 | rel_scores = [] 88 | n_correct = 0 89 | n_rel_correct = 0 90 | n_sub_recall = 0 91 | n_single_correct = 0 92 | for data_batch in data_loader.next_question(): 93 | data = data_batch[-1] 94 | total += 1 95 | if data.subject not in data.cand_sub: 96 | continue 97 | sub_correct += 1 98 | 99 | pos_score, neg_score = model(data_batch[:-1]) 100 | neg_score = neg_score.data.squeeze().cpu().numpy() 101 | 102 | if args.write_score: 103 | rel_scores.append((data.cand_rel, data.relation, neg_score)) 104 | 105 | pred_rel, pred_rel_scores, pred_sub = rel_pruned(neg_score, data) 106 | 107 | if pred_rel == data.relation: 108 | n_rel_correct += 1 109 | if data.subject in pred_sub: 110 | n_sub_recall += 1 111 | if pred_sub[0] == data.subject: 112 | n_correct += 1 113 | if len(pred_sub) == 1: 114 | n_single_correct += 1 115 | 116 | if args.write_score: 117 | score_file = open(os.path.join(args.results_path, 'score-rel-%s.pkl' %tp), 'wb') 118 | pickle.dump(rel_scores, score_file) 119 | 120 | accuracy = 100. * n_correct / total 121 | rel_acc = 100. * n_rel_correct / sub_correct 122 | sub_recall = 100. * n_sub_recall / sub_correct 123 | single_acc = 100. * n_single_correct / sub_correct 124 | print("%s\taccuracy: %8.6f\tcorrect: %d\ttotal: %d" %(tp, accuracy, n_correct, total)) 125 | print('rel_acc: ', rel_acc, n_rel_correct, sub_correct) 126 | print('recall: ', sub_recall, n_sub_recall, sub_correct) 127 | print('single_acc: ', single_acc, n_single_correct, sub_correct) 128 | print("-" * 80) 129 | 130 | if args.predict: 131 | qa_pattern_file = '../entity_detection/results/QAData.label.%s.pkl' 132 | for tp in ('valid', 'test', 'train'): 133 | predict(qa_pattern_file % tp, tp) 134 | else: 135 | evaluate(args.valid_file, "valid") 136 | evaluate(args.test_file, "test") 137 | evaluate(args.train_file, 'train') 138 | -------------------------------------------------------------------------------- /relation_ranking/process.sh: -------------------------------------------------------------------------------- 1 | python train.py \ 2 | --epochs 30 \ 3 | --loss_margin 1 \ 4 | --dropout_prob 0.3 \ 5 | --dev_every 600 \ 6 | --d_rel_embed 256 \ 7 | --d_hidden 128 \ 8 | --n_layers 2 \ 9 | --channel_size 8 \ 10 | --conv_kernel_1 3 \ 11 | --conv_kernel_2 3 \ 12 | --pool_kernel_1 21 \ 13 | --pool_kernel_2 1 \ 14 | --gpu 6 \ 15 | --lr 0.0005 \ 16 | --rnn_type gru 17 | -------------------------------------------------------------------------------- /relation_ranking/seqRankingLoader.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | #-*- coding: utf-8 -*- 3 | 4 | # Author: QuYingqi 5 | # mail: cookiequ17@hotmail.com 6 | # Created Time: 2017-12-12 7 | import sys, os 8 | import pickle 9 | import numpy as np 10 | import random 11 | import torch 12 | from torch.autograd import Variable 13 | sys.path.append('../vocab') 14 | sys.path.append('../tools') 15 | import virtuoso 16 | from args import get_args 17 | args = get_args() 18 | 19 | def create_seq_ranking_data(qa_data, word_vocab, rel_sep_vocab, rel_vocab, save_path): 20 | seqs = [] 21 | seq_len = [] 22 | pos_rel1 = [] 23 | pos_rel2 = [] 24 | neg_rel1 = [] 25 | neg_rel2 = [] 26 | pos_rel = [] 27 | neg_rel = [] 28 | pos_rel_len = [] 29 | neg_rel_len = [] 30 | batch_index = -1 # the index of sequence batches 31 | seq_index = 0 # sequence index within each batch 32 | pad_index = word_vocab.lookup(word_vocab.pad_token) 33 | 34 | rel_max_len = args.rel_maxlen 35 | 36 | data_list = pickle.load(open(qa_data, 'rb')) 37 | 38 | def get_separated_rel_id(relation): 39 | rel = relation.split('.') 40 | rel1 = '.'.join(rel[:-1]) 41 | rel2 = rel[-1] 42 | rel_word = [] 43 | rel[0] = rel[0][3:] 44 | for i in rel: 45 | rel_word.extend(i.split('_')) 46 | 47 | rel1_id = rel_sep_vocab[0].convert_to_index([rel1])[0] 48 | rel2_id = rel_sep_vocab[1].convert_to_index([rel2])[0] 49 | rel_id = word_vocab.convert_to_index(rel_word) 50 | return rel1_id, rel2_id, rel_id 51 | 52 | for data in data_list: 53 | tokens = data.question_pattern.split() 54 | can_rels = [] 55 | if hasattr(data, 'cand_sub') and data.subject in data.cand_sub: 56 | can_rels = data.cand_rel 57 | else: 58 | can_subs = virtuoso.str_query_id(data.text_subject) 59 | for sub in can_subs: 60 | can_rels.extend(virtuoso.id_query_out_rel(sub)) 61 | can_rels = list(set(can_rels)) 62 | if data.relation in can_rels: 63 | can_rels.remove(data.relation) 64 | for i in range(len(can_rels), args.neg_size): 65 | tmp = random.randint(2, len(rel_vocab)-1) 66 | while(tmp in can_rels): 67 | tmp = random.randint(2, len(rel_vocab)-1) 68 | can_rels.append(rel_vocab.index2word[tmp]) 69 | can_rels = random.sample(can_rels, args.neg_size) 70 | 71 | if seq_index % args.batch_size == 0: 72 | seq_index = 0 73 | batch_index += 1 74 | seqs.append(torch.LongTensor(args.batch_size, len(tokens)).fill_(pad_index)) 75 | seq_len.append(torch.LongTensor(args.batch_size).fill_(1)) 76 | pos_rel1.append(torch.LongTensor(args.batch_size).fill_(pad_index)) 77 | pos_rel2.append(torch.LongTensor(args.batch_size).fill_(pad_index)) 78 | neg_rel1.append(torch.LongTensor(args.neg_size, args.batch_size).fill_(pad_index)) 79 | neg_rel2.append(torch.LongTensor(args.neg_size, args.batch_size).fill_(pad_index)) 80 | pos_rel.append(torch.LongTensor(args.batch_size, rel_max_len).fill_(pad_index)) 81 | pos_rel_len.append(torch.Tensor(args.batch_size).fill_(1)) 82 | neg_rel.append(torch.LongTensor(args.neg_size, args.batch_size, rel_max_len).fill_(pad_index)) 83 | neg_rel_len.append(torch.Tensor(args.neg_size, args.batch_size).fill_(1)) 84 | print('batch: %d' %batch_index) 85 | 86 | seqs[batch_index][seq_index, 0:len(tokens)] = torch.LongTensor(word_vocab.convert_to_index(tokens)) 87 | seq_len[batch_index][seq_index] = len(tokens) 88 | 89 | pos1, pos2, pos = get_separated_rel_id(data.relation) 90 | pos_rel1[batch_index][seq_index] = pos1 91 | pos_rel2[batch_index][seq_index] = pos2 92 | pos_rel[batch_index][seq_index, 0:len(pos)] = torch.LongTensor(pos) 93 | pos_rel_len[batch_index][seq_index] = len(pos) 94 | 95 | for j, can_rel in enumerate(can_rels): 96 | neg1, neg2, neg = get_separated_rel_id(can_rel) 97 | if not neg1 or not neg2: 98 | continue 99 | neg_rel1[batch_index][j,seq_index] = neg1 100 | neg_rel2[batch_index][j,seq_index] = neg2 101 | neg_rel[batch_index][j,seq_index, 0:len(neg)] = torch.LongTensor(neg) 102 | neg_rel_len[batch_index][j,seq_index] = len(neg) 103 | 104 | seq_index += 1 105 | 106 | torch.save((seqs, seq_len, pos_rel1, pos_rel2, neg_rel1, neg_rel2, pos_rel, pos_rel_len, neg_rel, neg_rel_len), save_path) 107 | 108 | class SeqRankingLoader(): 109 | def __init__(self, infile, device=-1): 110 | self.seqs, self.seq_len, self.pos_rel1, self.pos_rel2, self.neg_rel1, self.neg_rel2, self.pos_rel, self.pos_rel_len, self.neg_rel, self.neg_rel_len = torch.load(infile) 111 | self.batch_size = self.seqs[0].size(0) 112 | self.batch_num = len(self.seqs) 113 | 114 | if device >=0: 115 | for i in range(self.batch_num): 116 | self.seqs[i] = self.seqs[i].cuda(device) 117 | self.pos_rel1[i] = self.pos_rel1[i].cuda(device) 118 | self.pos_rel2[i] = self.pos_rel2[i].cuda(device) 119 | self.neg_rel1[i] = self.neg_rel1[i].cuda(device) 120 | self.neg_rel2[i] = self.neg_rel2[i].cuda(device) 121 | self.pos_rel[i] = self.pos_rel[i].cuda(device) 122 | self.neg_rel[i] = self.neg_rel[i].cuda(device) 123 | 124 | def next_batch(self, shuffle = True): 125 | if shuffle: 126 | indices = torch.randperm(self.batch_num) 127 | else: 128 | indices = range(self.batch_num) 129 | for i in indices: 130 | yield Variable(self.seqs[i]), self.seq_len[i], Variable(self.pos_rel1[i]), \ 131 | Variable(self.pos_rel2[i]), Variable(self.neg_rel1[i]), Variable(self.neg_rel2[i]), \ 132 | Variable(self.pos_rel[i]), self.pos_rel_len[i], Variable(self.neg_rel[i]),\ 133 | self.neg_rel_len[i] 134 | 135 | class CandidateRankingLoader(): 136 | def __init__(self, qa_pattern_file, word_vocab, rel_sep_vocab, device=-1): 137 | self.qa_pattern = pickle.load(open(qa_pattern_file, 'rb')) 138 | self.batch_num = len(self.qa_pattern) 139 | self.word_vocab = word_vocab 140 | self.rel_sep_vocab = rel_sep_vocab 141 | self.pad_index = word_vocab.lookup(word_vocab.pad_token) 142 | self.device = device 143 | 144 | def get_separated_rel_id(self, relation): 145 | rel = relation.split('.') 146 | rel1 = '.'.join(rel[:-1]) 147 | rel2 = rel[-1] 148 | rel_word = [] 149 | rel[0] = rel[0][3:] 150 | for i in rel: 151 | rel_word.extend(i.split('_')) 152 | 153 | rel1_id = self.rel_sep_vocab[0].convert_to_index([rel1])[0] 154 | rel2_id = self.rel_sep_vocab[1].convert_to_index([rel2])[0] 155 | rel_id = self.word_vocab.convert_to_index(rel_word) 156 | return rel1_id, rel2_id, rel_id 157 | 158 | def next_question(self): 159 | for data in self.qa_pattern: 160 | if not hasattr(data, 'cand_rel'): 161 | self.batch_num -= 1 162 | continue 163 | 164 | tokens = data.question_pattern.split() 165 | seqs = torch.LongTensor(self.word_vocab.convert_to_index(tokens)).unsqueeze(0) 166 | seq_len = torch.LongTensor([len(tokens)]) 167 | 168 | pos1, pos2, pos = self.get_separated_rel_id(data.relation) 169 | pos_rel1 = torch.LongTensor([pos1]) 170 | pos_rel2 = torch.LongTensor([pos2]) 171 | pos_rel = torch.LongTensor(args.rel_maxlen).fill_(self.pad_index) 172 | pos_rel[0:len(pos)] = torch.LongTensor(pos) 173 | pos_rel = pos_rel.unsqueeze(0) 174 | pos_len = torch.LongTensor([len(pos)]) 175 | 176 | neg_rel1 = torch.LongTensor(len(data.cand_rel)) 177 | neg_rel2 = torch.LongTensor(len(data.cand_rel)) 178 | neg_rel = torch.LongTensor(len(data.cand_rel), args.rel_maxlen).fill_(self.pad_index) 179 | neg_len = torch.LongTensor(len(data.cand_rel)) 180 | for idx, rel in enumerate(data.cand_rel): 181 | neg1, neg2, neg = self.get_separated_rel_id(rel) 182 | neg_rel1[idx] = neg1 183 | neg_rel2[idx] = neg2 184 | neg_rel[idx, 0:len(neg)] = torch.LongTensor(neg) 185 | neg_len[idx] = len(neg) 186 | neg_rel1.unsqueeze_(1) 187 | neg_rel2.unsqueeze_(1) 188 | neg_rel.unsqueeze_(1) 189 | 190 | if self.device>=0: 191 | seqs, pos_rel1, pos_rel2, neg_rel1, neg_rel2, pos_rel, neg_rel = \ 192 | seqs.cuda(self.device), pos_rel1.cuda(self.device), pos_rel2.cuda(self.device), \ 193 | neg_rel1.cuda(self.device), neg_rel2.cuda(self.device), pos_rel.cuda(self.device), \ 194 | neg_rel.cuda(self.device) 195 | yield Variable(seqs), seq_len, Variable(pos_rel1), Variable(pos_rel2), Variable(neg_rel1), Variable(neg_rel2), Variable(pos_rel), pos_len, Variable(neg_rel), neg_len, data 196 | 197 | if __name__ == '__main__': 198 | word_vocab = torch.load(args.vocab_file) 199 | rel_sep_vocab = torch.load(args.rel_vocab_file) 200 | rel_vocab = torch.load('../vocab/vocab.rel.pt') 201 | 202 | qa_data_path = '../entity_detection/results/QAData.label.%s.pkl' 203 | if not os.path.exists('data'): 204 | os.mkdir('data') 205 | 206 | create_seq_ranking_data(qa_data_path % 'valid', word_vocab, rel_sep_vocab, rel_vocab, args.valid_file) 207 | create_seq_ranking_data(qa_data_path % 'test', word_vocab, rel_sep_vocab, rel_vocab, args.test_file) 208 | create_seq_ranking_data(qa_data_path % 'train', word_vocab, rel_sep_vocab, rel_vocab, args.train_file) 209 | -------------------------------------------------------------------------------- /relation_ranking/train.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | #-*- coding: utf-8 -*- 3 | 4 | # Author: QuYingqi 5 | # mail: cookiequ17@hotmail.com 6 | # Created Time: 2017-11-09 7 | import torch 8 | import torch.optim as optim 9 | import torch.nn as nn 10 | import time 11 | import os, sys, glob 12 | import numpy as np 13 | 14 | from args import get_args 15 | from model import RelationRanking 16 | from seqRankingLoader import SeqRankingLoader 17 | 18 | # please set the configuration in the file : args.py 19 | args = get_args() 20 | # set the random seed for reproducibility 21 | torch.manual_seed(args.seed) 22 | if not args.cuda: 23 | args.gpu = -1 24 | if torch.cuda.is_available() and args.cuda: 25 | print("Note: You are using GPU for training") 26 | torch.cuda.set_device(args.gpu) 27 | torch.cuda.manual_seed(args.seed) 28 | if torch.cuda.is_available() and not args.cuda: 29 | print("Warning: You have Cuda but do not use it. You are using CPU for training") 30 | 31 | # load word vocab for questions, relation vocab for relations 32 | word_vocab = torch.load(args.vocab_file) 33 | print('load word vocab, size: %s' % len(word_vocab)) 34 | rel_vocab = torch.load(args.rel_vocab_file) 35 | print('load relation vocab, size: %s' %len(rel_vocab)) 36 | 37 | # load data 38 | train_loader = SeqRankingLoader(args.train_file, args.gpu) 39 | print('load train data, batch_num: %d\tbatch_size: %d' 40 | %(train_loader.batch_num, train_loader.batch_size)) 41 | valid_loader = SeqRankingLoader(args.valid_file, args.gpu) 42 | print('load valid data, batch_num: %d\tbatch_size: %d' 43 | %(valid_loader.batch_num, valid_loader.batch_size)) 44 | 45 | os.makedirs(args.save_path, exist_ok=True) 46 | 47 | # define models 48 | config = args 49 | config.n_cells = config.n_layers 50 | 51 | if config.birnn: 52 | config.n_cells *= 2 53 | print(config) 54 | with open(os.path.join(config.save_path, 'param.log'), 'w') as f: 55 | f.write(str(config)) 56 | 57 | if args.resume_snapshot: 58 | model = torch.load(args.resume_snapshot, map_location=lambda storage, location: storage) 59 | else: 60 | model = RelationRanking(word_vocab, rel_vocab, config) 61 | if args.word_vectors: 62 | if os.path.isfile(args.vector_cache): 63 | pretrained = torch.load(args.vector_cache) 64 | model.word_embed.word_lookup_table.weight.data.copy_(pretrained) 65 | else: 66 | pretrained = model.word_embed.load_pretrained_vectors(args.word_vectors, binary=False, 67 | normalize=args.word_normalize) 68 | torch.save(pretrained, args.vector_cache) 69 | print('load pretrained word vectors from %s, pretrained size: %s' %(args.word_vectors, 70 | pretrained.size())) 71 | if args.cuda: 72 | model.cuda() 73 | print("Shift model to GPU") 74 | 75 | # show model parameters 76 | for name, param in model.named_parameters(): 77 | print(name, param.size()) 78 | 79 | criterion = nn.MarginRankingLoss(args.loss_margin) # Max margin ranking loss function 80 | optimizer = optim.Adam(model.parameters(), lr=args.lr) 81 | 82 | # train the model 83 | iterations = 0 84 | start = time.time() 85 | best_dev_acc = 0 86 | best_dev_F = 0 87 | num_iters_in_epoch = train_loader.batch_num 88 | patience = args.patience * num_iters_in_epoch # for early stopping 89 | iters_not_improved = 0 # this parameter is used for stopping early 90 | early_stop = False 91 | header = ' Time Epoch Iteration Progress (%Epoch) Loss Accuracy Dev/Accuracy' 92 | dev_log_template = ' '.join('{:>6.0f},{:>5.0f},{:>9.0f},{:>5.0f}/{:<5.0f} {:>7.0f}%,{:>8.6f},{:12.4f},{:12.4f}'.split(',')) 93 | log_template = ' '.join('{:>6.0f},{:>5.0f},{:>9.0f},{:>5.0f}/{:<5.0f} {:>7.0f}%,{:>8.6f},{:12.4f},{}'.split(',')) 94 | best_snapshot_prefix = os.path.join(args.save_path, 'best_snapshot') 95 | print(header) 96 | 97 | for epoch in range(1, args.epochs+1): 98 | if early_stop: 99 | print("Early stopping. Epoch: {}, Best Dev. Acc: {}".format(epoch, best_dev_acc)) 100 | break 101 | 102 | n_correct, n_total = 0, 0 103 | 104 | for batch_idx, batch in enumerate(train_loader.next_batch()): 105 | iterations += 1 106 | model.train(); 107 | optimizer.zero_grad() 108 | 109 | pos_score, neg_score = model(batch) 110 | 111 | n_correct += (torch.sum(torch.gt(pos_score, neg_score), 0).data == neg_score.size(0)).sum() 112 | n_total += pos_score.size(1) 113 | train_acc = 100. * n_correct / n_total 114 | 115 | ones = torch.autograd.Variable(torch.ones(pos_score.size(0)*pos_score.size(1))) 116 | if args.cuda: 117 | ones = ones.cuda() 118 | loss = criterion(pos_score.contiguous().view(-1,1).squeeze(1), neg_score.contiguous().view(-1,1).squeeze(1), ones) 119 | loss.backward() 120 | 121 | # clip the gradient 122 | torch.nn.utils.clip_grad_norm(model.parameters(), args.clip_gradient) 123 | optimizer.step() 124 | 125 | # checkpoint model periodically 126 | if iterations % args.save_every == 0: 127 | snapshot_prefix = os.path.join(args.save_path, 'snapshot') 128 | snapshot_path = snapshot_prefix + \ 129 | '_iter_{}_acc_{:.4f}_loss_{:.6f}_model.pt'.format(iterations, train_acc, loss.data[0]) 130 | torch.save(model, snapshot_path) 131 | for f in glob.glob(snapshot_prefix + '*'): 132 | if f != snapshot_path: 133 | os.remove(f) 134 | 135 | # evaluate performance on validation set periodically 136 | if iterations % args.dev_every == 0: 137 | model.eval() 138 | n_dev_correct = 0 139 | valid_total = 0 140 | 141 | gold_list = [] 142 | pred_list = [] 143 | 144 | for valid_batch_idx, valid_batch in enumerate(valid_loader.next_batch(False)): 145 | val_ps, val_ns = model(valid_batch) 146 | val_neg_size, val_batch_size = val_ps.size() 147 | 148 | n_dev_correct += (torch.sum(torch.gt(val_ps, val_ns), 0).data == val_neg_size).sum() 149 | valid_total += val_batch_size 150 | 151 | dev_acc = 100. * n_dev_correct / valid_total 152 | print(dev_log_template.format(time.time() - start, epoch, iterations, 153 | 1 + batch_idx, train_loader.batch_num, 154 | 100. * (1 + batch_idx) / train_loader.batch_num, 155 | loss.data[0], train_acc, dev_acc)) 156 | # print("{} Precision: {:10.6f}% Recall: {:10.6f}% F1 Score: {:10.6f}%".format("Dev", 100. * P, 100. * R, 100. * F)) 157 | # update model 158 | if dev_acc > best_dev_acc: 159 | best_dev_acc = dev_acc 160 | iters_not_improved = 0 161 | snapshot_path = best_snapshot_prefix + \ 162 | '_iter_{}_devf1_{}_model.pt'.format(iterations, best_dev_acc) 163 | 164 | # save model, delete previous 'best_snapshot' files 165 | torch.save(model, snapshot_path) 166 | for f in glob.glob(best_snapshot_prefix + '*'): 167 | if f != snapshot_path: 168 | os.remove(f) 169 | 170 | else: 171 | iters_not_improved += 1 172 | if iters_not_improved > patience: 173 | early_stop = True 174 | break 175 | 176 | # print progress message 177 | elif iterations % args.log_every == 0: 178 | print(log_template.format(time.time()-start, epoch, iterations, 1+batch_idx, 179 | train_loader.batch_num, 100. * (1+batch_idx)/train_loader.batch_num, 180 | loss.data[0], train_acc, ' '*12)) 181 | -------------------------------------------------------------------------------- /tools/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | #-*- coding: utf-8 -*- 3 | 4 | # Author: QuYingqi 5 | # mail: cookiequ17@hotmail.com 6 | # Created Time: 2017-11-06 7 | -------------------------------------------------------------------------------- /tools/embedding.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # Mostly from https://github.com/OpenNMT/OpenNMT-py 4 | # Most Code in load/save word2vec format refer to Gensim 5 | import torch 6 | import torch.nn as nn 7 | from utils import load_word2vec_format, aeq 8 | 9 | class Embeddings(nn.Module): 10 | 11 | def __init__(self, 12 | word_vec_size, 13 | dicts, 14 | feat_merge='concat', 15 | feat_vec_exponent=0.7, 16 | feature_dicts=None, 17 | feature_dims=None): 18 | """ 19 | :param word_vec_size: Word Embedding Size 20 | :param dicts: Word Dict 21 | :param feat_merge: Merge action for the features embeddings. 22 | :param feat_vec_exponent: 23 | When features embedding sizes are not set and using -feat_merge concat, 24 | their dimension will be set to N^feat_vec_exponent where N is the number 25 | of values the feature takes 26 | :param feature_dicts: 27 | """ 28 | super(Embeddings, self).__init__() 29 | 30 | self.word_dict = dicts 31 | self.word_vec_size = word_vec_size 32 | self.feat_exp = feat_vec_exponent 33 | self.feat_merge = feat_merge 34 | 35 | # vocab_sizes: sequence of vocab sizes for words and each feature 36 | vocab_sizes = [self.word_dict.size()] 37 | 38 | # emb_sizes 39 | emb_sizes = [self.word_vec_size] 40 | if feature_dicts is not None and len(feature_dicts) > 0: 41 | vocab_sizes.extend(feat_dict.size() for feat_dict in feature_dicts) 42 | if self.feat_merge == 'concat': 43 | # Derive embedding sizes from each feature's vocab size 44 | emb_sizes.extend([int(feature_dim) for feature_dim in feature_dims]) 45 | elif self.feat_merge == 'sum': 46 | # All embeddings to be summed must be the same size 47 | emb_sizes.extend(feature_dims) 48 | else: 49 | # TODO MLP 50 | raise NotImplementedError 51 | 52 | # Embedding Lookup Tables 53 | # [word_embedd, ... 54 | # other embedding if has] 55 | self.emb_luts = nn.ModuleList([ 56 | nn.Embedding(vocab, dim, padding_idx=self.word_dict.lookup(self.word_dict.pad_token)) 57 | for vocab, dim in zip(vocab_sizes, emb_sizes)]) 58 | 59 | self.init_model() 60 | 61 | self.output_size = self.embedding_size() 62 | 63 | def embedding_size(self): 64 | """ 65 | Returns sum of all feature dimensions if the merge action is concat. 66 | Otherwise, returns word vector size. 67 | """ 68 | if self.feat_merge == 'concat': 69 | return sum(emb_lut.embedding_dim 70 | for emb_lut in self.emb_luts.children()) 71 | else: 72 | return self.word_lookup_table.embedding_dim 73 | 74 | @property 75 | def word_lookup_table(self): 76 | return self.emb_luts[0] 77 | 78 | def init_model(self): 79 | for emb_table in self.emb_luts: 80 | emb_table.weight.data.normal_(0, 0.1) 81 | 82 | def load_pretrained_vectors(self, emb_file, binary=True, normalize=False): 83 | if emb_file is not None: 84 | pretrained, vec_size, vocab = load_word2vec_format(emb_file, self.word_dict.word2index, 85 | binary=binary, normalize=normalize) 86 | 87 | # Init Out-of-PreTrain Wordembedding using Min,Max Uniform 88 | scale = torch.std(pretrained) 89 | # random_range = (torch.min(pretrained), torch.max(pretrained)) 90 | random_range = (-scale, scale) 91 | random_init_count = 0 92 | for word in self.word_dict: 93 | 94 | if word not in vocab: 95 | random_init_count += 1 96 | nn.init.uniform(pretrained[self.word_dict.lookup(word)], 97 | random_range[0], random_range[1]) 98 | 99 | self.word_lookup_table.weight.data.copy_(pretrained) 100 | print("Init %s words in uniform [%s, %s]" % (random_init_count, random_range[0], random_range[1])) 101 | return pretrained 102 | 103 | def merge(self, features): 104 | if self.feat_merge == 'concat': 105 | return torch.cat(features, 2) 106 | elif self.feat_merge == 'sum': 107 | return sum(features) 108 | else: 109 | return self.mlp(torch.cat(features, 2)) 110 | 111 | def forward(self, inp): 112 | """ 113 | Return the embeddings for words, and features if there are any. 114 | Args: 115 | inp (LongTensor): batch x len x nfeat 116 | Return: 117 | emb (Tensor): batch x len x self.embedding_size 118 | """ 119 | if inp.dim() == 2: 120 | # batch x len 121 | emb = self.word_lookup_table(inp) 122 | return emb 123 | 124 | in_batch, in_length, nfeat = inp.size() 125 | aeq(nfeat, len(self.emb_luts)) 126 | 127 | if len(self.emb_luts) == 1: 128 | emb = self.word_lookup_table(inp.squeeze(2)) 129 | else: 130 | feat_inputs = (feat.squeeze(2) 131 | for feat in inp.split(1, dim=2)) 132 | features = [lut(feat) 133 | for lut, feat in zip(self.emb_luts, feat_inputs)] 134 | emb = self.merge(features) 135 | 136 | out_batch, out_length, emb_size = emb.size() 137 | aeq(in_batch, out_batch) 138 | aeq(in_length, out_length) 139 | aeq(emb_size, self.embedding_size()) 140 | 141 | return emb 142 | -------------------------------------------------------------------------------- /tools/qa_data.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from sklearn import preprocessing 3 | 4 | def fb2www(in_data): 5 | if type(in_data) == type(' '): 6 | out_data = in_data.replace('.', '/').replace('fb:', 'www.freebase.com/') 7 | elif type(in_data) == type([]): 8 | out_data = [data.replace('.', '/').replace('fb:', 'www.freebase.com/') for data in in_data] 9 | return out_data 10 | 11 | class QAData(object): 12 | """docstring for QAData""" 13 | def __init__(self, data_tuple): 14 | super(QAData, self).__init__() 15 | self.question = data_tuple[0] 16 | self.subject = data_tuple[1] 17 | self.relation = data_tuple[2] 18 | self.object = data_tuple[3] 19 | self.num_text_token = int(data_tuple[4]) 20 | 21 | def add_candidate(self, sub, rels, types = None): 22 | if not hasattr(self, 'cand_sub'): 23 | self.cand_sub = [] 24 | if not hasattr(self, 'cand_rel'): 25 | self.cand_rel = [] 26 | if not hasattr(self, 'sub_rels'): 27 | self.sub_rels = [] 28 | self.cand_sub.append(sub) 29 | self.sub_rels.append(rels) 30 | self.cand_rel.extend(rels) 31 | if types: 32 | if not hasattr(self, 'sub_types'): 33 | self.sub_types = [] 34 | self.sub_types.append(types) 35 | 36 | def add_sub_types(self, types): 37 | if not hasattr(self, 'sub_types'): 38 | self.sub_types = [] 39 | self.sub_types.append(types) 40 | 41 | def remove_duplicate(self): 42 | self.cand_rel = list(set(self.cand_rel)) 43 | 44 | def make_score_mat(self): 45 | # make candidate unique rels 46 | self.num_sub = len(self.cand_sub) 47 | self.num_rel = len(self.cand_rel) 48 | self.rel_dict = {self.cand_rel[i]:i for i in range(self.num_rel)} 49 | 50 | # establish score matrix 51 | self.score_mat = np.zeros((self.num_sub, self.num_rel)) 52 | for i in range(self.num_sub): 53 | for rel in self.sub_rels[i]: 54 | self.score_mat[i, self.rel_dict[rel]] = 1 55 | 56 | def fill_rel_score(self, scores): 57 | self.score_mat = self.score_mat * scores 58 | 59 | def fill_ent_score(self, scores): 60 | self.ent_score = preprocessing.scale(scores) 61 | 62 | def top_sub_rel(self): 63 | sub_score = np.sum(self.score_mat, 1) 64 | top_subscore = np.max(sub_score) 65 | top_subids = [] 66 | for subid in np.argsort(sub_score)[::-1]: 67 | if sub_score[subid] < top_subscore: 68 | break 69 | top_subids.append(subid) 70 | 71 | top_relid = np.argmax(self.score_mat[top_subids[0]]) 72 | 73 | return [self.cand_sub[subid] for subid in top_subids], self.cand_rel[top_relid] 74 | -------------------------------------------------------------------------------- /tools/utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | import sys 4 | from builtins import range 5 | import numpy as np 6 | import torch 7 | REAL = np.float32 8 | if sys.version_info[0] >= 3: 9 | unicode = str 10 | 11 | 12 | def to_unicode(text, encoding='utf8', errors='strict'): 13 | """Convert a string (bytestring in `encoding` or unicode), to unicode. 14 | :param text: 15 | :param encoding: 16 | :param errors: errors can be 'strict', 'replace' or 'ignore' and defaults to 'strict'. 17 | """ 18 | if isinstance(text, unicode): 19 | return text 20 | return unicode(text, encoding, errors=errors) 21 | 22 | 23 | def any2utf8(text, encoding='utf8', errors='strict'): 24 | """Convert a string (unicode or bytestring in `encoding`), to bytestring in utf8.""" 25 | if isinstance(text, unicode): 26 | return text.encode('utf8') 27 | # do bytestring -> unicode -> utf8 full circle, to ensure valid utf8 28 | return unicode(text, encoding, errors=errors).encode('utf8') 29 | 30 | 31 | def aeq(*args): 32 | base = args[0] 33 | for a in args[1:]: 34 | assert a == base, str(args) 35 | 36 | 37 | 38 | def load_word2vec_format(filename, word_idx, binary=False, normalize=False, 39 | encoding='utf8', unicode_errors='ignore'): 40 | """ 41 | refer to gensim 42 | load Word Embeddings 43 | If you trained the C model using non-utf8 encoding for words, specify that 44 | encoding in `encoding`. 45 | :param filename : 46 | :param word_idx : 47 | :param binary : a boolean indicating whether the data is in binary word2vec format. 48 | :param normalize: 49 | :param encoding : 50 | :param unicode_errors: errors can be 'strict', 'replace' or 'ignore' and defaults to 'strict'. 51 | """ 52 | vocab = set() 53 | print("loading word embedding from %s" % filename) 54 | with open(filename, 'rb') as fin: 55 | # header = to_unicode(fin.readline(), encoding=encoding) 56 | # vocab_size, vector_size = map(int, header.split()) # throws for invalid file format 57 | vocab_size = 1917494 58 | vector_size = 300 59 | word_matrix = torch.zeros(len(word_idx), vector_size) 60 | 61 | def add_word(_word, _weights): 62 | if _word not in word_idx: 63 | return 64 | vocab.add(_word) 65 | word_matrix[word_idx[_word]] = _weights 66 | 67 | if binary: 68 | binary_len = np.dtype(np.float32).itemsize * vector_size 69 | for _ in range(vocab_size): 70 | # mixed text and binary: read text first, then binary 71 | word = [] 72 | while True: 73 | ch = fin.read(1) 74 | if ch == b' ': 75 | break 76 | if ch != b'\n': # ignore newlines in front of words (some binary files have) 77 | word.append(ch) 78 | word = to_unicode(b''.join(word), encoding=encoding, errors=unicode_errors) 79 | weights = torch.from_numpy(np.fromstring(fin.read(binary_len), dtype=REAL)) 80 | add_word(word, weights) 81 | else: 82 | for line_no, line in enumerate(fin): 83 | parts = to_unicode(line.rstrip(), encoding=encoding, errors=unicode_errors).split(" ") 84 | if len(parts) != vector_size + 1: 85 | raise ValueError("invalid vector on line %s (is this really the text format?)" % line_no) 86 | word, weights = parts[0], list(map(float, parts[1:])) 87 | weights = torch.Tensor(weights) 88 | add_word(word, weights) 89 | if word_idx is not None: 90 | assert (len(word_idx), vector_size) == word_matrix.size() 91 | if normalize: 92 | # each row normalize to 1 93 | word_matrix = torch.renorm(word_matrix, 2, 0, 1) 94 | print("loaded %d words pre-trained from %s with %d" % (len(vocab), filename, vector_size)) 95 | return word_matrix, vector_size, vocab 96 | 97 | 98 | def clip_weight_norm(model, max_norm, norm_type=2, except_params=None): 99 | """Clips gradient norm of an iterable of parameters. 100 | 101 | The norm is computed over all gradients together, as if they were 102 | concatenated into a single vector. Gradients are modified in-place. 103 | 104 | Arguments: 105 | parameters (Iterable[Variable]): an iterable of Variables that will have 106 | gradients normalized 107 | max_norm (float or int): max norm of the gradients 108 | norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for 109 | infinity norm. 110 | 111 | Returns: 112 | Total norm of the parameters (viewed as a single vector). 113 | """ 114 | for name, param in model.named_parameters(): 115 | if except_params is not None: 116 | for except_param in except_params: 117 | if except_param in name: 118 | # print "Pass", name 119 | pass 120 | 121 | if len(param.size()) == 2: 122 | 123 | if name == 'out.linear.weight': 124 | row_norm = torch.norm(param.data, norm_type, 1) 125 | desired_norm = torch.clamp(row_norm, 0, np.sqrt(max_norm)) 126 | scale = desired_norm / (row_norm + 1e-7) 127 | param.data = scale[:, None] * param.data 128 | # print "Row Norm", torch.norm(param.data, norm_type, 1) 129 | else: 130 | col_norm = torch.norm(param.data, norm_type, 0) 131 | desired_norm = torch.clamp(col_norm, 0, np.sqrt(max_norm)) 132 | scale = desired_norm / (col_norm + 1e-7) 133 | param.data *= scale 134 | # print "Col Norm", torch.norm(param.data, norm_type, 0) 135 | -------------------------------------------------------------------------------- /tools/virtuoso.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | import sys, json 3 | from urllib import parse, request 4 | 5 | # Setting global variables 6 | data_source = 'fb2m:' 7 | query_url = 'http://localhost:8890/sparql/' 8 | 9 | # HTTP URL is constructed accordingly with JSON query results format in mind. 10 | def sparql_query(query, URL, format='application/json'): 11 | 12 | params={ 13 | 'default-graph': '', 14 | 'should-sponge': 'soft', 15 | 'query': query.encode('utf8'), 16 | 'debug': 'on', 17 | 'timeout': '', 18 | 'format': format, 19 | 'save': 'display', 20 | 'fname': '' 21 | } 22 | 23 | encoded_query = parse.urlencode(params).encode('utf-8') 24 | http_response = request.urlopen(URL, encoded_query).read() 25 | 26 | try: 27 | json_response = json.loads(http_response.decode('utf-8')) 28 | return json_response 29 | except: 30 | print >> sys.stderr, 'json load error' 31 | print >> sys.stderr, http_response 32 | return None 33 | 34 | # Using freebase mid to query its types 35 | def id_query_type(node_id): 36 | query = ''' 37 | SELECT ?type FROM <%s> WHERE {<%s> ?type} 38 | ''' % (data_source, node_id) 39 | json_response = sparql_query(query, query_url) 40 | 41 | try: 42 | type_list = [item['type']['value'] for item in json_response['results']['bindings']] 43 | return list(set(type_list)) 44 | except: 45 | return [] 46 | 47 | # Using freebase mid to query its original cased name 48 | def id_query_en_name(node_id): 49 | query = ''' 50 | SELECT ?name FROM <%s> WHERE {<%s> ?name} 51 | ''' % (data_source, node_id) 52 | json_response = sparql_query(query, query_url) 53 | 54 | try: 55 | name_list = [item['name']['value'] for item in json_response['results']['bindings']] 56 | return list(set(name_list)) 57 | except: 58 | return [] 59 | 60 | # Using freebase mid to query its original cased alias 61 | def id_query_en_alias(node_id): 62 | query = ''' 63 | SELECT ?alias FROM <%s> WHERE {<%s> ?alias} 64 | ''' % (data_source, node_id) 65 | json_response = sparql_query(query, query_url) 66 | 67 | try: 68 | alias_list = [item['alias']['value'] for item in json_response['results']['bindings']] 69 | return list(set(alias_list)) 70 | except: 71 | return [] 72 | 73 | # Using freebase mid to query its processed & tokenized name 74 | def id_query_name(node_id): 75 | query = ''' 76 | SELECT ?name FROM <%s> WHERE {<%s> ?name} 77 | ''' % (data_source, node_id) 78 | json_response = sparql_query(query, query_url) 79 | 80 | try: 81 | name_list = [item['name']['value'] for item in json_response['results']['bindings']] 82 | return list(set(name_list)) 83 | except: 84 | return [] 85 | 86 | # Using freebase mid to query its processed & tokenized alias 87 | def id_query_alias(node_id): 88 | query = ''' 89 | SELECT ?alias FROM <%s> WHERE {<%s> ?alias} 90 | ''' % (data_source, node_id) 91 | json_response = sparql_query(query, query_url) 92 | 93 | try: 94 | alias_list = [item['alias']['value'] for item in json_response['results']['bindings']] 95 | return list(set(alias_list)) 96 | except: 97 | return [] 98 | 99 | # Using freebase mid to query its processed & tokenized name & alias 100 | def id_query_str(node_id): 101 | query = ''' 102 | SELECT ?str FROM <%s> WHERE { {<%s> ?str} UNION {<%s> ?str} } 103 | ''' % (data_source, node_id, node_id) 104 | json_response = sparql_query(query, query_url) 105 | 106 | try: 107 | name_list = [item['str']['value'] for item in json_response['results']['bindings']] 108 | return list(set(name_list)) 109 | except: 110 | return [] 111 | 112 | # Using freebase mid to query all relations coming out of the entity 113 | def id_query_out_rel(node_id, unique = True): 114 | query = ''' 115 | SELECT ?relation FROM <%s> WHERE {<%s> ?relation ?object} 116 | ''' % (data_source, node_id) 117 | json_response = sparql_query(query, query_url) 118 | 119 | try: 120 | relations = [str(item['relation']['value']) for item in json_response['results']['bindings']] 121 | return list(set(relations)) 122 | except: 123 | return [] 124 | 125 | # Using freebase mid to query all relations coming into the entity 126 | def id_query_in_rel(node_id, unique = True): 127 | query = ''' 128 | SELECT ?relation FROM <%s> WHERE {?subject ?relation <%s>} 129 | ''' % (data_source, node_id) 130 | json_response = sparql_query(query, query_url) 131 | 132 | try: 133 | relations = [str(item['relation']['value']) for item in json_response['results']['bindings']] 134 | return list(set(relations)) 135 | except: 136 | return [] 137 | 138 | # Using the name of an entity to query its freebase mid 139 | def name_query_id(name): 140 | query = ''' 141 | SELECT ?node_id FROM <%s> WHERE {?node_id "%s"} 142 | ''' % (data_source, name) 143 | json_response = sparql_query(query, query_url) 144 | 145 | try: 146 | node_id_list = [str(item['node_id']['value']) for item in json_response['results']['bindings']] 147 | return list(set(node_id_list)) 148 | except: 149 | return [] 150 | 151 | # Using the alias of an entity to query its freebase mid 152 | def alias_query_id(alias): 153 | query = ''' 154 | SELECT ?node_id FROM <%s> WHERE {?node_id "%s"} 155 | ''' % (data_source, alias) 156 | json_response = sparql_query(query, query_url) 157 | 158 | try: 159 | node_id_list = [str(item['node_id']['value']) for item in json_response['results']['bindings']] 160 | return list(set(node_id_list)) 161 | except: 162 | return [] 163 | 164 | # Using the alias/name of an entity to query its freebase mid 165 | def str_query_id(string): 166 | query = ''' 167 | SELECT ?node_id FROM <%s> WHERE { {?node_id "%s"} UNION {?node_id "%s"} } 168 | ''' % (data_source, string, string) 169 | json_response = sparql_query(query, query_url) 170 | 171 | try: 172 | node_id_list = [str(item['node_id']['value']) for item in json_response['results']['bindings']] 173 | return list(set(node_id_list)) 174 | except: 175 | return [] 176 | 177 | # Using freebase mid to query all object coming out of the entity 178 | def id_query_in_entity(node_id, unique = True): 179 | query = ''' 180 | SELECT ?subject FROM <%s> WHERE {?subject ?relation <%s>} 181 | ''' % (data_source, node_id) 182 | json_response = sparql_query(query, query_url) 183 | 184 | try: 185 | subjects = [str(item['subject']['value']) for item in json_response['results']['bindings']] 186 | return list(set(subjects)) 187 | except: 188 | return [] 189 | 190 | # Using freebase mid to query all relation coming into the entity 191 | def id_query_out_entity(node_id, unique = True): 192 | query = ''' 193 | SELECT ?object FROM <%s> WHERE {<%s> ?relation ?object} 194 | ''' % (data_source, node_id) 195 | json_response = sparql_query(query, query_url) 196 | 197 | try: 198 | objects = [str(item['object']['value']) for item in json_response['results']['bindings']] 199 | return list(set(objects)) 200 | except: 201 | return [] 202 | 203 | # Using the subject and relation to query the corresponding object 204 | def query_object(subject, relation): 205 | query = ''' 206 | SELECT ?object FROM <%s> WHERE {<%s> <%s> ?object} 207 | ''' % (data_source, subject, relation) 208 | json_response = sparql_query(query, query_url) 209 | 210 | try: 211 | return [str(item['object']['value']) for item in json_response['results']['bindings']] 212 | except: 213 | return [] 214 | 215 | # Using the object and relation to query the corresponding subject 216 | def query_subject(obj, relation): 217 | query = ''' 218 | SELECT ?subject FROM <%s> WHERE {?subject <%s> <%s>} 219 | ''' % (data_source, relation, obj) 220 | json_response = sparql_query(query, query_url) 221 | 222 | try: 223 | return [str(item['subject']['value']) for item in json_response['results']['bindings']] 224 | except: 225 | return [] 226 | 227 | # Using the subject and object to query the corresponding relation 228 | def query_relation(sub, obj): 229 | query = ''' 230 | SELECT ?relation FROM <%s> WHERE {<%s> ?relation <%s>} 231 | ''' % (data_source, sub, obj) 232 | json_response = sparql_query(query, query_url) 233 | 234 | try: 235 | objects = [str(item['relation']['value']) for item in json_response['results']['bindings']] 236 | return list(set(objects)) 237 | except: 238 | return [] 239 | 240 | def relation_query_subject(relation): 241 | query = ''' 242 | SELECT ?subject FROM <%s> WHERE {?subject <%s> ?object} 243 | '''% (data_source, relation) 244 | json_response = sparql_query(query, query_url) 245 | 246 | try: 247 | return [str(item['subject']['value']) for item in json_response['results']['bindings']] 248 | except: 249 | return [] 250 | 251 | # Check whether a node is a CVT node 252 | def check_cvt(node_id): 253 | query = ''' 254 | SELECT ?tag FROM <%s> WHERE {<%s> ?tag} 255 | ''' % (data_source, node_id) 256 | json_response = sparql_query(query, query_url) 257 | ret = [str(item['tag']['value']) for item in json_response['results']['bindings']] 258 | 259 | if len(ret) == 1 and ret[0] == 'true': 260 | return True 261 | else: 262 | return False 263 | -------------------------------------------------------------------------------- /vocab/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | #-*- coding: utf-8 -*- 3 | 4 | # Author: QuYingqi 5 | # mail: cookiequ17@hotmail.com 6 | # Created Time: 2017-11-06 7 | -------------------------------------------------------------------------------- /vocab/create_vocab.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | #-*- coding: utf-8 -*- 3 | 4 | # Author: QuYingqi 5 | # mail: cookiequ17@hotmail.com 6 | # Created Time: 2017-11-05 7 | 8 | from dictionary import Dictionary 9 | import torch 10 | import pickle 11 | import sys 12 | sys.path.append('../tools') 13 | from qa_data import QAData 14 | 15 | def load_word_dictionary(filename, word_dict=None): 16 | if word_dict is None: 17 | word_dict = Dictionary() 18 | word_dict.add_unk_token() 19 | word_dict.add_pad_token() 20 | with open(filename) as f: 21 | for line in f: 22 | if not line:break 23 | line = line.strip() 24 | if not line:continue 25 | word_dict.add(line) 26 | return word_dict 27 | 28 | def load_rel_separated_dictionary(filename): 29 | rel1_dict = Dictionary() 30 | rel1_dict.add_unk_token() 31 | rel1_dict.add_pad_token() 32 | rel2_dict = Dictionary() 33 | rel2_dict.add_unk_token() 34 | rel2_dict.add_pad_token() 35 | with open(filename) as f: 36 | for line in f: 37 | if not line:break 38 | line = line.strip() 39 | if not line:continue 40 | line = line.split('.') 41 | rel1 = '.'.join(line[:-1]) 42 | rel2 = line[-1] 43 | rel1_dict.add(rel1) 44 | rel2_dict.add(rel2) 45 | return rel1_dict, rel2_dict 46 | 47 | def creat_word_rel_dict(r_file, *q_files): 48 | word_dict = Dictionary() 49 | word_dict.add_unk_token() 50 | word_dict.add_pad_token() 51 | word_dict.add_start_token() 52 | 53 | for q_file in q_files: 54 | qa_data = pickle.load(open(q_file, 'rb')) 55 | for data in qa_data: 56 | q = data.question 57 | tokens = q.split(' ') 58 | for token in tokens: 59 | word_dict.add(token) 60 | print(len(word_dict)) 61 | 62 | rels = pickle.load(open(r_file, 'rb')) 63 | for rel in rels: 64 | rel_word = [] 65 | w = rel[3:].split('.') 66 | for i in w: 67 | rel_word.extend(i.split('_')) 68 | for word in rel_word: 69 | word_dict.add(word) 70 | print(len(word_dict)) 71 | return word_dict 72 | 73 | if __name__ == '__main__': 74 | 75 | rel_vocab = load_word_dictionary('../freebase_data/FB2M.rel.txt') 76 | torch.save(rel_vocab, 'vocab.rel.pt') 77 | 78 | ent_vocab = load_word_dictionary('../freebase_data/FB2M.ent.txt') 79 | torch.save(ent_vocab, 'vocab.ent.pt') 80 | 81 | rel1_vocab, rel2_vocab = load_rel_separated_dictionary('../freebase_data/FB2M.rel.txt') 82 | torch.save((rel1_vocab, rel2_vocab), 'vocab.rel.sep.pt') 83 | 84 | word_rel_vocab = creat_word_rel_dict('../freebase_data/FB2M.rel.pkl', '../data/QAData.test.pkl', 85 | '../data/QAData.train.pkl', '../data/QAData.valid.pkl') 86 | torch.save(word_rel_vocab, 'vocab.word&rel.pt') 87 | 88 | -------------------------------------------------------------------------------- /vocab/dictionary.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | from six import iteritems 4 | import codecs 5 | 6 | class Dictionary(object): 7 | """ 8 | A Class for Manage Word Dictionary by count/top setting. 9 | """ 10 | def __init__(self, lower=True): 11 | # Word -> Index, index start from 0 12 | self.word2index = dict() 13 | # Index -> Word, index start from 0 14 | self.index2word = dict() 15 | # Word -> Count 16 | self.word_count = dict() 17 | self.lower = lower 18 | 19 | # Special entries will not be pruned. 20 | self.special = set() 21 | 22 | def __add__(self, d): 23 | """ 24 | Merge two Dictionary 25 | :param d: 26 | :return: 27 | """ 28 | assert type(d) == Dictionary 29 | assert self.lower == d.lower 30 | 31 | word_set = set(self.word2index) | set(d.word_index) 32 | new_d = Dictionary(self.lower) 33 | new_d.special = self.special | d.special 34 | 35 | for w in word_set: 36 | new_d.add(w, count=self.lookup_count(w)) 37 | new_d.add(w, count=d.lookup_count(w)) 38 | 39 | return new_d 40 | 41 | def __getitem__(self, word): 42 | return self.word2index[word] 43 | 44 | def __contains__(self, word): 45 | return word in self.word2index 46 | 47 | def __len__(self): 48 | return len(self.word2index) 49 | 50 | def __iter__(self): 51 | for word in self.word2index: 52 | yield word 53 | 54 | def size(self): 55 | return len(self) 56 | 57 | def lookup(self, key, default=None): 58 | key = self.lower_(key) 59 | try: 60 | return self.word2index[key] 61 | except KeyError: 62 | return default 63 | 64 | def lower_(self, key): 65 | if isinstance(key, int): 66 | return key 67 | return key.lower() if self.lower else key 68 | 69 | def add(self, key, idx=None, count=1): 70 | """ 71 | Add word to Dictionary 72 | :param key: 73 | :param idx: Use `idx` as its index if given. 74 | :param count: Use `count` as its count if given, default is 1. 75 | """ 76 | key = self.lower_(key) 77 | if idx is not None: 78 | self.index2word[idx] = key 79 | self.word2index[key] = idx 80 | else: 81 | if key not in self.word2index: 82 | idx = len(self.word2index) 83 | self.index2word[idx] = key 84 | self.word2index[key] = idx 85 | 86 | if key not in self.word_count: 87 | self.word_count[key] = count 88 | else: 89 | self.word_count[key] += count 90 | 91 | def add_special(self, key, idx=None, count=1): 92 | self.add(key, idx, count) 93 | self.special.add(key) 94 | 95 | def add_specials(self, keys, idxs): 96 | for key, idx in zip(keys, idxs): 97 | self.add_special(key, idx=idx) 98 | 99 | def add_unk_token(self, unk_token=''): 100 | self.add_special(unk_token) 101 | self.unk_token = unk_token 102 | 103 | def add_pad_token(self, pad_token=''): 104 | self.add_special(pad_token) 105 | self.pad_token = pad_token 106 | 107 | def add_start_token(self, start_token=''): 108 | self.add_special(start_token) 109 | self.start_token = start_token 110 | 111 | def add_end_token(self, end_token=''): 112 | self.add_special(end_token) 113 | self.end_token = end_token 114 | 115 | def lookup_count(self, key): 116 | key = self.lower_(key) 117 | try: 118 | return self.word_count[key] 119 | except KeyError: 120 | return 0 121 | 122 | def sort(self, reverse=True): 123 | """ 124 | Sort Dict by Count 125 | :param reverse: Default is True, high -> low 126 | False, low -> high 127 | 128 | """ 129 | count_word = list() 130 | indexs = list() 131 | for w in self.word2index: 132 | if w in self.special: 133 | continue 134 | count_word.append((self.word_count[w], w)) 135 | indexs.append(self.word2index[w]) 136 | 137 | count_word.sort(reverse=reverse) 138 | indexs.sort(reverse=reverse) 139 | 140 | for index, (_, word) in zip(indexs, count_word): 141 | self.word2index[word] = index 142 | self.index2word[index] = word 143 | 144 | def clear_dictionary(self, keep_special=True): 145 | special_count_index = None 146 | if keep_special: 147 | special_count_index = [(word, self.word_count[word], self.word2index[word]) for word in self.special] 148 | else: 149 | self.special = set() 150 | self.word_count = dict() 151 | self.word2index = dict() 152 | self.index2word = dict() 153 | if special_count_index: 154 | for word, count, index in special_count_index: 155 | self.add_special(key=word, count=count, idx=index) 156 | 157 | def cut_by_top(self, top_k=30000): 158 | """ 159 | Cut Dictionary by Top Count 160 | :param top_k: 161 | """ 162 | if len(self.word2index) <= top_k: 163 | print("Word number (%s) is smaller Top K (%s)" % (len(self.word2index), top_k)) 164 | return 165 | 166 | word_count = list() 167 | for word, count in iteritems(self.word_count): 168 | word_count.append((count, word)) 169 | word_count.sort(reverse=True) 170 | 171 | self.clear_dictionary(keep_special=True) 172 | 173 | added_top_num = 0 174 | for count, word in word_count: 175 | if added_top_num >= top_k: 176 | break 177 | if word not in self.special: 178 | self.add(key=word, count=count) 179 | added_top_num += 1 180 | 181 | print("After cut, Dictionary Size is %d" % len(self)) 182 | 183 | def cut_by_count(self, min_count=1, max_count=None): 184 | """ 185 | Cut Dictionary by Count 186 | :param min_count: 187 | :param max_count: 188 | """ 189 | word_count = list() 190 | for word, count in iteritems(self.word_count): 191 | word_count.append((word, count)) 192 | 193 | self.clear_dictionary(keep_special=True) 194 | 195 | for word, count in word_count: 196 | if min_count is not None and count < min_count: 197 | continue 198 | if max_count is not None and count > max_count: 199 | continue 200 | self.add(word, count=count) 201 | 202 | print("After cut, Dictionary Size is %d" % len(self)) 203 | 204 | def write_to_file(self, filename): 205 | with codecs.open(filename, 'w', 'utf-8') as out: 206 | for word, index in iteritems(self.word2index): 207 | write_str = "%s %s\n" % (word, index) 208 | out.write(write_str.encode('utf8')) 209 | 210 | def convert_to_index(self, words, bos_word=None, eos_word=None): 211 | """ 212 | Convert `words` to indices. 213 | :param words: 214 | :param unk_word: Use `unkWord` if not found. 215 | :param bos_word: Optionally insert `bosWord` at the beginning 216 | :param eos_word: and `eosWord` at the end. 217 | :return: 218 | """ 219 | vec = [] 220 | 221 | if bos_word is not None: 222 | vec += [self.lookup(bos_word)] 223 | 224 | unk = self.lookup(self.unk_token) 225 | vec += [self.lookup(word, default=unk) for word in words] 226 | 227 | if eos_word is not None: 228 | vec += [self.lookup(eos_word)] 229 | 230 | return vec 231 | 232 | def convert_to_word(self, ids): 233 | """ 234 | Convert `ids` to words. 235 | :param ids: 236 | :return: 237 | """ 238 | return [self.index2word[index] for index in ids] 239 | 240 | def contains(self, word): 241 | key = self.lower_(word) 242 | if key in self.word2index: 243 | return True 244 | else: 245 | return False 246 | --------------------------------------------------------------------------------