├── LICENSE ├── README.md ├── data_gen ├── gen_p_e_m │ ├── gen_p_e_m_from_wiki.lua │ ├── gen_p_e_m_from_yago.lua │ ├── merge_crosswikis_wiki.lua │ └── unicode_map.lua ├── gen_test_train_data │ ├── gen_ace_msnbc_aquaint_csv.lua │ ├── gen_aida_test.lua │ ├── gen_aida_train.lua │ └── gen_all.lua ├── gen_wiki_data │ ├── gen_ent_wiki_w_repr.lua │ └── gen_wiki_hyp_train_data.lua ├── indexes │ ├── wiki_disambiguation_pages_index.lua │ ├── wiki_redirects_index.lua │ └── yago_crosswikis_wiki.lua └── parse_wiki_dump │ └── parse_wiki_dump_tools.lua ├── ed ├── args.lua ├── ed.lua ├── loss.lua ├── minibatch │ ├── build_minibatch.lua │ └── data_loader.lua ├── models │ ├── SetConstantDiag.lua │ ├── linear_layers.lua │ ├── model.lua │ ├── model_global.lua │ └── model_local.lua ├── test │ ├── check_coref.lua │ ├── coref_persons.lua │ ├── ent_freq_stats_test.lua │ ├── ent_p_e_m_stats_test.lua │ ├── test.lua │ └── test_one_loaded_model.lua └── train.lua ├── entities ├── ent_name2id_freq │ ├── e_freq_gen.lua │ ├── e_freq_index.lua │ └── ent_name_id.lua ├── learn_e2v │ ├── 4EX_wiki_words.lua │ ├── batch_dataset_a.lua │ ├── e2v_a.lua │ ├── learn_a.lua │ ├── minibatch_a.lua │ └── model_a.lua ├── pretrained_e2v │ ├── check_ents.lua │ ├── e2v.lua │ └── e2v_txt_reader.lua └── relatedness │ ├── filter_wiki_canonical_words_RLTD.lua │ ├── filter_wiki_hyperlink_contexts_RLTD.lua │ └── relatedness.lua ├── our_system_annotations.txt ├── utils ├── logger.lua ├── optim │ ├── adadelta_mem.lua │ ├── adagrad_mem.lua │ └── rmsprop_mem.lua └── utils.lua └── words ├── load_w_freq_and_vecs.lua ├── stop_words.lua ├── w2v ├── glove_reader.lua ├── w2v.lua └── word2vec_reader.lua └── w_freq ├── w_freq_gen.lua └── w_freq_index.lua /README.md: -------------------------------------------------------------------------------- 1 | # Source code for "Deep Joint Entity Disambiguation with Local Neural Attention" 2 | 3 | [O-E. Ganea and T. Hofmann, full paper @ EMNLP 2017](https://arxiv.org/abs/1704.04920) 4 | 5 | Slides and poster can be accessed [here](http://people.inf.ethz.ch/ganeao/). 6 | 7 | ## Pre-trained entity embeddings 8 | 9 | Entity embeddings trained with our method using Word2Vec 300 dimensional pre-trained word vectors (GoogleNews-vectors-negative300.bin). They have norm 1 and are restricted only to entities appearing in the training, validation and test sets described in our paper. Available [here](https://polybox.ethz.ch/index.php/s/sH2JSB2c1OSj7yv). 10 | 11 | ## Full set of annotations made by one of our global models 12 | 13 | See file our_system_annotations.txt . Best to visualize together with its color scheme in a bash terminal. Contains the full set of annotations for the following datasets: 14 | 15 | ``` 16 | $ cat our_system_annotations.txt | grep 'Micro ' 17 | ==> AQUAINT AQUAINT ; EPOCH = 307: Micro recall = 88.03% ; Micro F1 = 89.51% 18 | ==> MSNBC MSNBC ; EPOCH = 307: Micro recall = 93.29% ; Micro F1 = 93.65% 19 | ==> ACE04 ACE04 ; EPOCH = 307: Micro recall = 84.05% ; Micro F1 = 86.92% 20 | ==> aida-B aida-B ; EPOCH = 307: Micro recall = 92.08% ; Micro F1 = 92.08% 21 | ==> aida-A aida-A ; EPOCH = 307: Micro recall = 91.00% ; Micro F1 = 91.01% 22 | ``` 23 | 24 | Global model was trained on AIDA-train with pre-trained entity embeddings trained on Wikipedia. See details of how to run our code below. 25 | 26 | Detailed statistics per dataset as in table 6 of our paper can be accessed: 27 | 28 | ``` 29 | $ cat our_system_annotations.txt | grep -A20 'Micro ' 30 | ``` 31 | 32 | 33 | ## How to run the system and reproduce our results 34 | 35 | 1) Install [Torch](http://torch.ch/) 36 | 37 | 38 | 2) Install torch libraries: cudnn, cutorch, [tds](https://github.com/torch/tds), gnuplot, xlua 39 | 40 | ```luarocks install lib_name``` 41 | 42 | Check that each of these libraries can be imported in a torch terminal. 43 | 44 | 45 | 3) Create a $DATA_PATH directoy (will be assumed to end in '/' in the next steps). Create a directory $DATA_PATH/generated/ that will contain all files generated in the next steps. 46 | 47 | 48 | 4) Download data files needed for training and testing from [this link](https://drive.google.com/uc?id=0Bx8d3azIm_ZcbHMtVmRVc1o5TWM&export=download). 49 | Download basic_data.zip, unzip it and place the basic_data directory in $DATA_PATH/. All generated files will be build based on files in this basic_data/ directory. 50 | 51 | 52 | 5) Download pre-trained Word2Vec vectors GoogleNews-vectors-negative300.bin.gz from https://code.google.com/archive/p/word2vec/. 53 | Unzip it and place the bin file in the folder $DATA_PATH/basic_data/wordEmbeddings/Word2Vec. 54 | 55 | 56 | Now we start creating additional data files needed in our pipeline: 57 | 58 | 6) Create wikipedia_p_e_m.txt: 59 | 60 | ```th data_gen/gen_p_e_m/gen_p_e_m_from_wiki.lua -root_data_dir $DATA_PATH``` 61 | 62 | 63 | 7) Merge wikipedia_p_e_m.txt and crosswikis_p_e_m.txt : 64 | 65 | ```th data_gen/gen_p_e_m/merge_crosswikis_wiki.lua -root_data_dir $DATA_PATH``` 66 | 67 | 68 | 8) Create yago_p_e_m.txt: 69 | 70 | ```th data_gen/gen_p_e_m/gen_p_e_m_from_yago.lua -root_data_dir $DATA_PATH ``` 71 | 72 | 73 | 9) Create a file ent_wiki_freq.txt with entity frequencies: 74 | 75 | ```th entities/ent_name2id_freq/e_freq_gen.lua -root_data_dir $DATA_PATH``` 76 | 77 | 78 | 10) Generate all entity disambiguation datasets in a CSV format needed in our training stage: 79 | 80 | ``` 81 | mkdir $DATA_PATH/generated/test_train_data/ 82 | th data_gen/gen_test_train_data/gen_all.lua -root_data_dir $DATA_PATH 83 | ``` 84 | 85 | Verify the statistics of these files as explained in the header comments of the files gen_ace_msnbc_aquaint_csv.lua and gen_aida_test.lua . 86 | 87 | 88 | 11) Create training data for learning entity embeddings: 89 | 90 | i) From Wiki canonical pages: 91 | 92 | ```th data_gen/gen_wiki_data/gen_ent_wiki_w_repr.lua -root_data_dir $DATA_PATH``` 93 | 94 | ii) From context windows surrounding Wiki hyperlinks: 95 | 96 | ```th data_gen/gen_wiki_data/gen_wiki_hyp_train_data.lua -root_data_dir $DATA_PATH``` 97 | 98 | 99 | 12) Compute the unigram frequency of each word in the Wikipedia corpus: 100 | 101 | ```th words/w_freq/w_freq_gen.lua -root_data_dir $DATA_PATH``` 102 | 103 | 104 | 13) Compute the restricted training data for learning entity embeddings by using only candidate entities from the relatedness datasets and all ED sets: 105 | i) From Wiki canonical pages: 106 | 107 | ```th entities/relatedness/filter_wiki_canonical_words_RLTD.lua -root_data_dir $DATA_PATH``` 108 | 109 | ii) From context windows surrounding Wiki hyperlinks: 110 | 111 | ```th entities/relatedness/filter_wiki_hyperlink_contexts_RLTD.lua -root_data_dir $DATA_PATH``` 112 | 113 | All files in the $DATA_PATH/generated/ folder containing the substring "_RLTD" are restricted to this set of entities (should contain 276030 entities). 114 | 115 | Your $DATA_PATH/generated/ folder should now contain the files : 116 | 117 | ``` 118 | $DATA_PATH/generated $ ls -lah ./ 119 | total 147G 120 | 9.5M all_candidate_ents_ed_rltd_datasets_RLTD.t7 121 | 775M crosswikis_wikipedia_p_e_m.txt 122 | 5.0M empty_page_ents.txt 123 | 520M ent_name_id_map.t7 124 | 95M ent_wiki_freq.txt 125 | 7.3M relatedness_test.t7 126 | 8.9M relatedness_validate.t7 127 | 220 test_train_data 128 | 1.5G wiki_canonical_words_RLTD.txt 129 | 8.4G wiki_canonical_words.txt 130 | 88G wiki_hyperlink_contexts.csv 131 | 48G wiki_hyperlink_contexts_RLTD.csv 132 | 329M wikipedia_p_e_m.txt 133 | 11M word_wiki_freq.txt 134 | 749M yago_p_e_m.txt 135 | 136 | $DATA_PATH//generated/test_train_data $ ls -lah ./ 137 | total 124M 138 | 14M aida_testA.csv 139 | 13M aida_testB.csv 140 | 50M aida_train.csv 141 | 723K wned-ace2004.csv 142 | 1.6M wned-aquaint.csv 143 | 31M wned-clueweb.csv 144 | 1.6M wned-msnbc.csv 145 | 15M wned-wikipedia.csv 146 | ``` 147 | 148 | 14) Now we train entity embeddings for the restricted set of entities (written in all_candidate_ents_ed_rltd_datasets_RLTD.t7). This is the step described in Section 3 of our paper. 149 | 150 | To check the full list of parameters run: 151 | 152 | ```th entities/learn_e2v/learn_a.lua -help``` 153 | 154 | Optimal parameters (see entities/learn_e2v/learn_a.lua): 155 | 156 | ``` 157 | optimization = 'ADAGRAD' 158 | lr = 0.3 159 | batch_size = 500 160 | word_vecs = 'w2v' 161 | num_words_per_ent = 20 162 | num_neg_words = 5 163 | unig_power = 0.6 164 | entities = 'RLTD' 165 | loss = 'maxm' 166 | data = 'wiki-canonical-hyperlinks' 167 | num_passes_wiki_words = 200 168 | hyp_ctxt_len = 10 169 | ``` 170 | 171 | To run the embedding training on one GPU: 172 | 173 | ``` 174 | mkdir $DATA_PATH/generated/ent_vecs 175 | CUDA_VISIBLE_DEVICES=0 th entities/learn_e2v/learn_a.lua -root_data_dir $DATA_PATH |& tee log_train_entity_vecs 176 | ``` 177 | 178 | Warning: This code is not sufficiently optimized to run at maximum speed and GPU usage. Sorry for the inconvenience. It only uses the main thread to load data and perform word embedding lookup. It can be made to run much faster. 179 | 180 | During training, you will see (in the log file log_train_entity_vecs) the validation score on the entity relatedness dataset of the current set of entity embeddings. After around 24 hours (this code is not optimized!), you will see no improvement and can thus stop the training script. Pick the set of saved entity vectors from the folder generated/ent_vecs/ corresponding to the best validation score on the entity relatedness dataset which is the sum of all validation metrics (the TOTAL VALIDATION column). In our paper, we reported in Table 1 the results on the test set corresponding to this best validation score. 181 | 182 | You should get some numbers similar to the following (may vary a little bit due to random initialization): 183 | 184 | ``` 185 | Entity Relatedness quality measure: 186 | measure = NDCG1 NDCG5 NDCG10 MAP TOTAL VALIDATION 187 | our (vald) = 0.681 0.639 0.671 0.619 2.610 188 | our (test) = 0.650 0.609 0.641 0.579 189 | Yamada'16 = 0.59 0.56 0.59 0.52 190 | WikiMW = 0.54 0.52 0.55 0.48 191 | ==> saving model to $DATA_PATH/generated/ent_vecs/ent_vecs__ep_69.t7 192 | ``` 193 | 194 | We will call the name of this file with entity embeddings as $ENTITY_VECS. In our case, it is 'ent_vecs__ep_69.t7' 195 | 196 | This code uses a simple initialization of entity embeddings based on the average of entities' title words (excluding stop words). This helps speed-up training and avoiding getting stuck in local minima. We found that using a random initialization might result in a slight quality decrease for ED only (up to 1%), requiring also a longer training time until reaching the same quality on entity relatedness (~60 hours).` 197 | 198 | 15) Run the training for the global/local ED neural network. Arguments file: ed/args.lua . To list all arguments: 199 | 200 | ```th ed/ed.lua -help``` 201 | 202 | Command to run the training: 203 | 204 | ``` 205 | mkdir $DATA_PATH/generated/ed_models/ 206 | mkdir $DATA_PATH/generated/ed_models/training_plots/ 207 | CUDA_VISIBLE_DEVICES=0 th ed/ed.lua -root_data_dir $DATA_PATH -ent_vecs_filename $ENTITY_VECS -model 'global' |& tee log_train_ed 208 | ``` 209 | 210 | Let it train for at least 48 hours (or 400 epochs as defined in our code), until the validation accuracy does not improve any more or starts dropping. As we wrote in the paper, we stop learning if 211 | the validation F1 does not increase after 500 full epochs of the AIDA train dataset. Validation F1 can be following using the command: 212 | 213 | ```cat log_train_ed | grep -A20 'Micro F1' | grep -A20 'aida-A'``` 214 | 215 | The best ED models will be saved in the folder generated/ed_models. This will only happen after the model gets > 90% F1 score on validation set (see test.lua). 216 | 217 | Statistics, weights and scors will be written in the log_train_ed file. Plots of micro F1 scores on all the validation and test set will be written in the folder $DATA_PATH/generated/ed_models/training_plots/ . 218 | 219 | Results variability: Results reported in our ED paper in Tables 3 and 4 are averaged over different runs of the ED neural architecture learning, but using the same set of entity embeddings. However, we found that the variance of ED results based on different trainings of entity embeddings might be a little higher, up to 0.5%. 220 | 221 | 16) After training is terminated, one can re-load and test the best ED model using the command: 222 | 223 | ``` 224 | CUDA_VISIBLE_DEVICES=0 th ed/test/test_one_loaded_model.lua -root_data_dir $DATA_PATH -model global -ent_vecs_filename $ENTITY_VECS -test_one_model_file $ED_MODEL_FILENAME 225 | ``` 226 | 227 | where $ED_MODEL_FILENAME is a file in $DATA_PATH/generated/ed_models/ . 228 | 229 | 230 | 17) Enjoy! 231 | -------------------------------------------------------------------------------- /data_gen/gen_p_e_m/gen_p_e_m_from_wiki.lua: -------------------------------------------------------------------------------- 1 | -- Generate p(e|m) index from Wikipedia 2 | -- Run: th data_gen/gen_p_e_m/gen_p_e_m_from_wiki.lua -root_data_dir $DATA_PATH 3 | 4 | cmd = torch.CmdLine() 5 | cmd:option('-root_data_dir', '', 'Root path of the data, $DATA_PATH.') 6 | cmd:text() 7 | opt = cmd:parse(arg or {}) 8 | assert(opt.root_data_dir ~= '', 'Specify a valid root_data_dir path argument.') 9 | 10 | 11 | require 'torch' 12 | dofile 'utils/utils.lua' 13 | dofile 'data_gen/parse_wiki_dump/parse_wiki_dump_tools.lua' 14 | tds = tds or require 'tds' 15 | 16 | print('\nComputing Wikipedia p_e_m') 17 | 18 | it, _ = io.open(opt.root_data_dir .. 'basic_data/textWithAnchorsFromAllWikipedia2014Feb.txt') 19 | line = it:read() 20 | 21 | wiki_e_m_counts = tds.Hash() 22 | 23 | -- Find anchors, e.g. anarchism 24 | local num_lines = 0 25 | local parsing_errors = 0 26 | local list_ent_errors = 0 27 | local diez_ent_errors = 0 28 | local disambiguation_ent_errors = 0 29 | local num_valid_hyperlinks = 0 30 | 31 | while (line) do 32 | num_lines = num_lines + 1 33 | if num_lines % 5000000 == 0 then 34 | print('Processed ' .. num_lines .. ' lines. Parsing errs = ' .. 35 | parsing_errors .. ' List ent errs = ' .. 36 | list_ent_errors .. ' diez errs = ' .. diez_ent_errors .. 37 | ' disambig errs = ' .. disambiguation_ent_errors .. 38 | ' . Num valid hyperlinks = ' .. num_valid_hyperlinks) 39 | end 40 | 41 | if not line:find(' b.freq end) 79 | 80 | local str = '' 81 | local total_freq = 0 82 | for _,el in pairs(tbl) do 83 | str = str .. el.ent_wikiid .. ',' .. el.freq 84 | str = str .. ',' .. get_ent_name_from_wikiid(el.ent_wikiid):gsub(' ', '_') .. '\t' 85 | total_freq = total_freq + el.freq 86 | end 87 | ouf:write(mention .. '\t' .. total_freq .. '\t' .. str .. '\n') 88 | end 89 | ouf:flush() 90 | io.close(ouf) 91 | 92 | print(' Done sorting and writing.') 93 | -------------------------------------------------------------------------------- /data_gen/gen_p_e_m/gen_p_e_m_from_yago.lua: -------------------------------------------------------------------------------- 1 | -- Generate p(e|m) index from Wikipedia 2 | -- Run: th data_gen/gen_p_e_m/gen_p_e_m_from_yago.lua -root_data_dir $DATA_PATH 3 | 4 | cmd = torch.CmdLine() 5 | cmd:option('-root_data_dir', '', 'Root path of the data, $DATA_PATH.') 6 | cmd:text() 7 | opt = cmd:parse(arg or {}) 8 | assert(opt.root_data_dir ~= '', 'Specify a valid root_data_dir path argument.') 9 | 10 | 11 | require 'torch' 12 | dofile 'utils/utils.lua' 13 | dofile 'data_gen/gen_p_e_m/unicode_map.lua' 14 | if not get_redirected_ent_title then 15 | dofile 'data_gen/indexes/wiki_redirects_index.lua' 16 | end 17 | if not get_ent_name_from_wikiid then 18 | dofile 'entities/ent_name2id_freq/ent_name_id.lua' 19 | end 20 | 21 | tds = tds or require 'tds' 22 | 23 | print('\nComputing YAGO p_e_m') 24 | 25 | local it, _ = io.open(opt.root_data_dir .. 'basic_data/p_e_m_data/aida_means.tsv') 26 | local line = it:read() 27 | 28 | local num_lines = 0 29 | local wiki_e_m_counts = tds.Hash() 30 | 31 | while (line) do 32 | num_lines = num_lines + 1 33 | if num_lines % 5000000 == 0 then 34 | print('Processed ' .. num_lines .. ' lines.') 35 | end 36 | local parts = split(line, '\t') 37 | assert(table_len(parts) == 2) 38 | assert(parts[1]:sub(1,1) == '"') 39 | assert(parts[1]:sub(parts[1]:len(),parts[1]:len()) == '"') 40 | 41 | local mention = parts[1]:sub(2, parts[1]:len() - 1) 42 | local ent_name = parts[2] 43 | ent_name = string.gsub(ent_name, '&', '&') 44 | ent_name = string.gsub(ent_name, '"', '"') 45 | while ent_name:find('\\u') do 46 | local x = ent_name:find('\\u') 47 | local code = ent_name:sub(x, x + 5) 48 | assert(unicode2ascii[code], code) 49 | replace = unicode2ascii[code] 50 | if(replace == "%") then 51 | replace = "%%" 52 | end 53 | ent_name = string.gsub(ent_name, code, replace) 54 | end 55 | 56 | ent_name = preprocess_ent_name(ent_name) 57 | local ent_wikiid = get_ent_wikiid_from_name(ent_name, true) 58 | if ent_wikiid ~= unk_ent_wikiid then 59 | if not wiki_e_m_counts[mention] then 60 | wiki_e_m_counts[mention] = tds.Hash() 61 | end 62 | wiki_e_m_counts[mention][ent_wikiid] = 1 63 | end 64 | 65 | line = it:read() 66 | end 67 | 68 | 69 | print('Now sorting and writing ..') 70 | out_file = opt.root_data_dir .. 'generated/yago_p_e_m.txt' 71 | ouf = assert(io.open(out_file, "w")) 72 | 73 | for mention, list in pairs(wiki_e_m_counts) do 74 | local str = '' 75 | local total_freq = 0 76 | for ent_wikiid, _ in pairs(list) do 77 | str = str .. ent_wikiid .. ',' .. get_ent_name_from_wikiid(ent_wikiid):gsub(' ', '_') .. '\t' 78 | total_freq = total_freq + 1 79 | end 80 | ouf:write(mention .. '\t' .. total_freq .. '\t' .. str .. '\n') 81 | end 82 | ouf:flush() 83 | io.close(ouf) 84 | 85 | print(' Done sorting and writing.') 86 | -------------------------------------------------------------------------------- /data_gen/gen_p_e_m/merge_crosswikis_wiki.lua: -------------------------------------------------------------------------------- 1 | -- Merge Wikipedia and Crosswikis p(e|m) indexes 2 | -- Run: th data_gen/gen_p_e_m/merge_crosswikis_wiki.lua -root_data_dir $DATA_PATH 3 | 4 | cmd = torch.CmdLine() 5 | cmd:option('-root_data_dir', '', 'Root path of the data, $DATA_PATH.') 6 | cmd:text() 7 | opt = cmd:parse(arg or {}) 8 | assert(opt.root_data_dir ~= '', 'Specify a valid root_data_dir path argument.') 9 | 10 | 11 | require 'torch' 12 | dofile 'utils/utils.lua' 13 | dofile 'entities/ent_name2id_freq/ent_name_id.lua' 14 | 15 | print('\nMerging Wikipedia and Crosswikis p_e_m') 16 | 17 | tds = tds or require 'tds' 18 | merged_e_m_counts = tds.Hash() 19 | 20 | print('Process Wikipedia') 21 | it, _ = io.open(opt.root_data_dir .. 'generated/wikipedia_p_e_m.txt') 22 | line = it:read() 23 | 24 | while (line) do 25 | local parts = split(line, "\t") 26 | local mention = parts[1] 27 | 28 | if (not mention:find('Wikipedia')) and (not mention:find('wikipedia')) then 29 | if not merged_e_m_counts[mention] then 30 | merged_e_m_counts[mention] = tds.Hash() 31 | end 32 | 33 | local total_freq = tonumber(parts[2]) 34 | assert(total_freq) 35 | local num_ents = table_len(parts) 36 | for i = 3, num_ents do 37 | local ent_str = split(parts[i], ",") 38 | local ent_wikiid = tonumber(ent_str[1]) 39 | assert(ent_wikiid) 40 | local freq = tonumber(ent_str[2]) 41 | assert(freq) 42 | 43 | if not merged_e_m_counts[mention][ent_wikiid] then 44 | merged_e_m_counts[mention][ent_wikiid] = 0 45 | end 46 | merged_e_m_counts[mention][ent_wikiid] = merged_e_m_counts[mention][ent_wikiid] + freq 47 | end 48 | end 49 | line = it:read() 50 | end 51 | 52 | 53 | print('Process Crosswikis') 54 | 55 | it, _ = io.open(opt.root_data_dir .. 'basic_data/p_e_m_data/crosswikis_p_e_m.txt') 56 | line = it:read() 57 | 58 | while (line) do 59 | local parts = split(line, "\t") 60 | local mention = parts[1] 61 | 62 | if (not mention:find('Wikipedia')) and (not mention:find('wikipedia')) then 63 | if not merged_e_m_counts[mention] then 64 | merged_e_m_counts[mention] = tds.Hash() 65 | end 66 | 67 | local total_freq = tonumber(parts[2]) 68 | assert(total_freq) 69 | local num_ents = table_len(parts) 70 | for i = 3, num_ents do 71 | local ent_str = split(parts[i], ",") 72 | local ent_wikiid = tonumber(ent_str[1]) 73 | assert(ent_wikiid) 74 | local freq = tonumber(ent_str[2]) 75 | assert(freq) 76 | 77 | if not merged_e_m_counts[mention][ent_wikiid] then 78 | merged_e_m_counts[mention][ent_wikiid] = 0 79 | end 80 | merged_e_m_counts[mention][ent_wikiid] = merged_e_m_counts[mention][ent_wikiid] + freq 81 | end 82 | end 83 | line = it:read() 84 | end 85 | 86 | 87 | print('Now sorting and writing ..') 88 | out_file = opt.root_data_dir .. 'generated/crosswikis_wikipedia_p_e_m.txt' 89 | ouf = assert(io.open(out_file, "w")) 90 | 91 | for mention, list in pairs(merged_e_m_counts) do 92 | if mention:len() >= 1 then 93 | local tbl = {} 94 | for ent_wikiid, freq in pairs(list) do 95 | table.insert(tbl, {ent_wikiid = ent_wikiid, freq = freq}) 96 | end 97 | table.sort(tbl, function(a,b) return a.freq > b.freq end) 98 | 99 | local str = '' 100 | local total_freq = 0 101 | local num_ents = 0 102 | for _,el in pairs(tbl) do 103 | if is_valid_ent(el.ent_wikiid) then 104 | str = str .. el.ent_wikiid .. ',' .. el.freq 105 | str = str .. ',' .. get_ent_name_from_wikiid(el.ent_wikiid):gsub(' ', '_') .. '\t' 106 | num_ents = num_ents + 1 107 | total_freq = total_freq + el.freq 108 | 109 | if num_ents >= 100 then -- At most 100 candidates 110 | break 111 | end 112 | end 113 | end 114 | ouf:write(mention .. '\t' .. total_freq .. '\t' .. str .. '\n') 115 | end 116 | end 117 | ouf:flush() 118 | io.close(ouf) 119 | 120 | print(' Done sorting and writing.') 121 | 122 | -------------------------------------------------------------------------------- /data_gen/gen_p_e_m/unicode_map.lua: -------------------------------------------------------------------------------- 1 | local unicode = {'\\u00bb', '\\u007d', '\\u00a1', '\\u0259', '\\u0641', '\\u0398', '\\u00fd', '\\u0940', '\\u00f9', '\\u02bc', '\\u00f6', '\\u00f8', '\\u0107', '\\u0648', '\\u0105', '\\u002c', '\\u6768', '\\u0160', '\\u015b', '\\u00c0', '\\u266f', '\\u0430', '\\u0141', '\\u1ea3', '\\u00df', '\\u2212', '\\u8650', '\\u012d', '\\u1e47', '\\u00c5', '\\u00ab', '\\u0226', '\\u0930', '\\u04a4', '\\u030d', '\\u0631', '\\u207f', '\\u00bf', '\\u2010', '\\u1e6e', '\\u00cd', '\\u00c6', '\\u30e8', '\\u1e63', '\\u6a5f', '\\u03c3', '\\u00d5', '\\u0644', '\\u2020', '\\u0104', '\\u010a', '\\u013c', '\\u0123', '\\u0159', '\\u1e6f', '\\u003a', '\\u06af', '\\u00fc', '\\u4e09', '\\u0028', '\\u03b2', '\\u2103', '\\u0191', '\\u03bc', '\\u00d1', '\\u207a', '\\u79d2', '\\u6536', '\\u1ed3', '\\u0329', '\\u0196', '\\u00fb', '\\u0435', '\\u01b0', '\\u007e', '\\u1e62', '\\u0181', '\\u1ea7', '\\u2011', '\\u03c9', '\\u201d', '\\u0165', '\\u0422', '\\u1e33', '\\u0144', '\\u00fa', '\\u1ed5', '\\u0632', '\\u0643', '\\u1ea1', '\\u011e', '\\u062a', '\\u00ee', '\\u00c2', '\\u016d', '\\u003d', '\\u2202', '\\u2605', '\\u0112', '\\u73cd', '\\u03a1', '\\u0182', '\\u00d2', '\\u0153', '\\u016f', '\\u00de', '\\u00a3', '\\u1e45', '\\u1ef1', '\\u4e45', '\\u06cc', '\\u1ea2', '\\u0152', '\\uff09', '\\u0219', '\\u0457', '\\u0283', '\\u1e5a', '\\u064e', '\\u0164', '\\u0116', '\\u018b', '\\u1ec3', '\\u00b4', '\\u002b', '\\u7248', '\\u0937', '\\u203a', '\\u4eba', '\\u002f', '\\u0136', '\\u01bc', '\\u017b', '\\u00a9', '\\u03a9', '\\u00f2', '\\u2026', '\\u00c7', '\\u0969', '\\u0198', '\\u011f', '\\u00e0', '\\u0126', '\\u018a', '\\u1edf', '\\u005e', '\\u03b4', '\\u0137', '\\u01f5', '\\u1e34', '\\u007b', '\\u00f3', '\\u01c0', '\\u00f1', '\\u1ef9', '\\u03c8', '\\ub8e8', '\\u0119', '\\u014f', '\\uff5e', '\\u016c', '\\u0358', '\\u529f', '\\u2606', '\\u00e4', '\\u012b', '\\u00c9', '\\u0173', '\\u092f', '\\u5957', '\\u1e49', '\\u1ec1', '\\u019f', '\\u02bb', '\\u0399', '\\u01c1', '\\u03d5', '\\u017a', '\\u1e0d', '\\u0148', '\\u01a4', '\\u00a2', '\\u011c', '\\u1e92', '\\u01b1', '\\u0443', '\\u00f4', '\\u2122', '\\u82e5', '\\u0967', '\\u1eaf', '\\u013e', '\\u1e46', '\\u03b1', '\\u884c', '\\u0328', '\\u0021', '\\u00aa', '\\u014d', '\\u002e', '\\u00cb', '\\u062f', '\\u0102', '\\u0155', '\\u00cf', '\\u0446', '\\u1e80', '\\u003b', '\\u6c38', '\\u0103', '\\u1e6c', '\\u203c', '\\u00dc', '\\u00b3', '\\u0145', '\\u0122', '\\u8fdb', '\\u015e', '\\u017c', '\\u043f', '\\u0442', '\\u0629', '\\u6176', '\\u1edd', '\\u2018', '\\u00ea', '\\u0060', '\\u0147', '\\u00e7', '\\u00dd', '\\u00d6', '\\u043a', '\\u00ec', '\\ufb01', '\\u0124', '\\u1e5f', '\\u0431', '\\u1e94', '\\u1ea8', '\\u01b3', '\\u018f', '\\u0627', '\\u1ef3', '\\u091f', '\\u03a6', '\\u674e', '\\u016b', '\\u039b', '\\u2032', '\\u002a', '\\u2033', '\\u00b7', '\\u00ce', '\\u2075', '\\u043c', '\\u2116', '\\u1e6d', '\\u00be', '\\u0171', '\\u0433', '\\u0635', '\\u0640', '\\u01a1', '\\u00a7', '\\u019d', '\\u301c', '\\u00f5', '\\u0187', '\\u1ef6', '\\u1e0c', '\\u1ec5', '\\u01b8', '\\u00f0', '\\u1e93', '\\u00eb', '\\u03bd', '\\u0108', '\\u010d', '\\u5229', '\\u2080', '\\u00c8', '\\u9ece', '\\u0917', '\\u85cf', '\\u1edb', '\\u1e25', '\\u045b', '\\u1ec9', '\\u1ecd', '\\u0633', '\\u2022', '\\u01e8', '\\u0197', '\\u2019', '\\u0421', '\\u1ecb', '\\u9910', '\\uc2a4', '\\u1e31', '\\u00c1', '\\u1ea9', '\\u1eeb', '\\u01f4', '\\u01c2', '\\u0146', '\\u0162', '\\ufb02', '\\u01ac', '\\u0025', '\\u015c', '\\u01ce', '\\u01c3', '\\u0179', '\\u01e5', '\\u58eb', '\\u745e', '\\uff08', '\\u016a', '\\u03a5', '\\u039c', '\\u00bd', '\\u0169', '\\u55f7', '\\u89d2', '\\u00e6', '\\u9752', '\\u005b', '\\u003f', '\\u041f', '\\u06a9', '\\u0027', '\\u1ebf', '\\u0646', '\\u0130', '\\u1eb1', '\\u0395', '\\u03b3', '\\u01d4', '\\u00ed', '\\u041a', '\\u0149', '\\u0143', '\\u010f', '\\u1e24', '\\u0121', '\\u1ecf', '\\u06c1', '\\u00d3', '\\u0029', '\\u02bf', '\\u010e', '\\u1e0e', '\\u01d2', '\\u00ff', '\\u00fe', '\\u03a0', '\\u1ebc', '\\u2153', '\\u00e1', '\\u00ca', '\\u012c', '\\u017d', '\\u0110', '\\u266d', '\\u0639', '\\u014e', '\\u1e43', '\\u0026', '\\u20ac', '\\u0024', '\\u011d', '\\u003e', '\\u0163', '\\u0939', '\\u221a', '\\u00e3', '\\u65f6', '\\u0118', '\\u0101', '\\u0628', '\\u221e', '\\u1ed1', '\\u0393', '\\u00c4', '\\u0161', '\\u00e9', '\\u0220', '\\u0115', '\\u002d', '\\u03c0', '\\u0177', '\\u00ba', '\\u0158', '\\u01a7', '\\u0117', '\\u043b', '\\u00d7', '\\u00e2', '\\u0175', '\\u0420', '\\u0391', '\\u00b2', '\\u014c', '\\u2013', '\\u00b9', '\\u1ef7', '\\u064a', '\\u0301', '\\u95a2', '\\u0113', '\\u013b', '\\u094d', '\\u03b5', '\\u1eef', '\\u2c6b', '\\u00da', '\\u00d8', '\\u0432', '\\u0109', '\\u00d9', '\\u00d4', '\\u011b', '\\u0303', '\\u0392', '\\u1eed', '\\u0444', '\\u026a', '\\u0218', '\\u00ef', '\\u1ed9', '\\u00b0', '\\u010c', '\\uac00', '\\u02be', '\\u2012', '\\u5baa', '\\u00e8', '\\u1ebd', '\\u30fb', '\\u0127', '\\u010b', '\\u0131', '\\u1ebb', '\\u0150', '\\u0327', '\\u0100', '\\u1ee7', '\\u1ed7', '\\u0129', '\\u00c3', '\\u003c', '\\u2260', '\\u0106', '\\u6625', '\\u0184', '\\u1eb5', '\\u4fdd', '\\u00b1', '\\u021b', '\\u014b', '\\uff0d', '\\u1e2a', '\\u00e5', '\\u017e', '\\u011a', '\\u1eab', '\\u200e', '\\u1e35', '\\u1e5b', '\\u2192', '\\u0040', '\\u1eb7', '\\u01b2', '\\u5b58', '\\u201c', '\\u015f', '\\u01e6', '\\u0111', '\\u738b', '\\u03a7', '\\u1ead', '\\u1ec7', '\\u0324', '\\u2665', '\\ub9c8', '\\u6bba', '\\u0151', '\\u2661', '\\u03ba', '\\ua784', '\\u2014', '\\u1ee9', '\\u0120', '\\u012a', '\\u7433', '\\u0134', '\\u039a', '\\u1ee3', '\\u1ea5', '\\u1ee5', '\\u0142', '\\u043e', '\\u01eb', '\\u0440', '\\u03a3', '\\u093e', '\\u00d0', '\\u092e', '\\u00b5', '\\u013d', '\\u1ecc', '\\u0394', '\\u00bc', '\\u01d0', '\\u015a', '\\u02b9', '\\u0645', '\\u043d', '\\u00cc'} 2 | 3 | local ascii = {'»', '}', '¡', 'ə', 'ف', 'Θ', 'ý', 'ी', 'ù', 'ʼ', 'ö', 'ø', 'ć', 'و', 'ą', ',', '杨', 'Š', 'ś', 'À', '♯', 'а', 'Ł', 'ả', 'ß', '−', '虐', 'ĭ', 'ṇ', 'Å', '«', 'Ȧ', 'र', 'Ҥ', 'ʼ', 'ر', 'ⁿ', '¿', '‐', 'Ṯ', 'Í', 'Æ', 'ヨ', 'ṣ', '機', 'σ', 'Õ', 'ل', '†', 'Ą', 'Ċ', 'ļ', 'ģ', 'ř', 'ṯ', ':', 'گ', 'ü', '三', '(', 'β', '℃', 'Ƒ', 'μ', 'Ñ', '⁺', '秒', '收', 'ồ', '̩', 'Ɩ', 'û', 'е', 'ư', '~', 'Ṣ', 'Ɓ', 'ầ', '‑', 'ω', '”', 'ť', 'Т', 'ḳ', 'ń', 'ú', 'ổ', 'ز', 'ك', 'ạ', 'Ğ', 'ت', 'î', 'Â', 'ŭ', '=', '∂', '★', 'Ē', '珍', 'Ρ', 'Ƃ', 'Ò', 'œ', 'ů', 'Þ', '£', 'ṅ', 'ự', '久', 'ی', 'Ả', 'Œ', ')', 'ș', 'ї', 'ʃ', 'Ṛ', 'َ', 'Ť', 'Ė', 'Ƌ', 'ể', '´', '+', '版', 'ष', '›', '人', '/', 'Ķ', 'Ƽ', 'Ż', '©', 'Ω', 'ò', '…', 'Ç', '३', 'Ƙ', 'ğ', 'à', 'Ħ', 'Ɗ', 'ở', '^', 'δ', 'ķ', 'ǵ', 'Ḵ', '{', 'ó', 'ǀ', 'ñ', 'ỹ', 'ψ', '루', 'ę', 'ŏ', '~', 'Ŭ', '͘', '功', '☆', 'ä', 'ī', 'É', 'ų', 'य', '套', 'ṉ', 'ề', 'Ɵ', 'ʻ', 'Ι', 'ǁ', 'ϕ', 'ź', 'ḍ', 'ň', 'Ƥ', '¢', 'Ĝ', 'Ẓ', 'Ʊ', 'у', 'ô', '™', '若', '१', 'ắ', 'ľ', 'Ṇ', 'α', '行', '̨', '!', 'ª', 'ō', '.', 'Ë', 'د', 'Ă', 'ŕ', 'Ï', 'ц', 'Ẁ', ';', '永', 'ă', 'Ṭ', '‼', 'Ü', '³', 'Ņ', 'Ģ', '进', 'Ş', 'ż', 'п', 'т', 'ة', '慶', 'ờ', '‘', 'ê', '`', 'Ň', 'ç', 'Ý', 'Ö', 'к', 'ì', 'fi', 'Ĥ', 'ṟ', 'б', 'Ẕ', 'Ẩ', 'Ƴ', 'Ə', 'ا', 'ỳ', 'ट', 'Φ', '李', 'ū', 'Λ', '′', '*', '″', '·', 'Î', '⁵', 'м', '№', 'ṭ', '¾', 'ű', 'г', 'ص', 'ـ', 'ơ', '§', 'Ɲ', '〜', 'õ', 'Ƈ', 'Ỷ', 'Ḍ', 'ễ', 'Ƹ', 'ð', 'ẓ', 'ë', 'ν', 'Ĉ', 'č', '利', '₀', 'È', '黎', 'ग', '藏', 'ớ', 'ḥ', 'ћ', 'ỉ', 'ọ', 'س', '•', 'Ǩ', 'Ɨ', '’', 'С', 'ị', '餐', '스', 'ḱ', 'Á', 'ẩ', 'ừ', 'Ǵ', 'ǂ', 'ņ', 'Ţ', 'fl', 'Ƭ', '%', 'Ŝ', 'ǎ', 'ǃ', 'Ź', 'ǥ', '士', '瑞', '(', 'Ū', 'Υ', 'Μ', '½', 'ũ', '嗷', '角', 'æ', '青', '[', '?', 'П', 'ک', '\'', 'ế', 'ن', 'İ', 'ằ', 'Ε', 'γ', 'ǔ', 'í', 'К', 'ʼn', 'Ń', 'ď', 'Ḥ', 'ġ', 'ỏ', 'ہ', 'Ó', ')', 'ʿ', 'Ď', 'Ḏ', 'ǒ', 'ÿ', 'þ', 'Π', 'Ẽ', '⅓', 'á', 'Ê', 'Ĭ', 'Ž', 'Đ', '♭', 'ع', 'Ŏ', 'ṃ', '&', '€', '$', 'ĝ', '>', 'ţ', 'ह', '√', 'ã', '时', 'Ę', 'ā', 'ب', '∞', 'ố', 'Γ', 'Ä', 'š', 'é', 'Ƞ', 'ĕ', '-', 'π', 'ŷ', 'º', 'Ř', 'Ƨ', 'ė', 'л', '×', 'â', 'ŵ', 'Р', 'Α', '²', 'Ō', '–', '¹', 'ỷ', 'ي', '́', '関', 'ē', 'Ļ', '्', 'ε', 'ữ', 'Ⱬ', 'Ú', 'Ø', 'в', 'ĉ', 'Ù', 'Ô', 'ě', '̃', 'Β', 'ử', 'ф', 'ɪ', 'Ș', 'ï', 'ộ', '°', 'Č', '가', 'ʾ', '‒', '宪', 'è', 'ẽ', '・', 'ħ', 'ċ', 'ı', 'ẻ', 'Ő', '̧', 'Ā', 'ủ', 'ỗ', 'ĩ', 'Ã', '<', '≠', 'Ć', '春', 'Ƅ', 'ẵ', '保', '±', 'ț', 'ŋ', '-', 'Ḫ', 'å', 'ž', 'Ě', 'ẫ', '‎', 'ḵ', 'ṛ', '→', '@', 'ặ', 'Ʋ', '存', '“', 'ş', 'Ǧ', 'đ', '王', 'Χ', 'ậ', 'ệ', '̤', '♥', '마', '殺', 'ő', '♡', 'κ', 'Ꞅ', '—', 'ứ', 'Ġ', 'Ī', '琳', 'Ĵ', 'Κ', 'ợ', 'ấ', 'ụ', 'ł', 'о', 'ǫ', 'р', 'Σ', 'ा', 'Ð', 'म', 'µ', 'Ľ', 'Ọ', 'Δ', '¼', 'ǐ', 'Ś', 'ʹ', 'م', 'н', 'Ì'} 4 | 5 | dofile 'utils/utils.lua' 6 | assert(table_len(unicode) == table_len(ascii)) 7 | 8 | unicode2ascii = {} 9 | for i,letter in pairs(ascii) do 10 | unicode2ascii[unicode[i]] = ascii[i] 11 | unicode2ascii['\\u0022'] = '"' 12 | unicode2ascii['\\u0023'] = '#' 13 | unicode2ascii['\\u005c'] = '\\' 14 | unicode2ascii['\\u00a0'] = '' 15 | end 16 | 17 | -------------------------------------------------------------------------------- /data_gen/gen_test_train_data/gen_ace_msnbc_aquaint_csv.lua: -------------------------------------------------------------------------------- 1 | -- Generate test data from the ACE/MSNBC/AQUAINT datasets by keeping the context and 2 | -- entity candidates for each annotated mention 3 | 4 | -- Format: 5 | -- doc_name \t doc_name \t mention \t left_ctxt \t right_ctxt \t CANDIDATES \t [ent_wikiid,p_e_m,ent_name]+ \t GT: \t pos,ent_wikiid,p_e_m,ent_name 6 | 7 | -- Stats: 8 | --cat wned-ace2004.csv | wc -l 9 | --257 10 | --cat wned-ace2004.csv | grep -P 'GT:\t-1' | wc -l 11 | --20 12 | --cat wned-ace2004.csv | grep -P 'GT:\t1,' | wc -l 13 | --217 14 | 15 | --cat wned-aquaint.csv | wc -l 16 | --727 17 | --cat wned-aquaint.csv | grep -P 'GT:\t-1' | wc -l 18 | --33 19 | --cat wned-aquaint.csv | grep -P 'GT:\t1,' | wc -l 20 | --604 21 | 22 | --cat wned-msnbc.csv | wc -l 23 | --656 24 | --cat wned-msnbc.csv | grep -P 'GT:\t-1' | wc -l 25 | --22 26 | --cat wned-msnbc.csv | grep -P 'GT:\t1,' | wc -l 27 | --496 28 | 29 | if not ent_p_e_m_index then 30 | require 'torch' 31 | dofile 'data_gen/indexes/wiki_redirects_index.lua' 32 | dofile 'data_gen/indexes/yago_crosswikis_wiki.lua' 33 | dofile 'utils/utils.lua' 34 | end 35 | 36 | tds = tds or require 'tds' 37 | 38 | local function gen_test_ace(dataset) 39 | 40 | print('\nGenerating test data from ' .. dataset .. ' set ') 41 | 42 | local path = opt.root_data_dir .. 'basic_data/test_datasets/wned-datasets/' .. dataset .. '/' 43 | 44 | out_file = opt.root_data_dir .. 'generated/test_train_data/wned-' .. dataset .. '.csv' 45 | ouf = assert(io.open(out_file, "w")) 46 | 47 | annotations, _ = io.open(path .. dataset .. '.xml') 48 | 49 | local num_nonexistent_ent_id = 0 50 | local num_correct_ents = 0 51 | 52 | local cur_doc_text = '' 53 | local cur_doc_name = '' 54 | 55 | local line = annotations:read() 56 | while (line) do 57 | if (not line:find('document docName=\"')) then 58 | if line:find('') then 59 | line = annotations:read() 60 | local x,y = line:find('') 61 | local z,t = line:find('') 62 | local cur_mention = line:sub(y + 1, z - 1) 63 | cur_mention = string.gsub(cur_mention, '&', '&') 64 | 65 | line = annotations:read() 66 | x,y = line:find('') 67 | z,t = line:find('') 68 | cur_ent_title = '' 69 | if not line:find('') then 70 | cur_ent_title = line:sub(y + 1, z - 1) 71 | end 72 | 73 | line = annotations:read() 74 | x,y = line:find('') 75 | z,t = line:find('') 76 | local offset = 1 + tonumber(line:sub(y + 1, z - 1)) 77 | 78 | line = annotations:read() 79 | x,y = line:find('') 80 | z,t = line:find('') 81 | local length = tonumber(line:sub(y + 1, z - 1)) 82 | length = cur_mention:len() 83 | 84 | line = annotations:read() 85 | if line:find('') then 86 | line = annotations:read() 87 | end 88 | 89 | assert(line:find('')) 90 | 91 | offset = math.max(1, offset - 10) 92 | while (cur_doc_text:sub(offset, offset + length - 1) ~= cur_mention) do 93 | -- print(cur_mention .. ' ---> ' .. cur_doc_text:sub(offset, offset + length - 1)) 94 | offset = offset + 1 95 | end 96 | 97 | cur_mention = preprocess_mention(cur_mention) 98 | 99 | if cur_ent_title ~= 'NIL' and cur_ent_title ~= '' and cur_ent_title:len() > 0 then 100 | local cur_ent_wikiid = get_ent_wikiid_from_name(cur_ent_title) 101 | if cur_ent_wikiid == unk_ent_wikiid then 102 | num_nonexistent_ent_id = num_nonexistent_ent_id + 1 103 | print(green(cur_ent_title)) 104 | else 105 | num_correct_ents = num_correct_ents + 1 106 | end 107 | 108 | assert(cur_mention:len() > 0) 109 | local str = cur_doc_name .. '\t' .. cur_doc_name .. '\t' .. cur_mention .. '\t' 110 | 111 | local left_words = split_in_words(cur_doc_text:sub(1, offset - 1)) 112 | local num_left_words = table_len(left_words) 113 | local left_ctxt = {} 114 | for i = math.max(1, num_left_words - 100 + 1), num_left_words do 115 | table.insert(left_ctxt, left_words[i]) 116 | end 117 | if table_len(left_ctxt) == 0 then 118 | table.insert(left_ctxt, 'EMPTYCTXT') 119 | end 120 | str = str .. table.concat(left_ctxt, ' ') .. '\t' 121 | 122 | local right_words = split_in_words(cur_doc_text:sub(offset + length)) 123 | local num_right_words = table_len(right_words) 124 | local right_ctxt = {} 125 | for i = 1, math.min(num_right_words, 100) do 126 | table.insert(right_ctxt, right_words[i]) 127 | end 128 | if table_len(right_ctxt) == 0 then 129 | table.insert(right_ctxt, 'EMPTYCTXT') 130 | end 131 | str = str .. table.concat(right_ctxt, ' ') .. '\tCANDIDATES\t' 132 | 133 | 134 | -- Entity candidates from p(e|m) dictionary 135 | if ent_p_e_m_index[cur_mention] and #(ent_p_e_m_index[cur_mention]) > 0 then 136 | 137 | local sorted_cand = {} 138 | for ent_wikiid,p in pairs(ent_p_e_m_index[cur_mention]) do 139 | table.insert(sorted_cand, {ent_wikiid = ent_wikiid, p = p}) 140 | end 141 | table.sort(sorted_cand, function(a,b) return a.p > b.p end) 142 | 143 | local candidates = {} 144 | local gt_pos = -1 145 | for pos,e in pairs(sorted_cand) do 146 | if pos <= 100 then 147 | table.insert(candidates, e.ent_wikiid .. ',' .. string.format("%.3f", e.p) .. ',' .. get_ent_name_from_wikiid(e.ent_wikiid)) 148 | if e.ent_wikiid == cur_ent_wikiid then 149 | gt_pos = pos 150 | end 151 | else 152 | break 153 | end 154 | end 155 | str = str .. table.concat(candidates, '\t') .. '\tGT:\t' 156 | 157 | if gt_pos > 0 then 158 | ouf:write(str .. gt_pos .. ',' .. candidates[gt_pos] .. '\n') 159 | else 160 | if cur_ent_wikiid ~= unk_ent_wikiid then 161 | ouf:write(str .. '-1,' .. cur_ent_wikiid .. ',' .. cur_ent_title .. '\n') 162 | else 163 | ouf:write(str .. '-1\n') 164 | end 165 | end 166 | else 167 | if cur_ent_wikiid ~= unk_ent_wikiid then 168 | ouf:write(str .. 'EMPTYCAND\tGT:\t-1,' .. cur_ent_wikiid .. ',' .. cur_ent_title .. '\n') 169 | else 170 | ouf:write(str .. 'EMPTYCAND\tGT:\t-1\n') 171 | end 172 | end 173 | 174 | end 175 | end 176 | else 177 | local x,y = line:find('document docName=\"') 178 | local z,t = line:find('\">') 179 | cur_doc_name = line:sub(y + 1, z - 1) 180 | cur_doc_name = string.gsub(cur_doc_name, '&', '&') 181 | 182 | local it,_ = io.open(path .. 'RawText/' .. cur_doc_name) 183 | cur_doc_text = '' 184 | local cur_line = it:read() 185 | while cur_line do 186 | cur_doc_text = cur_doc_text .. cur_line .. ' ' 187 | cur_line = it:read() 188 | end 189 | cur_doc_text = string.gsub(cur_doc_text, '&', '&') 190 | end 191 | line = annotations:read() 192 | end 193 | 194 | ouf:flush() 195 | io.close(ouf) 196 | 197 | print('Done ' .. dataset .. '.') 198 | print('num_nonexistent_ent_id = ' .. num_nonexistent_ent_id .. '; num_correct_ents = ' .. num_correct_ents) 199 | end 200 | 201 | 202 | gen_test_ace('wikipedia') 203 | gen_test_ace('clueweb') 204 | gen_test_ace('ace2004') 205 | gen_test_ace('msnbc') 206 | gen_test_ace('aquaint') 207 | -------------------------------------------------------------------------------- /data_gen/gen_test_train_data/gen_aida_test.lua: -------------------------------------------------------------------------------- 1 | -- Generate test data from the AIDA dataset by keeping the context and 2 | -- entity candidates for each annotated mention 3 | 4 | -- Format: 5 | -- doc_name \t doc_name \t mention \t left_ctxt \t right_ctxt \t CANDIDATES \t [ent_wikiid,p_e_m,ent_name]+ \t GT: \t pos,ent_wikiid,p_e_m,ent_name 6 | 7 | -- Stats: 8 | --cat aida_testA.csv | wc -l 9 | --4791 10 | --cat aida_testA.csv | grep -P 'GT:\t-1' | wc -l 11 | --43 12 | --cat aida_testA.csv | grep -P 'GT:\t1,' | wc -l 13 | --3401 14 | 15 | --cat aida_testB.csv | wc -l 16 | --4485 17 | --cat aida_testB.csv | grep -P 'GT:\t-1' | wc -l 18 | --19 19 | --cat aida_testB.csv | grep -P 'GT:\t1,' | wc -l 20 | --3084 21 | 22 | if not ent_p_e_m_index then 23 | require 'torch' 24 | dofile 'data_gen/indexes/wiki_redirects_index.lua' 25 | dofile 'data_gen/indexes/yago_crosswikis_wiki.lua' 26 | dofile 'utils/utils.lua' 27 | end 28 | 29 | tds = tds or require 'tds' 30 | 31 | print('\nGenerating test data from AIDA set ') 32 | 33 | it, _ = io.open(opt.root_data_dir .. 'basic_data/test_datasets/AIDA/testa_testb_aggregate_original') 34 | 35 | out_file_A = opt.root_data_dir .. 'generated/test_train_data/aida_testA.csv' 36 | out_file_B = opt.root_data_dir .. 'generated/test_train_data/aida_testB.csv' 37 | 38 | ouf_A = assert(io.open(out_file_A, "w")) 39 | ouf_B = assert(io.open(out_file_B, "w")) 40 | 41 | local ouf = ouf_A 42 | 43 | local num_nme = 0 44 | local num_nonexistent_ent_title = 0 45 | local num_nonexistent_ent_id = 0 46 | local num_nonexistent_both = 0 47 | local num_correct_ents = 0 48 | local num_total_ents = 0 49 | 50 | local cur_words_num = 0 51 | local cur_words = {} 52 | local cur_mentions = {} 53 | local cur_mentions_num = 0 54 | 55 | local cur_doc_name = '' 56 | 57 | local function write_results() 58 | -- Write results: 59 | if cur_doc_name ~= '' then 60 | local header = cur_doc_name .. '\t' .. cur_doc_name .. '\t' 61 | for _, hyp in pairs(cur_mentions) do 62 | assert(hyp.mention:len() > 0, line) 63 | local mention = hyp.mention 64 | local str = header .. hyp.mention .. '\t' 65 | 66 | local left_ctxt = {} 67 | for i = math.max(0, hyp.start_off - 100), hyp.start_off - 1 do 68 | table.insert(left_ctxt, cur_words[i]) 69 | end 70 | if table_len(left_ctxt) == 0 then 71 | table.insert(left_ctxt, 'EMPTYCTXT') 72 | end 73 | str = str .. table.concat(left_ctxt, ' ') .. '\t' 74 | 75 | local right_ctxt = {} 76 | for i = hyp.end_off + 1, math.min(cur_words_num, hyp.end_off + 100) do 77 | table.insert(right_ctxt, cur_words[i]) 78 | end 79 | if table_len(right_ctxt) == 0 then 80 | table.insert(right_ctxt, 'EMPTYCTXT') 81 | end 82 | str = str .. table.concat(right_ctxt, ' ') .. '\tCANDIDATES\t' 83 | 84 | -- Entity candidates from p(e|m) dictionary 85 | if ent_p_e_m_index[mention] and #(ent_p_e_m_index[mention]) > 0 then 86 | 87 | local sorted_cand = {} 88 | for ent_wikiid,p in pairs(ent_p_e_m_index[mention]) do 89 | table.insert(sorted_cand, {ent_wikiid = ent_wikiid, p = p}) 90 | end 91 | table.sort(sorted_cand, function(a,b) return a.p > b.p end) 92 | 93 | local candidates = {} 94 | local gt_pos = -1 95 | for pos,e in pairs(sorted_cand) do 96 | if pos <= 100 then 97 | table.insert(candidates, e.ent_wikiid .. ',' .. string.format("%.3f", e.p) .. ',' .. get_ent_name_from_wikiid(e.ent_wikiid)) 98 | if e.ent_wikiid == hyp.ent_wikiid then 99 | gt_pos = pos 100 | end 101 | else 102 | break 103 | end 104 | end 105 | str = str .. table.concat(candidates, '\t') .. '\tGT:\t' 106 | 107 | if gt_pos > 0 then 108 | ouf:write(str .. gt_pos .. ',' .. candidates[gt_pos] .. '\n') 109 | else 110 | if hyp.ent_wikiid ~= unk_ent_wikiid then 111 | ouf:write(str .. '-1,' .. hyp.ent_wikiid .. ',' .. get_ent_name_from_wikiid(hyp.ent_wikiid) .. '\n') 112 | else 113 | ouf:write(str .. '-1\n') 114 | end 115 | end 116 | else 117 | if hyp.ent_wikiid ~= unk_ent_wikiid then 118 | ouf:write(str .. 'EMPTYCAND\tGT:\t-1,' .. hyp.ent_wikiid .. ',' .. get_ent_name_from_wikiid(hyp.ent_wikiid) .. '\n') 119 | else 120 | ouf:write(str .. 'EMPTYCAND\tGT:\t-1\n') 121 | end 122 | end 123 | end 124 | end 125 | end 126 | 127 | 128 | 129 | local line = it:read() 130 | while (line) do 131 | if (not line:find('-DOCSTART-')) then 132 | local parts = split(line, '\t') 133 | local num_parts = table_len(parts) 134 | assert(num_parts == 0 or num_parts == 1 or num_parts == 4 or num_parts == 7 or num_parts == 6, line) 135 | if num_parts > 0 then 136 | if num_parts == 4 and parts[2] == 'B' then 137 | num_nme = num_nme + 1 138 | end 139 | 140 | if (num_parts == 7 or num_parts == 6) and parts[2] == 'B' then 141 | 142 | -- Find current mention. A few hacks here. 143 | local cur_mention = preprocess_mention(parts[3]) 144 | 145 | local x,y = parts[5]:find('/wiki/') 146 | local cur_ent_title = parts[5]:sub(y + 1) 147 | local cur_ent_wikiid = tonumber(parts[6]) 148 | local index_ent_title = get_ent_name_from_wikiid(cur_ent_wikiid) 149 | local index_ent_wikiid = get_ent_wikiid_from_name(cur_ent_title) 150 | 151 | local final_ent_wikiid = index_ent_wikiid 152 | if final_ent_wikiid == unk_ent_wikiid then 153 | final_ent_wikiid = cur_ent_wikiid 154 | end 155 | 156 | if (index_ent_title == cur_ent_title and cur_ent_wikiid == index_ent_wikiid) then 157 | num_correct_ents = num_correct_ents + 1 158 | elseif (index_ent_title ~= cur_ent_title and cur_ent_wikiid ~= index_ent_wikiid) then 159 | num_nonexistent_both = num_nonexistent_both + 1 160 | elseif index_ent_title ~= cur_ent_title then 161 | assert(cur_ent_wikiid == index_ent_wikiid) 162 | num_nonexistent_ent_title = num_nonexistent_ent_title + 1 163 | else 164 | assert(index_ent_title == cur_ent_title) 165 | assert(cur_ent_wikiid ~= index_ent_wikiid) 166 | num_nonexistent_ent_id = num_nonexistent_ent_id + 1 167 | end 168 | 169 | num_total_ents = num_total_ents + 1 -- Keep even incorrect links 170 | 171 | cur_mentions_num = cur_mentions_num + 1 172 | cur_mentions[cur_mentions_num] = {} 173 | cur_mentions[cur_mentions_num].mention = cur_mention 174 | cur_mentions[cur_mentions_num].ent_wikiid = final_ent_wikiid 175 | cur_mentions[cur_mentions_num].start_off = cur_words_num + 1 176 | cur_mentions[cur_mentions_num].end_off = cur_words_num + table_len(split(parts[3], ' ')) 177 | end 178 | 179 | local words_on_this_line = split_in_words(parts[1]) 180 | for _,w in pairs(words_on_this_line) do 181 | table.insert(cur_words, modify_uppercase_phrase(w)) 182 | cur_words_num = cur_words_num + 1 183 | end 184 | end 185 | 186 | else 187 | assert(line:find('-DOCSTART-')) 188 | write_results() 189 | 190 | if cur_doc_name:find('testa') and line:find('testb') then 191 | ouf = ouf_B 192 | print('Done validation testA : ') 193 | print('num_nme = ' .. num_nme .. '; num_nonexistent_ent_title = ' .. num_nonexistent_ent_title) 194 | print('num_nonexistent_ent_id = ' .. num_nonexistent_ent_id .. '; num_nonexistent_both = ' .. num_nonexistent_both) 195 | print('num_correct_ents = ' .. num_correct_ents .. '; num_total_ents = ' .. num_total_ents) 196 | end 197 | 198 | local words = split_in_words(line) 199 | for _,w in pairs(words) do 200 | if w:find('testa') or w:find('testb') then 201 | cur_doc_name = w 202 | break 203 | end 204 | end 205 | cur_words = {} 206 | cur_words_num = 0 207 | cur_mentions = {} 208 | cur_mentions_num = 0 209 | end 210 | 211 | line = it:read() 212 | end 213 | 214 | write_results() 215 | 216 | ouf_A:flush() 217 | io.close(ouf_A) 218 | ouf_B:flush() 219 | io.close(ouf_B) 220 | 221 | 222 | print(' Done AIDA.') 223 | print('num_nme = ' .. num_nme .. '; num_nonexistent_ent_title = ' .. num_nonexistent_ent_title) 224 | print('num_nonexistent_ent_id = ' .. num_nonexistent_ent_id .. '; num_nonexistent_both = ' .. num_nonexistent_both) 225 | print('num_correct_ents = ' .. num_correct_ents .. '; num_total_ents = ' .. num_total_ents) 226 | -------------------------------------------------------------------------------- /data_gen/gen_test_train_data/gen_aida_train.lua: -------------------------------------------------------------------------------- 1 | -- Generate train data from the AIDA dataset by keeping the context and 2 | -- entity candidates for each annotated mention 3 | 4 | -- Format: 5 | -- doc_name \t doc_name \t mention \t left_ctxt \t right_ctxt \t CANDIDATES \t [ent_wikiid,p_e_m,ent_name]+ \t GT: \t pos,ent_wikiid,p_e_m,ent_name 6 | 7 | if not ent_p_e_m_index then 8 | require 'torch' 9 | dofile 'data_gen/indexes/wiki_redirects_index.lua' 10 | dofile 'data_gen/indexes/yago_crosswikis_wiki.lua' 11 | dofile 'utils/utils.lua' 12 | tds = tds or require 'tds' 13 | end 14 | 15 | print('\nGenerating train data from AIDA set ') 16 | 17 | it, _ = io.open(opt.root_data_dir .. 'basic_data/test_datasets/AIDA/aida_train.txt') 18 | 19 | out_file = opt.root_data_dir .. 'generated/test_train_data/aida_train.csv' 20 | ouf = assert(io.open(out_file, "w")) 21 | 22 | local num_nme = 0 23 | local num_nonexistent_ent_title = 0 24 | local num_nonexistent_ent_id = 0 25 | local num_nonexistent_both = 0 26 | local num_correct_ents = 0 27 | local num_total_ents = 0 28 | 29 | local cur_words_num = 0 30 | local cur_words = {} 31 | local cur_mentions = {} 32 | local cur_mentions_num = 0 33 | 34 | local cur_doc_name = '' 35 | 36 | local function write_results() 37 | -- Write results: 38 | if cur_doc_name ~= '' then 39 | local header = cur_doc_name .. '\t' .. cur_doc_name .. '\t' 40 | for _, hyp in pairs(cur_mentions) do 41 | assert(hyp.mention:len() > 0, line) 42 | local str = header .. hyp.mention .. '\t' 43 | 44 | local left_ctxt = {} 45 | for i = math.max(0, hyp.start_off - 100), hyp.start_off - 1 do 46 | table.insert(left_ctxt, cur_words[i]) 47 | end 48 | if table_len(left_ctxt) == 0 then 49 | table.insert(left_ctxt, 'EMPTYCTXT') 50 | end 51 | str = str .. table.concat(left_ctxt, ' ') .. '\t' 52 | 53 | local right_ctxt = {} 54 | for i = hyp.end_off + 1, math.min(cur_words_num, hyp.end_off + 100) do 55 | table.insert(right_ctxt, cur_words[i]) 56 | end 57 | if table_len(right_ctxt) == 0 then 58 | table.insert(right_ctxt, 'EMPTYCTXT') 59 | end 60 | str = str .. table.concat(right_ctxt, ' ') .. '\tCANDIDATES\t' 61 | 62 | -- Entity candidates from p(e|m) dictionary 63 | if ent_p_e_m_index[hyp.mention] and #(ent_p_e_m_index[hyp.mention]) > 0 then 64 | 65 | local sorted_cand = {} 66 | for ent_wikiid,p in pairs(ent_p_e_m_index[hyp.mention]) do 67 | table.insert(sorted_cand, {ent_wikiid = ent_wikiid, p = p}) 68 | end 69 | table.sort(sorted_cand, function(a,b) return a.p > b.p end) 70 | 71 | local candidates = {} 72 | local gt_pos = -1 73 | for pos,e in pairs(sorted_cand) do 74 | if pos <= 100 then 75 | table.insert(candidates, e.ent_wikiid .. ',' .. string.format("%.3f", e.p) .. ',' .. get_ent_name_from_wikiid(e.ent_wikiid)) 76 | if e.ent_wikiid == hyp.ent_wikiid then 77 | gt_pos = pos 78 | end 79 | else 80 | break 81 | end 82 | end 83 | str = str .. table.concat(candidates, '\t') .. '\tGT:\t' 84 | 85 | if gt_pos > 0 then 86 | ouf:write(str .. gt_pos .. ',' .. candidates[gt_pos] .. '\n') 87 | end 88 | end 89 | end 90 | end 91 | end 92 | 93 | 94 | 95 | local line = it:read() 96 | while (line) do 97 | if (not line:find('-DOCSTART-')) then 98 | local parts = split(line, '\t') 99 | local num_parts = table_len(parts) 100 | assert(num_parts == 0 or num_parts == 1 or num_parts == 4 or num_parts == 7 or num_parts == 6, line) 101 | if num_parts > 0 then 102 | if num_parts == 4 and parts[2] == 'B' then 103 | num_nme = num_nme + 1 104 | end 105 | 106 | if (num_parts == 7 or num_parts == 6) and parts[2] == 'B' then 107 | 108 | -- Find current mention. A few hacks here. 109 | 110 | local cur_mention = preprocess_mention(parts[3]) 111 | 112 | local x,y = parts[5]:find('/wiki/') 113 | local cur_ent_title = parts[5]:sub(y + 1) 114 | local cur_ent_wikiid = tonumber(parts[6]) 115 | local index_ent_title = get_ent_name_from_wikiid(cur_ent_wikiid) 116 | local index_ent_wikiid = get_ent_wikiid_from_name(cur_ent_title) 117 | 118 | local final_ent_wikiid = index_ent_wikiid 119 | if final_ent_wikiid == unk_ent_wikiid then 120 | final_ent_wikiid = cur_ent_wikiid 121 | end 122 | 123 | if (index_ent_title == cur_ent_title and cur_ent_wikiid == index_ent_wikiid) then 124 | num_correct_ents = num_correct_ents + 1 125 | elseif (index_ent_title ~= cur_ent_title and cur_ent_wikiid ~= index_ent_wikiid) then 126 | num_nonexistent_both = num_nonexistent_both + 1 127 | elseif index_ent_title ~= cur_ent_title then 128 | assert(cur_ent_wikiid == index_ent_wikiid) 129 | num_nonexistent_ent_title = num_nonexistent_ent_title + 1 130 | else 131 | assert(index_ent_title == cur_ent_title) 132 | assert(cur_ent_wikiid ~= index_ent_wikiid) 133 | num_nonexistent_ent_id = num_nonexistent_ent_id + 1 134 | end 135 | 136 | num_total_ents = num_total_ents + 1 -- Keep even incorrect links 137 | 138 | cur_mentions_num = cur_mentions_num + 1 139 | cur_mentions[cur_mentions_num] = {} 140 | cur_mentions[cur_mentions_num].mention = cur_mention 141 | cur_mentions[cur_mentions_num].ent_wikiid = final_ent_wikiid 142 | cur_mentions[cur_mentions_num].start_off = cur_words_num + 1 143 | cur_mentions[cur_mentions_num].end_off = cur_words_num + table_len(split(parts[3], ' ')) 144 | end 145 | 146 | local words_on_this_line = split_in_words(parts[1]) 147 | for _,w in pairs(words_on_this_line) do 148 | table.insert(cur_words, modify_uppercase_phrase(w)) 149 | cur_words_num = cur_words_num + 1 150 | end 151 | end 152 | 153 | else 154 | assert(line:find('-DOCSTART-')) 155 | write_results() 156 | 157 | cur_doc_name = line:sub(13) 158 | 159 | cur_words = {} 160 | cur_words_num = 0 161 | cur_mentions = {} 162 | cur_mentions_num = 0 163 | end 164 | 165 | line = it:read() 166 | end 167 | 168 | write_results() 169 | 170 | ouf:flush() 171 | io.close(ouf) 172 | 173 | 174 | print(' Done AIDA.') 175 | print('num_nme = ' .. num_nme .. '; num_nonexistent_ent_title = ' .. num_nonexistent_ent_title) 176 | print('num_nonexistent_ent_id = ' .. num_nonexistent_ent_id .. '; num_nonexistent_both = ' .. num_nonexistent_both) 177 | print('num_correct_ents = ' .. num_correct_ents .. '; num_total_ents = ' .. num_total_ents) 178 | -------------------------------------------------------------------------------- /data_gen/gen_test_train_data/gen_all.lua: -------------------------------------------------------------------------------- 1 | -- Generates all training and test data for entity disambiguation. 2 | 3 | if not ent_p_e_m_index then 4 | require 'torch' 5 | dofile 'data_gen/indexes/wiki_redirects_index.lua' 6 | dofile 'data_gen/indexes/yago_crosswikis_wiki.lua' 7 | dofile 'utils/utils.lua' 8 | tds = tds or require 'tds' 9 | end 10 | 11 | dofile 'data_gen/gen_test_train_data/gen_aida_test.lua' 12 | dofile 'data_gen/gen_test_train_data/gen_aida_train.lua' 13 | dofile 'data_gen/gen_test_train_data/gen_ace_msnbc_aquaint_csv.lua' -------------------------------------------------------------------------------- /data_gen/gen_wiki_data/gen_ent_wiki_w_repr.lua: -------------------------------------------------------------------------------- 1 | if not opt then 2 | cmd = torch.CmdLine() 3 | cmd:option('-root_data_dir', '', 'Root path of the data, $DATA_PATH.') 4 | cmd:text() 5 | opt = cmd:parse(arg or {}) 6 | assert(opt.root_data_dir ~= '', 'Specify a valid root_data_dir path argument.') 7 | end 8 | 9 | 10 | require 'torch' 11 | dofile 'utils/utils.lua' 12 | dofile 'data_gen/parse_wiki_dump/parse_wiki_dump_tools.lua' 13 | dofile 'entities/ent_name2id_freq/e_freq_index.lua' 14 | tds = tds or require 'tds' 15 | 16 | print('\nExtracting text only from Wiki dump. Output is wiki_canonical_words.txt containing on each line an Wiki entity with the list of all words in its canonical Wiki page.') 17 | 18 | it, _ = io.open(opt.root_data_dir .. 'basic_data/textWithAnchorsFromAllWikipedia2014Feb.txt') 19 | 20 | out_file = opt.root_data_dir .. 'generated/wiki_canonical_words.txt' 21 | ouf = assert(io.open(out_file, "w")) 22 | 23 | line = it:read() 24 | 25 | -- Find anchors, e.g. anarchism 26 | local num_lines = 0 27 | local num_valid_ents = 0 28 | local num_error_ents = 0 -- Probably list or disambiguation pages. 29 | 30 | local empty_valid_ents = get_map_all_valid_ents() 31 | 32 | local cur_words = '' 33 | local cur_ent_wikiid = -1 34 | 35 | while (line) do 36 | num_lines = num_lines + 1 37 | if num_lines % 5000000 == 0 then 38 | print('Processed ' .. num_lines .. ' lines. Num valid ents = ' .. num_valid_ents .. '. Num errs = ' .. num_error_ents) 39 | end 40 | 41 | if (not line:find(' 0 and cur_words ~= '') then 48 | if cur_ent_wikiid ~= unk_ent_wikiid and is_valid_ent(cur_ent_wikiid) then 49 | ouf:write(cur_ent_wikiid .. '\t' .. get_ent_name_from_wikiid(cur_ent_wikiid) .. '\t' .. cur_words .. '\n') 50 | empty_valid_ents[cur_ent_wikiid] = nil 51 | num_valid_ents = num_valid_ents + 1 52 | else 53 | num_error_ents = num_error_ents + 1 54 | end 55 | end 56 | 57 | cur_ent_wikiid = extract_page_entity_title(line) 58 | cur_words = '' 59 | end 60 | 61 | line = it:read() 62 | end 63 | ouf:flush() 64 | io.close(ouf) 65 | 66 | -- Num valid ents = 4126137. Num errs = 332944 67 | print(' Done extracting text only from Wiki dump. Num valid ents = ' .. num_valid_ents .. '. Num errs = ' .. num_error_ents) 68 | 69 | 70 | print('Create file with all entities with empty Wikipedia pages.') 71 | local empty_ents = {} 72 | for ent_wikiid, _ in pairs(empty_valid_ents) do 73 | table.insert(empty_ents, {ent_wikiid = ent_wikiid, f = get_ent_freq(ent_wikiid)}) 74 | end 75 | table.sort(empty_ents, function(a,b) return a.f > b.f end) 76 | 77 | local ouf2 = assert(io.open(opt.root_data_dir .. 'generated/empty_page_ents.txt', "w")) 78 | for _,x in pairs(empty_ents) do 79 | ouf2:write(x.ent_wikiid .. '\t' .. get_ent_name_from_wikiid(x.ent_wikiid) .. '\t' .. x.f .. '\n') 80 | end 81 | ouf2:flush() 82 | io.close(ouf2) 83 | print(' Done') -------------------------------------------------------------------------------- /data_gen/gen_wiki_data/gen_wiki_hyp_train_data.lua: -------------------------------------------------------------------------------- 1 | -- Generate training data from Wikipedia hyperlinks by keeping the context and 2 | -- entity candidates for each hyperlink 3 | 4 | -- Format: 5 | -- ent_wikiid \t ent_name \t mention \t left_ctxt \t right_ctxt \t CANDIDATES \t [ent_wikiid,p_e_m,ent_name]+ \t GT: \t pos,ent_wikiid,p_e_m,ent_name 6 | 7 | if not opt then 8 | cmd = torch.CmdLine() 9 | cmd:option('-root_data_dir', '', 'Root path of the data, $DATA_PATH.') 10 | cmd:text() 11 | opt = cmd:parse(arg or {}) 12 | assert(opt.root_data_dir ~= '', 'Specify a valid root_data_dir path argument.') 13 | end 14 | 15 | require 'torch' 16 | dofile 'data_gen/parse_wiki_dump/parse_wiki_dump_tools.lua' 17 | dofile 'data_gen/indexes/yago_crosswikis_wiki.lua' 18 | tds = tds or require 'tds' 19 | 20 | print('\nGenerating training data from Wiki dump') 21 | 22 | it, _ = io.open(opt.root_data_dir .. 'basic_data/textWithAnchorsFromAllWikipedia2014Feb.txt') 23 | 24 | out_file = opt.root_data_dir .. 'generated/wiki_hyperlink_contexts.csv' 25 | ouf = assert(io.open(out_file, "w")) 26 | 27 | -- Find anchors, e.g. anarchism 28 | local num_lines = 0 29 | local num_valid_hyp = 0 30 | 31 | local cur_words_num = 0 32 | local cur_words = {} 33 | local cur_mentions = {} 34 | local cur_mentions_num = 0 35 | local cur_ent_wikiid = -1 36 | 37 | local line = it:read() 38 | while (line) do 39 | num_lines = num_lines + 1 40 | if num_lines % 1000000 == 0 then 41 | print('Processed ' .. num_lines .. ' lines. Num valid hyp = ' .. num_valid_hyp) 42 | end 43 | 44 | -- If it's a line from the Wiki page, add its text words and its hyperlinks 45 | if (not line:find(' 0 then 88 | assert(hyp.mention:len() > 0, line) 89 | local str = header .. hyp.mention .. '\t' 90 | 91 | local left_ctxt = {} 92 | for i = math.max(0, hyp.start_off - 100), hyp.start_off - 1 do 93 | table.insert(left_ctxt, cur_words[i]) 94 | end 95 | if table_len(left_ctxt) == 0 then 96 | table.insert(left_ctxt, 'EMPTYCTXT') 97 | end 98 | str = str .. table.concat(left_ctxt, ' ') .. '\t' 99 | 100 | local right_ctxt = {} 101 | for i = hyp.end_off + 1, math.min(cur_words_num, hyp.end_off + 100) do 102 | table.insert(right_ctxt, cur_words[i]) 103 | end 104 | if table_len(right_ctxt) == 0 then 105 | table.insert(right_ctxt, 'EMPTYCTXT') 106 | end 107 | str = str .. table.concat(right_ctxt, ' ') .. '\tCANDIDATES\t' 108 | 109 | -- Entity candidates from p(e|m) dictionary 110 | local unsorted_cand = {} 111 | for ent_wikiid,p in pairs(ent_p_e_m_index[hyp.mention]) do 112 | table.insert(unsorted_cand, {ent_wikiid = ent_wikiid, p = p}) 113 | end 114 | table.sort(unsorted_cand, function(a,b) return a.p > b.p end) 115 | 116 | local candidates = {} 117 | local gt_pos = -1 118 | for pos,e in pairs(unsorted_cand) do 119 | if pos <= 32 then 120 | table.insert(candidates, e.ent_wikiid .. ',' .. string.format("%.3f", e.p) .. ',' .. get_ent_name_from_wikiid(e.ent_wikiid)) 121 | if e.ent_wikiid == hyp.ent_wikiid then 122 | gt_pos = pos 123 | end 124 | else 125 | break 126 | end 127 | end 128 | str = str .. table.concat(candidates, '\t') .. '\tGT:\t' 129 | 130 | if gt_pos > 0 then 131 | num_valid_hyp = num_valid_hyp + 1 132 | ouf:write(str .. gt_pos .. ',' .. candidates[gt_pos] .. '\n') 133 | end 134 | end 135 | end 136 | end 137 | 138 | cur_ent_wikiid = extract_page_entity_title(line) 139 | cur_words = {} 140 | cur_words_num = 0 141 | cur_mentions = {} 142 | cur_mentions_num = 0 143 | end 144 | 145 | line = it:read() 146 | end 147 | ouf:flush() 148 | io.close(ouf) 149 | 150 | print(' Done generating training data from Wiki dump. Num valid hyp = ' .. num_valid_hyp) 151 | -------------------------------------------------------------------------------- /data_gen/indexes/wiki_disambiguation_pages_index.lua: -------------------------------------------------------------------------------- 1 | -- Loads the link disambiguation index from Wikipedia 2 | 3 | if not opt then 4 | cmd = torch.CmdLine() 5 | cmd:option('-root_data_dir', '', 'Root path of the data, $DATA_PATH.') 6 | cmd:text() 7 | opt = cmd:parse(arg or {}) 8 | assert(opt.root_data_dir ~= '', 'Specify a valid root_data_dir path argument.') 9 | end 10 | 11 | 12 | require 'torch' 13 | dofile 'utils/utils.lua' 14 | tds = tds or require 'tds' 15 | 16 | print('==> Loading disambiguation index') 17 | it, _ = io.open(opt.root_data_dir .. 'basic_data/wiki_disambiguation_pages.txt') 18 | line = it:read() 19 | 20 | wiki_disambiguation_index = tds.Hash() 21 | while (line) do 22 | parts = split(line, "\t") 23 | assert(tonumber(parts[1])) 24 | wiki_disambiguation_index[tonumber(parts[1])] = 1 25 | line = it:read() 26 | end 27 | 28 | assert(wiki_disambiguation_index[579]) 29 | assert(wiki_disambiguation_index[41535072]) 30 | 31 | print(' Done loading disambiguation index') 32 | -------------------------------------------------------------------------------- /data_gen/indexes/wiki_redirects_index.lua: -------------------------------------------------------------------------------- 1 | -- Loads the link redirect index from Wikipedia 2 | 3 | if not opt then 4 | cmd = torch.CmdLine() 5 | cmd:option('-root_data_dir', '', 'Root path of the data, $DATA_PATH.') 6 | cmd:text() 7 | opt = cmd:parse(arg or {}) 8 | assert(opt.root_data_dir ~= '', 'Specify a valid root_data_dir path argument.') 9 | end 10 | 11 | require 'torch' 12 | dofile 'utils/utils.lua' 13 | tds = tds or require 'tds' 14 | 15 | print('==> Loading redirects index') 16 | it, _ = io.open(opt.root_data_dir .. 'basic_data/wiki_redirects.txt') 17 | line = it:read() 18 | 19 | local wiki_redirects_index = tds.Hash() 20 | while (line) do 21 | parts = split(line, "\t") 22 | wiki_redirects_index[parts[1]] = parts[2] 23 | line = it:read() 24 | end 25 | 26 | assert(wiki_redirects_index['Coercive'] == 'Coercion') 27 | assert(wiki_redirects_index['Hosford, FL'] == 'Hosford, Florida') 28 | 29 | print(' Done loading redirects index') 30 | 31 | 32 | function get_redirected_ent_title(ent_name) 33 | if wiki_redirects_index[ent_name] then 34 | return wiki_redirects_index[ent_name] 35 | else 36 | return ent_name 37 | end 38 | end 39 | -------------------------------------------------------------------------------- /data_gen/indexes/yago_crosswikis_wiki.lua: -------------------------------------------------------------------------------- 1 | -- Loads the merged p(e|m) index. 2 | if not opt then 3 | cmd = torch.CmdLine() 4 | cmd:option('-root_data_dir', '', 'Root path of the data, $DATA_PATH.') 5 | cmd:text() 6 | opt = cmd:parse(arg or {}) 7 | assert(opt.root_data_dir ~= '', 'Specify a valid root_data_dir path argument.') 8 | end 9 | 10 | 11 | require 'torch' 12 | tds = tds or require 'tds' 13 | 14 | dofile 'utils/utils.lua' 15 | dofile 'entities/ent_name2id_freq/ent_name_id.lua' 16 | 17 | ent_p_e_m_index = tds.Hash() 18 | 19 | mention_lower_to_one_upper = tds.Hash() 20 | 21 | mention_total_freq = tds.Hash() 22 | 23 | local crosswikis_textfilename = opt.root_data_dir .. 'generated/crosswikis_wikipedia_p_e_m.txt' 24 | print('==> Loading crosswikis_wikipedia from file ' .. crosswikis_textfilename) 25 | local it, _ = io.open(crosswikis_textfilename) 26 | local line = it:read() 27 | 28 | local num_lines = 0 29 | while (line) do 30 | num_lines = num_lines + 1 31 | if num_lines % 2000000 == 0 then 32 | print('Processed ' .. num_lines .. ' lines. ') 33 | end 34 | 35 | local parts = split(line , '\t') 36 | local mention = parts[1] 37 | 38 | local total = tonumber(parts[2]) 39 | assert(total) 40 | if total >= 1 then 41 | ent_p_e_m_index[mention] = tds.Hash() 42 | mention_lower_to_one_upper[mention:lower()] = mention 43 | mention_total_freq[mention] = total 44 | local num_parts = table_len(parts) 45 | for i = 3, num_parts do 46 | local ent_str = split(parts[i], ',') 47 | local ent_wikiid = tonumber(ent_str[1]) 48 | local freq = tonumber(ent_str[2]) 49 | assert(ent_wikiid) 50 | assert(freq) 51 | ent_p_e_m_index[mention][ent_wikiid] = freq / (total + 0.0) -- not sorted 52 | end 53 | end 54 | line = it:read() 55 | end 56 | 57 | local yago_textfilename = opt.root_data_dir .. 'generated/yago_p_e_m.txt' 58 | print('==> Loading yago index from file ' .. yago_textfilename) 59 | it, _ = io.open(yago_textfilename) 60 | line = it:read() 61 | 62 | num_lines = 0 63 | while (line) do 64 | num_lines = num_lines + 1 65 | if num_lines % 2000000 == 0 then 66 | print('Processed ' .. num_lines .. ' lines. ') 67 | end 68 | 69 | local parts = split(line , '\t') 70 | local mention = parts[1] 71 | 72 | local total = tonumber(parts[2]) 73 | assert(total) 74 | if total >= 1 then 75 | mention_lower_to_one_upper[mention:lower()] = mention 76 | if not mention_total_freq[mention] then 77 | mention_total_freq[mention] = total 78 | else 79 | mention_total_freq[mention] = total + mention_total_freq[mention] 80 | end 81 | 82 | local yago_ment_ent_idx = tds.Hash() 83 | local num_parts = table_len(parts) 84 | for i = 3, num_parts do 85 | local ent_str = split(parts[i], ',') 86 | local ent_wikiid = tonumber(ent_str[1]) 87 | local freq = 1 88 | assert(ent_wikiid) 89 | yago_ment_ent_idx[ent_wikiid] = freq / (total + 0.0) -- not sorted 90 | end 91 | 92 | if not ent_p_e_m_index[mention] then 93 | ent_p_e_m_index[mention] = yago_ment_ent_idx 94 | else 95 | for ent_wikiid,prob in pairs(yago_ment_ent_idx) do 96 | if not ent_p_e_m_index[mention][ent_wikiid] then 97 | ent_p_e_m_index[mention][ent_wikiid] = 0.0 98 | end 99 | ent_p_e_m_index[mention][ent_wikiid] = math.min(1.0, ent_p_e_m_index[mention][ent_wikiid] + prob) 100 | end 101 | 102 | end 103 | 104 | end 105 | line = it:read() 106 | end 107 | 108 | assert(ent_p_e_m_index['Dejan Koturovic'] and ent_p_e_m_index['Jose Luis Caminero']) 109 | 110 | -- Function used to preprocess a given mention such that it has higher 111 | -- chance to have at least one valid entry in the p(e|m) index. 112 | function preprocess_mention(m) 113 | assert(ent_p_e_m_index and mention_total_freq) 114 | local cur_m = modify_uppercase_phrase(m) 115 | if (not ent_p_e_m_index[cur_m]) then 116 | cur_m = m 117 | end 118 | if (mention_total_freq[m] and (mention_total_freq[m] > mention_total_freq[cur_m])) then 119 | cur_m = m -- Cases like 'U.S.' are handed badly by modify_uppercase_phrase 120 | end 121 | -- If we cannot find the exact mention in our index, we try our luck to find it in a case insensitive index. 122 | if (not ent_p_e_m_index[cur_m]) and mention_lower_to_one_upper[cur_m:lower()] then 123 | cur_m = mention_lower_to_one_upper[cur_m:lower()] 124 | end 125 | return cur_m 126 | end 127 | 128 | 129 | print(' Done loading index') 130 | -------------------------------------------------------------------------------- /data_gen/parse_wiki_dump/parse_wiki_dump_tools.lua: -------------------------------------------------------------------------------- 1 | -- Utility functions to extract the text and hyperlinks from each page in the Wikipedia corpus. 2 | 3 | if not table_len then 4 | dofile 'utils/utils.lua' 5 | end 6 | if not get_redirected_ent_title then 7 | dofile 'data_gen/indexes/wiki_redirects_index.lua' 8 | end 9 | if not get_ent_name_from_wikiid then 10 | dofile 'entities/ent_name2id_freq/ent_name_id.lua' 11 | end 12 | 13 | 14 | function extract_text_and_hyp(line, mark_mentions) 15 | local list_hyp = {} -- (mention, entity) pairs 16 | local text = '' 17 | local list_ent_errors = 0 18 | local parsing_errors = 0 19 | local disambiguation_ent_errors = 0 20 | local diez_ent_errors = 0 21 | 22 | local end_end_hyp = 0 23 | local begin_end_hyp = 0 24 | local begin_start_hyp, end_start_hyp = line:find('', end_start_hyp + 1) 32 | if next_quotes then 33 | local ent_name = line:sub(end_start_hyp + 1, next_quotes - 1) 34 | begin_end_hyp, end_end_hyp = line:find('', end_quotes + 1) 35 | if begin_end_hyp then 36 | local mention = line:sub(end_quotes + 1, begin_end_hyp - 1) 37 | local mention_marker = false 38 | 39 | local good_mention = true 40 | good_mention = good_mention and (not mention:find('Wikipedia')) 41 | good_mention = good_mention and (not mention:find('wikipedia')) 42 | good_mention = good_mention and (mention:len() >= 1) 43 | 44 | if good_mention then 45 | local i = ent_name:find('wikt:') 46 | if i == 1 then 47 | ent_name = ent_name:sub(6) 48 | end 49 | ent_name = preprocess_ent_name(ent_name) 50 | 51 | i = ent_name:find('List of ') 52 | if (not i) or (i ~= 1) then 53 | if ent_name:find('#') then 54 | diez_ent_errors = diez_ent_errors + 1 55 | else 56 | local ent_wikiid = get_ent_wikiid_from_name(ent_name, true) 57 | if ent_wikiid == unk_ent_wikiid then 58 | disambiguation_ent_errors = disambiguation_ent_errors + 1 59 | else 60 | -- A valid (entity,mention) pair 61 | num_mentions = num_mentions + 1 62 | table.insert(list_hyp, {mention = mention, ent_wikiid = ent_wikiid, cnt = num_mentions}) 63 | if mark_mentions then 64 | mention_marker = true 65 | end 66 | end 67 | end 68 | else 69 | list_ent_errors = list_ent_errors + 1 70 | end 71 | end 72 | 73 | if (not mention_marker) then 74 | text = text .. ' ' .. mention .. ' ' 75 | else 76 | text = text .. ' MMSTART' .. num_mentions .. ' ' .. mention .. ' MMEND' .. num_mentions .. ' ' 77 | end 78 | else 79 | parsing_errors = parsing_errors + 1 80 | begin_start_hyp = nil 81 | end 82 | else 83 | parsing_errors = parsing_errors + 1 84 | begin_start_hyp = nil 85 | end 86 | 87 | if begin_start_hyp then 88 | begin_start_hyp, end_start_hyp = line:find('Anarchism is a political philosophy that advocatesstateless societiesoften defined as self-governed voluntary institutions, but that several authors have defined as more specific institutions based on non-hierarchical free associations..Anarchism' 109 | 110 | local test_line_2 = 'CSF pressure, as measured by lumbar puncture (LP), is 10-18 ' 111 | local test_line_3 = 'Anarchism' 112 | 113 | list_hype, text = extract_text_and_hyp(test_line_1, false) 114 | print(list_hype) 115 | print(text) 116 | print() 117 | 118 | list_hype, text = extract_text_and_hyp(test_line_1, true) 119 | print(list_hype) 120 | print(text) 121 | print() 122 | 123 | list_hype, text = extract_text_and_hyp(test_line_2, true) 124 | print(list_hype) 125 | print(text) 126 | print() 127 | 128 | list_hype, text = extract_text_and_hyp(test_line_3, false) 129 | print(list_hype) 130 | print(text) 131 | print() 132 | print(' Done unit tests.') 133 | --------------------------------------------------------- 134 | 135 | 136 | function extract_page_entity_title(line) 137 | local startoff, endoff = line:find(' ' .. line:sub(startoff + 1, startquotes - 1)) 142 | local starttitlestartoff, starttitleendoff = line:find(' title="') 143 | local endtitleoff, _ = line:find('">') 144 | local ent_name = line:sub(starttitleendoff + 1, endtitleoff - 1) 145 | if (ent_wikiid ~= get_ent_wikiid_from_name(ent_name, true)) then 146 | -- Most probably this is a disambiguation or list page 147 | local new_ent_wikiid = get_ent_wikiid_from_name(ent_name, true) 148 | -- print(red('Error in Wiki dump: ' .. line .. ' ' .. ent_wikiid .. ' ' .. new_ent_wikiid)) 149 | return new_ent_wikiid 150 | end 151 | return ent_wikiid 152 | end 153 | 154 | 155 | local test_line_4 = '' 156 | 157 | print(extract_page_entity_title(test_line_4)) 158 | -------------------------------------------------------------------------------- /ed/args.lua: -------------------------------------------------------------------------------- 1 | -- We add params abbreviations to the abbv map if they are important hyperparameters 2 | -- used to differentiate between different methods. 3 | abbv = {} 4 | 5 | cmd = torch.CmdLine() 6 | cmd:text() 7 | cmd:text('Deep Joint Entity Disambiguation w/ Local Neural Attention') 8 | cmd:text('Command line options:') 9 | 10 | ---------------- runtime options: 11 | cmd:option('-root_data_dir', '', 'Root path of the data, $DATA_PATH.') 12 | 13 | cmd:option('-unit_tests', false, 'Run unit tests or not') 14 | 15 | ---------------- CUDA: 16 | cmd:option('-type', 'cudacudnn', 'Type: cuda | float | cudacudnn') 17 | 18 | ---------------- train data: 19 | cmd:option('-store_train_data', 'RAM', 'Where to read the training data from: RAM (tensors) | DISK (text, parsed all the time)') 20 | 21 | ---------------- loss: 22 | cmd:option('-loss', 'max-margin', 'Loss: nll | max-margin') 23 | abbv['-loss'] = '' 24 | 25 | ---------------- optimization: 26 | cmd:option('-opt', 'ADAM', 'Optimization method: SGD | ADADELTA | ADAGRAD | ADAM') 27 | abbv['-opt'] = '' 28 | 29 | cmd:option('-lr', 1e-4, 'Learning rate. Will be divided by 10 after validation F1 >= 90%.') 30 | abbv['-lr'] = 'lr' 31 | 32 | cmd:option('-batch_size', 1, 'Batch size in terms of number of documents.') 33 | abbv['-batch_size'] = 'bs' 34 | 35 | ---------------- word vectors 36 | cmd:option('-word_vecs', 'w2v', 'Word vectors type: glove | w2v (Word2Vec)') 37 | abbv['-word_vecs'] = '' 38 | 39 | ---------------- entity vectors 40 | cmd:option('-entities', 'RLTD', 'Which entity vectors to use, either just those that appear as candidates in all datasets, or all. All is impractical when storing on GPU. RLTD | ALL') 41 | 42 | cmd:option('-ent_vecs_filename', 'ent_vecs__ep_228.t7', 'File name containing entity vectors generated with entities/learn_e2v/learn_a.lua.') 43 | 44 | ---------------- context 45 | cmd:option('-ctxt_window', 100, 'Number of context words at the left plus right of each mention') 46 | abbv['-ctxt_window'] = 'ctxtW' 47 | 48 | cmd:option('-R', 25, 'Hard attention threshold: top R context words are kept, the rest are discarded.') 49 | abbv['-R'] = 'R' 50 | 51 | ---------------- model 52 | cmd:option('-model', 'global', 'ED model: local | global') 53 | 54 | cmd:option('-nn_pem_interm_size', 100, 'Number of hidden units in the f function described in Section 4 - Local score combination.') 55 | abbv['-nn_pem_interm_size'] = 'nnPEMintermS' 56 | 57 | -------------- model regularization: 58 | cmd:option('-mat_reg_norm', 1, 'Maximum norm of columns of matrices of the f network.') 59 | abbv['-mat_reg_norm'] = 'matRegNorm' 60 | 61 | ---------------- global model parameters 62 | cmd:option('-lbp_iter', 10, 'Number iterations of LBP hard-coded in a NN. Referred as T in the paper.') 63 | abbv['-lbp_iter'] = 'lbpIt' 64 | 65 | cmd:option('-lbp_damp', 0.5, 'Damping factor for LBP') 66 | abbv['-lbp_damp'] = 'lbpDamp' 67 | 68 | ----------------- reranking of candidates 69 | cmd:option('-num_cand_before_rerank', 30, '') 70 | abbv['-num_cand_before_rerank'] = 'numERerank' 71 | 72 | cmd:option('-keep_p_e_m', 4, '') 73 | abbv['-keep_p_e_m'] = 'keepPEM' 74 | 75 | cmd:option('-keep_e_ctxt', 3, '') 76 | abbv['-keep_e_ctxt'] = 'keepEC' 77 | 78 | ----------------- coreference: 79 | cmd:option('-coref', true, 'Coreference heuristic to match persons names.') 80 | abbv['-coref'] = 'coref' 81 | 82 | ------------------ test one model with saved pretrained parameters 83 | cmd:option('-test_one_model_file', '', 'Saved pretrained model filename from folder $DATA_PATH/generated/ed_models/.') 84 | 85 | ------------------ banner: 86 | cmd:option('-banner_header', '', ' Banner header to be used for plotting') 87 | 88 | cmd:text() 89 | opt = cmd:parse(arg or {}) 90 | 91 | -- Whether to save the current ED model or not during training. 92 | -- It will become true after the model gets > 90% F1 score on validation set (see test.lua). 93 | opt.save = false 94 | 95 | -- Creates a nice banner from the command line arguments 96 | function get_banner(arg, abbv) 97 | local num_args = table_len(arg) 98 | local banner = opt.banner_header 99 | 100 | if opt.model == 'global' then 101 | banner = banner .. 'GLB' 102 | else 103 | banner = banner .. 'LCL' 104 | end 105 | 106 | for i = 1,num_args do 107 | if abbv[arg[i]] then 108 | banner = banner .. '|' .. abbv[arg[i]] .. '=' .. tostring(opt[arg[i]:sub(2)]) 109 | end 110 | end 111 | return banner 112 | end 113 | 114 | 115 | function serialize_params(arg) 116 | local num_args = table_len(arg) 117 | local str = opt.banner_header 118 | if opt.banner_header:len() > 0 then 119 | str = str .. '|' 120 | end 121 | 122 | str = str .. 'model=' .. opt.model 123 | 124 | for i = 1,num_args do 125 | if abbv[arg[i]] then 126 | str = str .. '|' .. arg[i]:sub(2) .. '=' .. tostring(opt[arg[i]:sub(2)]) 127 | end 128 | end 129 | return str 130 | end 131 | 132 | banner = get_banner(arg, abbv) 133 | 134 | params_serialized = serialize_params(arg) 135 | print('PARAMS SERIALIZED: ' .. params_serialized) 136 | print('BANNER : ' .. banner .. '\n') 137 | 138 | assert(params_serialized:len() < 255, 'Parameters string length should be < 255.') 139 | 140 | 141 | function extract_args_from_model_title(title) 142 | local x,y = title:find('model=') 143 | local parts = split(title:sub(x), '|') 144 | for _,part in pairs(parts) do 145 | if string.find(part, '=') then 146 | local components = split(part, '=') 147 | opt[components[1]] = components[2] 148 | if tonumber(components[2]) then 149 | opt[components[1]] = tonumber(components[2]) 150 | end 151 | if components[2] == 'true' then 152 | opt[components[1]] = true 153 | end 154 | if components[2] == 'false' then 155 | opt[components[1]] = false 156 | end 157 | end 158 | end 159 | end 160 | -------------------------------------------------------------------------------- /ed/ed.lua: -------------------------------------------------------------------------------- 1 | require 'optim' 2 | require 'torch' 3 | require 'gnuplot' 4 | require 'nn' 5 | require 'xlua' 6 | 7 | tds = tds or require 'tds' 8 | dofile 'utils/utils.lua' 9 | 10 | print('\n' .. green('==========> TRAINING of ED NEURAL MODELS <==========') .. '\n') 11 | 12 | dofile 'ed/args.lua' 13 | 14 | print('===> RUN TYPE (CPU/GPU): ' .. opt.type) 15 | 16 | torch.setdefaulttensortype('torch.FloatTensor') 17 | if string.find(opt.type, 'cuda') then 18 | print('==> switching to CUDA (GPU)') 19 | require 'cunn' 20 | require 'cutorch' 21 | require 'cudnn' 22 | cudnn.benchmark = true 23 | cudnn.fastest = true 24 | else 25 | print('==> running on CPU') 26 | end 27 | 28 | dofile 'utils/logger.lua' 29 | dofile 'entities/relatedness/relatedness.lua' 30 | dofile 'entities/ent_name2id_freq/ent_name_id.lua' 31 | dofile 'entities/ent_name2id_freq/e_freq_index.lua' 32 | dofile 'words/load_w_freq_and_vecs.lua' -- w ids 33 | dofile 'words/w2v/w2v.lua' 34 | dofile 'entities/pretrained_e2v/e2v.lua' 35 | dofile 'ed/minibatch/build_minibatch.lua' 36 | dofile 'ed/minibatch/data_loader.lua' 37 | dofile 'ed/models/model.lua' 38 | dofile 'ed/loss.lua' 39 | dofile 'ed/train.lua' 40 | dofile 'ed/test/test.lua' 41 | 42 | geom_unit_tests() -- Show some entity examples 43 | compute_relatedness_metrics(entity_similarity) -- UNCOMMENT 44 | 45 | train_and_test() 46 | -------------------------------------------------------------------------------- /ed/loss.lua: -------------------------------------------------------------------------------- 1 | if opt.loss == 'nll' then 2 | criterion = nn.CrossEntropyCriterion() 3 | else 4 | -- max-margin with margin parameter = 0.01 5 | criterion = nn.MultiMarginCriterion(1, torch.ones(max_num_cand), 0.01) 6 | end 7 | 8 | if string.find(opt.type, 'cuda') then 9 | criterion = criterion:cuda() 10 | end -------------------------------------------------------------------------------- /ed/minibatch/data_loader.lua: -------------------------------------------------------------------------------- 1 | -- Data loader for training of ED models. 2 | 3 | train_file = opt.root_data_dir .. 'generated/test_train_data/aida_train.csv' 4 | it_train, _ = io.open(train_file) 5 | 6 | print('==> Loading training data with option ' .. opt.store_train_data) 7 | local function one_doc_to_minibatch(doc_lines) 8 | -- Create empty mini batch: 9 | local num_mentions = #doc_lines 10 | assert(num_mentions > 0) 11 | 12 | local inputs = empty_minibatch_with_ids(num_mentions) 13 | local targets = torch.zeros(num_mentions) 14 | 15 | -- Fill in each example: 16 | for i = 1, num_mentions do 17 | local target = process_one_line(doc_lines[i], inputs, i, true) 18 | targets[i] = target 19 | assert(target >= 1 and target == targets[i]) 20 | end 21 | 22 | return inputs, targets 23 | end 24 | 25 | if opt.store_train_data == 'RAM' then 26 | all_docs_inputs = tds.Hash() 27 | all_docs_targets = tds.Hash() 28 | doc2id = tds.Hash() 29 | id2doc = tds.Hash() 30 | 31 | local cur_doc_lines = tds.Hash() 32 | local prev_doc_id = nil 33 | 34 | local line = it_train:read() 35 | while line do 36 | local parts = split(line, '\t') 37 | local doc_name = parts[1] 38 | if not doc2id[doc_name] then 39 | if prev_doc_id then 40 | local inputs, targets = one_doc_to_minibatch(cur_doc_lines) 41 | all_docs_inputs[prev_doc_id] = minibatch_table2tds(inputs) 42 | all_docs_targets[prev_doc_id] = targets 43 | end 44 | local cur_docid = 1 + #doc2id 45 | id2doc[cur_docid] = doc_name 46 | doc2id[doc_name] = cur_docid 47 | cur_doc_lines = tds.Hash() 48 | prev_doc_id = cur_docid 49 | end 50 | cur_doc_lines[1 + #cur_doc_lines] = line 51 | line = it_train:read() 52 | end 53 | if prev_doc_id then 54 | local inputs, targets = one_doc_to_minibatch(cur_doc_lines) 55 | all_docs_inputs[prev_doc_id] = minibatch_table2tds(inputs) 56 | all_docs_targets[prev_doc_id] = targets 57 | end 58 | assert(#doc2id == #all_docs_inputs, #doc2id .. ' ' .. #all_docs_inputs) 59 | 60 | else 61 | all_doc_lines = tds.Hash() 62 | doc2id = tds.Hash() 63 | id2doc = tds.Hash() 64 | 65 | local line = it_train:read() 66 | while line do 67 | local parts = split(line, '\t') 68 | local doc_name = parts[1] 69 | if not doc2id[doc_name] then 70 | local cur_docid = 1 + #doc2id 71 | id2doc[cur_docid] = doc_name 72 | doc2id[doc_name] = cur_docid 73 | all_doc_lines[cur_docid] = tds.Hash() 74 | end 75 | all_doc_lines[doc2id[doc_name]][1 + #all_doc_lines[doc2id[doc_name]]] = line 76 | line = it_train:read() 77 | end 78 | assert(#doc2id == #all_doc_lines) 79 | end 80 | 81 | 82 | get_minibatch = function() 83 | -- Create empty mini batch: 84 | local inputs = nil 85 | local targets = nil 86 | 87 | if opt.store_train_data == 'RAM' then 88 | local random_docid = math.random(#id2doc) 89 | inputs = minibatch_tds2table(all_docs_inputs[random_docid]) 90 | targets = all_docs_targets[random_docid] 91 | else 92 | local doc_lines = all_doc_lines[math.random(#id2doc)] 93 | inputs, targets = one_doc_to_minibatch(doc_lines) 94 | end 95 | 96 | -- Move data to GPU: 97 | inputs, targets = minibatch_to_correct_type(inputs, targets, true) 98 | targets = correct_type(targets) 99 | 100 | return inputs, targets 101 | end 102 | 103 | print(' Done loading training data.') 104 | -------------------------------------------------------------------------------- /ed/models/SetConstantDiag.lua: -------------------------------------------------------------------------------- 1 | -- Torch layer that receives as input a squared matrix and sets its diagonal to a constant value. 2 | 3 | local SetConstantDiag, parent = torch.class('nn.SetConstantDiag', 'nn.Module') 4 | 5 | function SetConstantDiag:__init(constant_scalar, ip) 6 | parent.__init(self) 7 | assert(type(constant_scalar) == 'number', 'input is not scalar!') 8 | self.constant_scalar = constant_scalar 9 | 10 | -- default for inplace is false 11 | self.inplace = ip or false 12 | if (ip and type(ip) ~= 'boolean') then 13 | error('in-place flag must be boolean') 14 | end 15 | if not opt then 16 | opt = {} 17 | opt.type = 'double' 18 | dofile 'utils/utils.lua' 19 | end 20 | end 21 | 22 | function SetConstantDiag:updateOutput(input) 23 | assert(input:dim() == 3) 24 | assert(input:size(2) == input:size(3)) 25 | local n = input:size(3) 26 | local prod_mat = torch.ones(n,n) - torch.eye(n) 27 | prod_mat = correct_type(prod_mat) 28 | local sum_mat = torch.eye(n):mul(self.constant_scalar) 29 | sum_mat = correct_type(sum_mat) 30 | if self.inplace then 31 | input:cmul(torch.repeatTensor(prod_mat, input:size(1), 1, 1)) 32 | input:add(torch.repeatTensor(sum_mat, input:size(1), 1, 1)) 33 | self.output:set(input) 34 | else 35 | self.output:resizeAs(input) 36 | self.output:copy(input) 37 | self.output:cmul(torch.repeatTensor(prod_mat, input:size(1), 1, 1)) 38 | self.output:add(torch.repeatTensor(sum_mat, input:size(1), 1, 1)) 39 | end 40 | return self.output 41 | end 42 | 43 | function SetConstantDiag:updateGradInput(input, gradOutput) 44 | local n = input:size(3) 45 | local prod_mat = torch.ones(n,n) - torch.eye(n) 46 | prod_mat = correct_type(prod_mat) 47 | if self.inplace then 48 | self.gradInput:set(gradOutput) 49 | else 50 | self.gradInput:resizeAs(gradOutput) 51 | self.gradInput:copy(gradOutput) 52 | end 53 | self.gradInput:cmul(torch.repeatTensor(prod_mat, input:size(1), 1, 1)) 54 | return self.gradInput 55 | end -------------------------------------------------------------------------------- /ed/models/linear_layers.lua: -------------------------------------------------------------------------------- 1 | -- Define all parametrized layers: linear layers (diagonal matrices) A,B,C + network f 2 | 3 | function new_linear_layer(out_dim) 4 | cmul = nn.CMul(out_dim) 5 | -- init weights with ones to speed up convergence 6 | cmul.weight = torch.ones(out_dim) 7 | return cmul 8 | end 9 | 10 | ---- Create shared weights 11 | A_linear = new_linear_layer(ent_vecs_size) 12 | 13 | -- Local ctxt bilinear weights 14 | B_linear = new_linear_layer(ent_vecs_size) 15 | 16 | -- Used only in the global model 17 | C_linear = new_linear_layer(ent_vecs_size) 18 | 19 | f_network = nn.Sequential() 20 | :add(nn.Linear(2,opt.nn_pem_interm_size)) 21 | :add(nn.ReLU()) 22 | :add(nn.Linear(opt.nn_pem_interm_size,1)) 23 | 24 | 25 | function regularize_f_network() 26 | if opt.mat_reg_norm < 10 then 27 | for i = 1,f_network:size() do 28 | if f_network:get(i).weight and (f_network:get(i).weight:norm() > opt.mat_reg_norm) then 29 | f_network:get(i).weight:mul(opt.mat_reg_norm / f_network:get(i).weight:norm()) 30 | end 31 | if f_network:get(i).bias and (f_network:get(i).bias:norm() > opt.mat_reg_norm) then 32 | f_network:get(i).bias:mul(opt.mat_reg_norm / f_network:get(i).bias:norm()) 33 | end 34 | end 35 | end 36 | end 37 | 38 | function pack_saveable_weights() 39 | local linears = nn.Sequential():add(A_linear):add(B_linear):add(C_linear):add(f_network) 40 | return linears:float() 41 | end 42 | 43 | function unpack_saveable_weights(saved_linears) 44 | A_linear = saved_linears:get(1) 45 | B_linear = saved_linears:get(2) 46 | C_linear = saved_linears:get(3) 47 | f_network = saved_linears:get(4) 48 | end 49 | 50 | 51 | function print_net_weights() 52 | print('\nNetwork norms of parameter weights :') 53 | print('A (attention mat) = ' .. A_linear.weight:norm()) 54 | print('B (ctxt embedding) = ' .. B_linear.weight:norm()) 55 | print('C (pairwise mat) = ' .. C_linear.weight:norm()) 56 | 57 | if opt.mat_reg_norm < 10 then 58 | print('f_network norm = ' .. f_network:get(1).weight:norm() .. ' ' .. 59 | f_network:get(1).bias:norm() .. ' ' .. f_network:get(3).weight:norm() .. ' ' .. 60 | f_network:get(3).bias:norm()) 61 | else 62 | p,gp = f_network:getParameters() 63 | print('f_network norm = ' .. p:norm()) 64 | end 65 | end 66 | -------------------------------------------------------------------------------- /ed/models/model.lua: -------------------------------------------------------------------------------- 1 | dofile 'ed/models/SetConstantDiag.lua' 2 | dofile 'ed/models/linear_layers.lua' 3 | dofile 'ed/models/model_local.lua' 4 | dofile 'ed/models/model_global.lua' 5 | 6 | function get_model(num_mentions) 7 | local model_ctxt, additional_local_submodels = local_model(num_mentions, A_linear, B_linear) 8 | local model = model_ctxt 9 | 10 | if opt.model == 'global' then 11 | model = global_model(num_mentions, model_ctxt, C_linear, f_network) 12 | else 13 | assert(opt.model == 'local') 14 | end 15 | 16 | return model, additional_local_submodels 17 | end 18 | -------------------------------------------------------------------------------- /ed/models/model_global.lua: -------------------------------------------------------------------------------- 1 | -- Definition of the neural network used for global (joint) ED. Section 5 of our paper. 2 | -- It unrolls a fixed number of LBP iterations allowing training of the CRF potentials using backprop. 3 | -- To run a simple unit test that checks the forward and backward passes, just run : 4 | -- th ed/models/model_global.lua 5 | 6 | if not opt then -- unit tests 7 | dofile 'ed/models/model_local.lua' 8 | dofile 'ed/models/SetConstantDiag.lua' 9 | dofile 'ed/models/linear_layers.lua' 10 | 11 | opt.lbp_iter = 10 12 | opt.lbp_damp = 0.5 13 | opt.model = 'global' 14 | end 15 | 16 | 17 | function global_model(num_mentions, model_ctxt, param_C_linear, param_f_network) 18 | 19 | assert(num_mentions) 20 | assert(model_ctxt) 21 | assert(param_C_linear) 22 | assert(param_f_network) 23 | 24 | local unary_plus_pairwise = nn.Sequential() 25 | :add(nn.ParallelTable() 26 | :add(nn.Sequential() -- Pairwise scores s_{ij}(y_i, y_j) 27 | :add(nn.View(num_mentions * max_num_cand, ent_vecs_size)) -- e_vecs 28 | :add(nn.ConcatTable() 29 | :add(nn.Identity()) 30 | :add(param_C_linear) 31 | ) 32 | :add(nn.MM(false, true)) -- s_{ij}(y_i, y_j) is s[i][y_i][j][y_j] = 33 | :add(nn.View(num_mentions, max_num_cand, num_mentions, max_num_cand)) 34 | :add(nn.MulConstant(2.0 / num_mentions, true)) 35 | ) 36 | :add(nn.Sequential() -- Unary scores s_j(y_j) 37 | :add(nn.Replicate(num_mentions * max_num_cand, 1)) 38 | :add(nn.Reshape(num_mentions, max_num_cand, num_mentions, max_num_cand)) 39 | ) 40 | ) 41 | :add(nn.CAddTable()) -- q[i][y_i][j][y_j] = s_j(y_j) + s_{ij}(y_i, y_j): num_mentions x max_num_cand x num_mentions x max_num_cand 42 | 43 | 44 | -- Input is q[i][y_i][j] : num_mentions x max_num_cand x num_mentions 45 | local messages_one_round = nn.ConcatTable() 46 | :add(nn.SelectTable(1)) -- 1. unary_plus_pairwise : num_mentions, max_num_cand, num_mentions, max_num_cand 47 | :add(nn.Sequential() 48 | :add(nn.ConcatTable() 49 | :add(nn.Sequential() 50 | :add(nn.SelectTable(2)) -- old messages 51 | :add(nn.Exp()) 52 | :add(nn.MulConstant(1.0 - opt.lbp_damp, false)) 53 | ) 54 | :add(nn.Sequential() 55 | :add(nn.ParallelTable() 56 | :add(nn.Identity()) -- unary plus pairwise 57 | :add(nn.Sequential() -- old messages: num_mentions, max_num_cand, num_mentions 58 | :add(nn.Sum(3)) -- g[i][y_i] := \sum_{k != i} q[i][y_i][k] 59 | :add(nn.Replicate(num_mentions * max_num_cand, 1)) 60 | :add(nn.Reshape(num_mentions, max_num_cand, num_mentions, max_num_cand)) 61 | ) 62 | ) 63 | :add(nn.CAddTable()) -- s_{j}(y_j) + s_{ij}(y_i, y_j) + g[j][y_j] : num_mentions, max_num_cand, num_mentions, max_num_cand 64 | :add(nn.Max(4)) -- unnorm_q[i][y_i][j] : num_mentions x max_num_cand x num_mentions 65 | :add(nn.Transpose({2,3})) 66 | :add(nn.View(num_mentions * num_mentions, max_num_cand)) 67 | :add(nn.LogSoftMax()) -- normalization: \sum_{y_i} exp(q[i][y_i][j]) = 1 68 | :add(nn.View(num_mentions, num_mentions, max_num_cand)) 69 | :add(nn.Transpose({2,3})) 70 | :add(nn.Transpose({1,2})) 71 | :add(nn.SetConstantDiag(0, true)) -- we make q[i][y_i][i] = 0, \forall i and y_i 72 | :add(nn.Transpose({1,2})) 73 | :add(nn.Exp()) 74 | :add(nn.MulConstant(opt.lbp_damp, false)) 75 | ) 76 | ) 77 | :add(nn.CAddTable()) -- 2. messages for next round: num_mentions, max_num_cand, num_mentions 78 | :add(nn.Log()) 79 | ) 80 | 81 | 82 | local messages_all_rounds = nn.Sequential() 83 | messages_all_rounds:add(nn.Identity()) 84 | for i = 1, opt.lbp_iter do 85 | messages_all_rounds:add(messages_one_round:clone('weight','bias','gradWeight','gradBias')) 86 | end 87 | 88 | local model_gl = nn.Sequential() 89 | :add(nn.ConcatTable() 90 | :add(nn.Sequential() 91 | :add(nn.SelectTable(2)) 92 | :add(nn.SelectTable(2)) -- e_vecs : num_mentions, max_num_cand, ent_vecs_size 93 | ) 94 | :add(model_ctxt) -- unary scores : num_mentions, max_num_cand 95 | ) 96 | :add(nn.ConcatTable() 97 | :add(nn.Sequential() 98 | :add(unary_plus_pairwise) 99 | :add(nn.ConcatTable() 100 | :add(nn.Identity()) 101 | :add(nn.Sequential() 102 | :add(nn.Max(4)) 103 | :add(nn.MulConstant(0, false)) -- first_round_zero_messages 104 | ) 105 | ) 106 | :add(messages_all_rounds) 107 | :add(nn.SelectTable(2)) 108 | :add(nn.Sum(3)) -- \sum_{j} msgs[i][y_i]: num_mentions x max_num_cand 109 | ) 110 | :add(nn.SelectTable(2)) -- unary scores : num_mentions x max_num_cand 111 | ) 112 | :add(nn.CAddTable()) 113 | :add(nn.LogSoftMax()) -- belief[i][y_i] (lbp marginals in log scale) 114 | 115 | 116 | -- Combine lbp marginals with log p(e|m) using the simple f neural network 117 | local pem_layer = nn.SelectTable(3) 118 | model_gl = nn.Sequential() 119 | :add(nn.ConcatTable() 120 | :add(nn.Sequential() 121 | :add(model_gl) 122 | :add(nn.View(num_mentions * max_num_cand, 1)) 123 | ) 124 | :add(nn.Sequential() 125 | :add(pem_layer) 126 | :add(nn.View(num_mentions * max_num_cand, 1)) 127 | ) 128 | ) 129 | :add(nn.JoinTable(2)) 130 | :add(param_f_network) 131 | :add(nn.View(num_mentions, max_num_cand)) 132 | 133 | ------- Cuda conversions: 134 | if string.find(opt.type, 'cuda') then 135 | model_gl = model_gl:cuda() 136 | end 137 | 138 | return model_gl 139 | end 140 | 141 | 142 | --- Unit tests 143 | if unit_tests_now then 144 | print('\n Global network model unit tests:') 145 | local num_mentions = 13 146 | 147 | local inputs = {} 148 | -- ctxt_w_vecs 149 | inputs[1] = {} 150 | inputs[1][1] = torch.ones(num_mentions, opt.ctxt_window):int():mul(unk_w_id) 151 | inputs[1][2] = torch.randn(num_mentions, opt.ctxt_window, ent_vecs_size) 152 | -- e_vecs 153 | inputs[2] = {} 154 | inputs[2][1] = torch.ones(num_mentions, max_num_cand):int():mul(unk_ent_thid) 155 | inputs[2][2] = torch.randn(num_mentions, max_num_cand, ent_vecs_size) 156 | -- p(e|m) 157 | inputs[3] = torch.log(torch.rand(num_mentions, max_num_cand)) 158 | 159 | local model_ctxt, _ = local_model(num_mentions, A_linear, B_linear, opt) 160 | 161 | local model_gl = global_model(num_mentions, model_ctxt, C_linear, f_network) 162 | local outputs = model_gl:forward(inputs) 163 | print(outputs) 164 | print('MIN: ' .. torch.min(outputs) .. ' MAX: ' .. torch.max(outputs)) 165 | assert(outputs:size(1) == num_mentions and outputs:size(2) == max_num_cand) 166 | print('Global FWD success!') 167 | 168 | model_gl:backward(inputs, torch.randn(num_mentions, max_num_cand)) 169 | print('Global BKWD success!') 170 | 171 | parameters,gradParameters = model_gl:getParameters() 172 | print(parameters:size()) 173 | print(gradParameters:size()) 174 | end -------------------------------------------------------------------------------- /ed/models/model_local.lua: -------------------------------------------------------------------------------- 1 | -- Definition of the local neural network with attention used for local (independent per each mention) ED. 2 | -- Section 4 of our paper. 3 | -- To run a simple unit test that checks the forward and backward passes, just run : 4 | -- th ed/models/model_local.lua 5 | 6 | if not opt then -- unit tests 7 | unit_tests_now = true 8 | dofile 'utils/utils.lua' 9 | require 'nn' 10 | opt = {type = 'double', ctxt_window = 100, R = 25, model = 'local', nn_pem_interm_size = 100} 11 | 12 | word_vecs_size = 300 13 | ent_vecs_size = 300 14 | max_num_cand = 6 15 | unk_ent_wikiid = 1 16 | unk_ent_thid = 1 17 | unk_w_id = 1 18 | dofile 'ed/models/linear_layers.lua' 19 | word_lookup_table = nn.LookupTable(5, ent_vecs_size) 20 | ent_lookup_table = nn.LookupTable(5, ent_vecs_size) 21 | else 22 | word_lookup_table = nn.LookupTable(w2vutils.M:size(1), ent_vecs_size) 23 | word_lookup_table.weight = w2vutils.M 24 | 25 | ent_lookup_table = nn.LookupTable(e2vutils.lookup:size(1), ent_vecs_size) 26 | ent_lookup_table.weight = e2vutils.lookup 27 | 28 | if string.find(opt.type, 'cuda') then 29 | word_lookup_table = word_lookup_table:cuda() 30 | ent_lookup_table = ent_lookup_table:cuda() 31 | end 32 | end 33 | 34 | assert(word_vecs_size == 300 and ent_vecs_size == 300) 35 | 36 | 37 | ----------------- Define the model 38 | function local_model(num_mentions, param_A_linear, param_B_linear) 39 | 40 | assert(num_mentions) 41 | assert(param_A_linear) 42 | assert(param_B_linear) 43 | 44 | model = nn.Sequential() 45 | 46 | ctxt_embed_and_ent_lookup = nn.ConcatTable() 47 | :add(nn.Sequential() 48 | :add(nn.SelectTable(1)) 49 | :add(nn.SelectTable(2)) -- 1 : Context words W : num_mentions x opt.ctxt_window x ent_vecs_size 50 | ) 51 | :add(nn.Sequential() 52 | :add(nn.SelectTable(2)) 53 | :add(nn.SelectTable(2)) -- 2 : Candidate entity vectors E : num_mentions x max_num_cand x ent_vecs_size 54 | ) 55 | :add(nn.SelectTable(3)) -- 3 : log p(e|m) : num_mentions x max_num_cand 56 | 57 | model:add(ctxt_embed_and_ent_lookup) 58 | 59 | 60 | local mem_weights_p_2 = nn.ConcatTable() 61 | :add(nn.Identity()) -- 1 : {W, E, logp(e|m)} 62 | :add(nn.Sequential() 63 | :add(nn.ConcatTable() 64 | :add(nn.Sequential() 65 | :add(nn.SelectTable(2)) 66 | :add(nn.View(num_mentions * max_num_cand, ent_vecs_size)) -- E : num_mentions x max_num_cand x ent_vecs_size 67 | :add(param_A_linear) 68 | :add(nn.View(num_mentions, max_num_cand, ent_vecs_size)) -- (A*) E 69 | ) 70 | :add(nn.SelectTable(1)) --- W : num_mentions x opt.ctxt_window x ent_vecs_size 71 | ) 72 | :add(nn.MM(false, true)) -- 2 : E^t * A * W : num_mentions x max_num_cand x opt.ctxt_window 73 | ) 74 | model:add(mem_weights_p_2) 75 | 76 | 77 | local mem_weights_p_3 = nn.ConcatTable() 78 | :add(nn.SelectTable(1)) -- 1 : {W, E, logp(e|m)} 79 | :add(nn.Sequential() 80 | :add(nn.SelectTable(2)) 81 | :add(nn.Max(2)) --- 2 : max(word-entity scores) : num_mentions x opt.ctxt_window 82 | ) 83 | 84 | model:add(mem_weights_p_3) 85 | 86 | 87 | local mem_weights_p_4 = nn.ConcatTable() 88 | :add(nn.SelectTable(1)) -- 1 : {W, E, logp(e|m)} 89 | :add(nn.Sequential() 90 | :add(nn.SelectTable(2)) 91 | -- keep only top K scored words 92 | :add(nn.ConcatTable() 93 | :add(nn.Identity()) -- all w-e scores 94 | :add(nn.Sequential() 95 | :add(nn.View(num_mentions, opt.ctxt_window, 1)) 96 | :add(nn.TemporalDynamicKMaxPooling(opt.R)) 97 | :add(nn.Min(2)) -- k-th largest w-e score 98 | :add(nn.View(num_mentions)) 99 | :add(nn.Replicate(opt.ctxt_window, 2)) 100 | ) 101 | ) 102 | :add(nn.ConcatTable() -- top k w-e scores (the rest are set to -infty) 103 | :add(nn.SelectTable(2)) -- k-th largest w-e score that we substract and then add again back after nn.Threshold 104 | :add(nn.Sequential() 105 | :add(nn.CSubTable()) 106 | :add(nn.Threshold(0, -50, true)) 107 | ) 108 | ) 109 | :add(nn.CAddTable()) 110 | :add(nn.SoftMax()) -- 2 : sigma (attention weights normalized): num_mentions x opt.ctxt_window 111 | :add(nn.View(num_mentions, opt.ctxt_window, 1)) 112 | ) 113 | 114 | model:add(mem_weights_p_4) 115 | 116 | 117 | local ctxt_full_embeddings = nn.ConcatTable() 118 | :add(nn.SelectTable(1)) -- 1 : {W, E, logp(e|m)} 119 | :add(nn.Sequential() 120 | :add(nn.ConcatTable() 121 | :add(nn.Sequential() 122 | :add(nn.SelectTable(1)) 123 | :add(nn.SelectTable(1)) -- W 124 | ) 125 | :add(nn.SelectTable(2)) -- sigma 126 | ) 127 | :add(nn.MM(true, false)) 128 | :add(nn.View(num_mentions, ent_vecs_size)) -- 2 : ctxt embedding = (W * B)^\top * sigma : num_mentions x ent_vecs_size 129 | ) 130 | 131 | model:add(ctxt_full_embeddings) 132 | 133 | 134 | local entity_context_sim_scores = nn.ConcatTable() 135 | :add(nn.SelectTable(1)) -- 1 : {W, E, logp(e|m)} 136 | :add(nn.Sequential() 137 | :add(nn.ConcatTable() 138 | :add(nn.Sequential() 139 | :add(nn.SelectTable(1)) 140 | :add(nn.SelectTable(2)) -- E 141 | ) 142 | :add(nn.Sequential() 143 | :add(nn.SelectTable(2)) 144 | :add(param_B_linear) 145 | :add(nn.View(num_mentions, ent_vecs_size, 1)) -- context vectors 146 | ) 147 | ) 148 | :add(nn.MM()) --> 2. context * E^T 149 | :add(nn.View(num_mentions, max_num_cand)) 150 | ) 151 | 152 | model:add(entity_context_sim_scores) 153 | 154 | if opt.model == 'local' then 155 | model = nn.Sequential() 156 | :add(nn.ConcatTable() 157 | :add(nn.Sequential() 158 | :add(model) 159 | :add(nn.SelectTable(2)) -- context - entity similarity scores 160 | :add(nn.View(num_mentions * max_num_cand, 1)) 161 | ) 162 | :add(nn.Sequential() 163 | :add(nn.SelectTable(3)) -- log p(e|m) scores 164 | :add(nn.View(num_mentions * max_num_cand, 1)) 165 | ) 166 | ) 167 | :add(nn.JoinTable(2)) 168 | :add(f_network) 169 | :add(nn.View(num_mentions, max_num_cand)) 170 | else 171 | model:add(nn.SelectTable(2)) -- context - entity similarity scores 172 | end 173 | 174 | 175 | ------------- Visualizing weights: 176 | 177 | -- sigma (attention weights normalized): num_mentions x opt.ctxt_window 178 | local model_debug_softmax_word_weights = nn.Sequential() 179 | :add(ctxt_embed_and_ent_lookup) 180 | :add(mem_weights_p_2) 181 | :add(mem_weights_p_3) 182 | :add(mem_weights_p_4) 183 | :add(nn.SelectTable(2)) 184 | :add(nn.View(num_mentions, opt.ctxt_window)) 185 | 186 | ------- Cuda conversions: 187 | if string.find(opt.type, 'cuda') then 188 | model = model:cuda() 189 | model_debug_softmax_word_weights = model_debug_softmax_word_weights:cuda() 190 | end 191 | 192 | local additional_local_submodels = { 193 | model_final_local = model, 194 | model_debug_softmax_word_weights = model_debug_softmax_word_weights, 195 | } 196 | 197 | return model, additional_local_submodels 198 | end 199 | 200 | 201 | --- Unit tests 202 | if unit_tests_now then 203 | print('Network model unit tests:') 204 | local num_mentions = 13 205 | 206 | local inputs = {} 207 | 208 | -- ctxt_w_vecs 209 | inputs[1] = {} 210 | inputs[1][1] = torch.ones(num_mentions, opt.ctxt_window):int():mul(unk_w_id) 211 | inputs[1][2] = torch.randn(num_mentions, opt.ctxt_window, ent_vecs_size) 212 | -- e_vecs 213 | inputs[2] = {} 214 | inputs[2][1] = torch.ones(num_mentions, max_num_cand):int():mul(unk_ent_thid) 215 | inputs[2][2] = torch.randn(num_mentions, max_num_cand, ent_vecs_size) 216 | -- p(e|m) 217 | inputs[3] = torch.zeros(num_mentions, max_num_cand) 218 | 219 | local model, additional_local_submodels = local_model(num_mentions, A_linear, B_linear, opt) 220 | 221 | print(additional_local_submodels.model_debug_softmax_word_weights:forward(inputs):size()) 222 | 223 | local outputs = model:forward(inputs) 224 | assert(outputs:size(1) == num_mentions and outputs:size(2) == max_num_cand) 225 | print('FWD success!') 226 | 227 | model:backward(inputs, torch.randn(num_mentions, max_num_cand)) 228 | print('BKWD success!') 229 | 230 | parameters,gradParameters = model:getParameters() 231 | print(parameters:size()) 232 | print(gradParameters:size()) 233 | end 234 | -------------------------------------------------------------------------------- /ed/test/check_coref.lua: -------------------------------------------------------------------------------- 1 | -- Runs our trivial coreference resolution method and outputs the new set of 2 | -- entity candidates. Used for debugging the coreference resolution method. 3 | 4 | if not opt then 5 | cmd = torch.CmdLine() 6 | cmd:option('-root_data_dir', '', 'Root path of the data, $DATA_PATH.') 7 | cmd:text() 8 | opt = cmd:parse(arg or {}) 9 | assert(opt.root_data_dir ~= '', 'Specify a valid root_data_dir path argument.') 10 | end 11 | 12 | 13 | require 'torch' 14 | dofile 'utils/utils.lua' 15 | 16 | tds = tds or require 'tds' 17 | 18 | dofile 'entities/ent_name2id_freq/ent_name_id.lua' 19 | dofile 'ed/test/coref.lua' 20 | 21 | file = opt.root_data_dir .. 'generated/test_train_data/aida_testB.csv' 22 | 23 | opt = {} 24 | opt.coref = true 25 | 26 | it, _ = io.open(file) 27 | local all_doc_lines = tds.Hash() 28 | local line = it:read() 29 | while line do 30 | local parts = split(line, '\t') 31 | local doc_name = parts[1] 32 | if not all_doc_lines[doc_name] then 33 | all_doc_lines[doc_name] = tds.Hash() 34 | end 35 | all_doc_lines[doc_name][1 + #all_doc_lines[doc_name]] = line 36 | line = it:read() 37 | end 38 | -- Gather coreferent mentions to increase accuracy. 39 | build_coreference_dataset(all_doc_lines, 'aida-B') 40 | -------------------------------------------------------------------------------- /ed/test/coref_persons.lua: -------------------------------------------------------------------------------- 1 | -- Given a dataset, try to retrieve better entity candidates 2 | -- for ambiguous mentions of persons. For example, suppose a document 3 | -- contains a mention of a person called 'Peter Such' that can be easily solved with 4 | -- the current system. Now suppose that, in the same document, there 5 | -- exists a mention 'Such' referring to the same person. For this 6 | -- second highly ambiguous mention, retrieving the correct entity in 7 | -- top K candidates would be very hard. We adopt here a simple heuristical strategy of 8 | -- searching in the same document all potentially coreferent mentions that strictly contain 9 | -- the given mention as a substring. If such mentions exist and they refer to 10 | -- persons (contain at least one person candidate entity), then the ambiguous 11 | -- shorter mention gets as candidates the candidates of the longer mention. 12 | 13 | tds = tds or require 'tds' 14 | assert(get_ent_wikiid_from_name) 15 | print('==> Loading index of Wiki entities that represent persons.') 16 | 17 | local persons_ent_wikiids = tds.Hash() 18 | for line in io.lines(opt.root_data_dir .. 'basic_data/p_e_m_data/persons.txt') do 19 | local ent_wikiid = get_ent_wikiid_from_name(line, true) 20 | if ent_wikiid ~= unk_ent_wikiid then 21 | persons_ent_wikiids[ent_wikiid] = 1 22 | end 23 | end 24 | 25 | function is_person(ent_wikiid) 26 | return persons_ent_wikiids[ent_wikiid] 27 | end 28 | 29 | print(' Done loading persons index. Size = ' .. #persons_ent_wikiids) 30 | 31 | 32 | local function mention_refers_to_person(m, mention_ent_cand) 33 | local top_p = 0 34 | local top_ent = -1 35 | for e_wikiid, p_e_m in pairs(mention_ent_cand[m]) do 36 | if p_e_m > top_p then 37 | top_ent = e_wikiid 38 | top_p = p_e_m 39 | end 40 | end 41 | return is_person(top_ent) 42 | end 43 | 44 | 45 | function build_coreference_dataset(dataset_lines, banner) 46 | if (not opt.coref) then 47 | return dataset_lines 48 | else 49 | 50 | -- Create new entity candidates 51 | local coref_dataset_lines = tds.Hash() 52 | for doc_id, lines_map in pairs(dataset_lines) do 53 | 54 | coref_dataset_lines[doc_id] = tds.Hash() 55 | 56 | -- Collect entity candidates for each mention. 57 | local mention_ent_cand = {} 58 | for _,sample_line in pairs(lines_map) do 59 | local parts = split(sample_line, "\t") 60 | assert(doc_id == parts[1]) 61 | local mention = parts[3]:lower() 62 | if not mention_ent_cand[mention] then 63 | mention_ent_cand[mention] = {} 64 | assert(parts[6] == 'CANDIDATES') 65 | if parts[7] ~= 'EMPTYCAND' then 66 | local num_cand = 1 67 | while parts[6 + num_cand] ~= 'GT:' do 68 | local cand_parts = split(parts[6 + num_cand], ',') 69 | local cand_ent_wikiid = tonumber(cand_parts[1]) 70 | local cand_p_e_m = tonumber(cand_parts[2]) 71 | assert(cand_p_e_m >= 0, cand_p_e_m) 72 | assert(cand_ent_wikiid) 73 | mention_ent_cand[mention][cand_ent_wikiid] = cand_p_e_m 74 | num_cand = num_cand + 1 75 | end 76 | end 77 | end 78 | end 79 | 80 | -- Find coreferent mentions 81 | for _,sample_line in pairs(lines_map) do 82 | local parts = split(sample_line, "\t") 83 | assert(doc_id == parts[1]) 84 | local mention = parts[3]:lower() 85 | 86 | assert(mention_ent_cand[mention]) 87 | assert(parts[table_len(parts) - 1] == 'GT:') 88 | 89 | -- Grd trth infos 90 | local grd_trth_parts = split(parts[table_len(parts)], ',') 91 | local grd_trth_idx = tonumber(grd_trth_parts[1]) 92 | assert(grd_trth_idx == -1 or table_len(grd_trth_parts) >= 4, sample_line) 93 | local grd_trth_entwikiid = -1 94 | if table_len(grd_trth_parts) >= 3 then 95 | grd_trth_entwikiid = tonumber(grd_trth_parts[2]) 96 | end 97 | assert(grd_trth_entwikiid) 98 | 99 | -- Merge lists of entity candidates 100 | local added_list = {} 101 | local num_added_mentions = 0 102 | for m,_ in pairs(mention_ent_cand) do 103 | local stupid_lua_pattern = string.gsub(mention, '%.', '%%%.') 104 | stupid_lua_pattern = string.gsub(stupid_lua_pattern, '%-', '%%%-') 105 | if m ~= mention and (string.find(m, ' ' .. stupid_lua_pattern) or string.find(m, stupid_lua_pattern .. ' ')) and mention_refers_to_person(m, mention_ent_cand) then 106 | 107 | if banner == 'aida-B' then 108 | print(blue('coref mention = ' .. m .. ' replaces original mention = ' .. mention) .. 109 | ' ; DOC = ' .. doc_id) 110 | end 111 | 112 | num_added_mentions = num_added_mentions + 1 113 | for e_wikiid, p_e_m in pairs(mention_ent_cand[m]) do 114 | if not added_list[e_wikiid] then 115 | added_list[e_wikiid] = 0.0 116 | end 117 | added_list[e_wikiid] = added_list[e_wikiid] + p_e_m 118 | end 119 | end 120 | end 121 | 122 | -- Average: 123 | for e_wikiid, _ in pairs(added_list) do 124 | added_list[e_wikiid] = added_list[e_wikiid] / num_added_mentions 125 | end 126 | 127 | -- Merge the two lists 128 | local merged_list = mention_ent_cand[mention] 129 | if num_added_mentions > 0 then 130 | merged_list = added_list 131 | end 132 | 133 | local sorted_list = {} 134 | for ent_wikiid,p in pairs(merged_list) do 135 | table.insert(sorted_list, {ent_wikiid = ent_wikiid, p = p}) 136 | end 137 | table.sort(sorted_list, function(a,b) return a.p > b.p end) 138 | 139 | -- Write the new line 140 | local str = parts[1] .. '\t' .. parts[2] .. '\t' .. parts[3] .. '\t' .. parts[4] .. '\t' 141 | .. parts[5] .. '\t' .. parts[6] .. '\t' 142 | 143 | if table_len(sorted_list) == 0 then 144 | str = str .. 'EMPTYCAND\tGT:\t-1' 145 | if grd_trth_entwikiid ~= unk_ent_wikiid then 146 | str = str .. ',' .. grd_trth_entwikiid .. ',' .. 147 | get_ent_name_from_wikiid(grd_trth_entwikiid) 148 | end 149 | else 150 | local candidates = {} 151 | local gt_pos = -1 152 | for pos,e in pairs(sorted_list) do 153 | if pos <= 100 then 154 | table.insert(candidates, e.ent_wikiid .. ',' .. 155 | string.format("%.3f", e.p) .. ',' .. get_ent_name_from_wikiid(e.ent_wikiid)) 156 | if e.ent_wikiid == grd_trth_entwikiid then 157 | gt_pos = pos 158 | end 159 | else 160 | break 161 | end 162 | end 163 | str = str .. table.concat(candidates, '\t') .. '\tGT:\t' 164 | 165 | if gt_pos > 0 then 166 | str = str .. gt_pos .. ',' .. candidates[gt_pos] 167 | else 168 | str = str .. '-1' 169 | if grd_trth_entwikiid ~= unk_ent_wikiid then 170 | str = str .. ',' .. grd_trth_entwikiid .. ',' .. 171 | get_ent_name_from_wikiid(grd_trth_entwikiid) 172 | end 173 | end 174 | end 175 | 176 | coref_dataset_lines[doc_id][1 + #coref_dataset_lines[doc_id]] = str 177 | end 178 | end 179 | 180 | assert(#dataset_lines == #coref_dataset_lines) 181 | for doc_id, lines_map in pairs(dataset_lines) do 182 | assert(table_len(dataset_lines[doc_id]) == table_len(coref_dataset_lines[doc_id])) 183 | end 184 | 185 | return coref_dataset_lines 186 | end 187 | end -------------------------------------------------------------------------------- /ed/test/ent_freq_stats_test.lua: -------------------------------------------------------------------------------- 1 | -- Statistics of annotated entities based on their frequency in Wikipedia corpus 2 | -- Table 6 (left) from our paper 3 | local function ent_freq_to_key(f) 4 | if f == 0 then 5 | return '0' 6 | elseif f == 1 then 7 | return '1' 8 | elseif f <= 5 then 9 | return '2-5' 10 | elseif f <= 10 then 11 | return '6-10' 12 | elseif f <= 20 then 13 | return '11-20' 14 | elseif f <= 50 then 15 | return '21-50' 16 | else 17 | return '50+' 18 | end 19 | end 20 | 21 | 22 | function new_ent_freq_map() 23 | local m = {} 24 | m['0'] = 0.0 25 | m['1'] = 0.0 26 | m['2-5'] = 0.0 27 | m['6-10'] = 0.0 28 | m['11-20'] = 0.0 29 | m['21-50'] = 0.0 30 | m['50+'] = 0.0 31 | return m 32 | end 33 | 34 | function add_freq_to_ent_freq_map(m, f) 35 | m[ent_freq_to_key(f)] = m[ent_freq_to_key(f)] + 1 36 | end 37 | 38 | function print_ent_freq_maps_stats(smallm, bigm) 39 | print(' ===> entity frequency stats :') 40 | for k,_ in pairs(smallm) do 41 | local perc = 0 42 | if bigm[k] > 0 then 43 | perc = 100.0 * smallm[k] / bigm[k] 44 | end 45 | assert(perc <= 100) 46 | print('freq = ' .. k .. ' : num = ' .. bigm[k] .. 47 | ' ; correctly classified = ' .. smallm[k] .. 48 | ' ; perc = ' .. string.format("%.2f", perc)) 49 | end 50 | print('') 51 | end 52 | 53 | 54 | -------------------------------------------------------------------------------- /ed/test/ent_p_e_m_stats_test.lua: -------------------------------------------------------------------------------- 1 | -- Statistics of annotated entities based on their p(e|m) prio 2 | -- Table 6 (right) from our paper 3 | 4 | local function ent_prior_to_key(f) 5 | if f <= 0.001 then 6 | return '<=0.001' 7 | elseif f <= 0.003 then 8 | return '0.001-0.003' 9 | elseif f <= 0.01 then 10 | return '0.003-0.01' 11 | elseif f <= 0.03 then 12 | return '0.01-0.03' 13 | elseif f <= 0.1 then 14 | return '0.03-0.1' 15 | elseif f <= 0.3 then 16 | return '0.1-0.3' 17 | else 18 | return '0.3+' 19 | end 20 | end 21 | 22 | 23 | function new_ent_prior_map() 24 | local m = {} 25 | m['<=0.001'] = 0.0 26 | m['0.001-0.003'] = 0.0 27 | m['0.003-0.01'] = 0.0 28 | m['0.01-0.03'] = 0.0 29 | m['0.03-0.1'] = 0.0 30 | m['0.1-0.3'] = 0.0 31 | m['0.3+'] = 0.0 32 | return m 33 | end 34 | 35 | function add_prior_to_ent_prior_map(m, f) 36 | m[ent_prior_to_key(f)] = m[ent_prior_to_key(f)] + 1 37 | end 38 | 39 | function print_ent_prior_maps_stats(smallm, bigm) 40 | print(' ===> entity p(e|m) stats :') 41 | for k,_ in pairs(smallm) do 42 | local perc = 0 43 | if bigm[k] > 0 then 44 | perc = 100.0 * smallm[k] / bigm[k] 45 | end 46 | assert(perc <= 100) 47 | print('p(e|m) = ' .. k .. ' : num = ' .. bigm[k] .. 48 | ' ; correctly classified = ' .. smallm[k] .. 49 | ' ; perc = ' .. string.format("%.2f", perc)) 50 | end 51 | end 52 | 53 | 54 | -------------------------------------------------------------------------------- /ed/test/test_one_loaded_model.lua: -------------------------------------------------------------------------------- 1 | -- Test one single ED model trained using ed/ed.lua 2 | 3 | -- Run: CUDA_VISIBLE_DEVICES=0 th ed/test/test_one_loaded_model.lua -root_data_dir $DATA_PATH -model global -ent_vecs_filename $ENTITY_VECS -test_one_model_file $ED_MODEL_FILENAME 4 | require 'optim' 5 | require 'torch' 6 | require 'gnuplot' 7 | require 'nn' 8 | require 'xlua' 9 | tds = tds or require 'tds' 10 | 11 | dofile 'utils/utils.lua' 12 | print('\n' .. green('==========> Test a pre-trained ED neural model <==========') .. '\n') 13 | 14 | dofile 'ed/args.lua' 15 | 16 | print('===> RUN TYPE: ' .. opt.type) 17 | torch.setdefaulttensortype('torch.FloatTensor') 18 | if string.find(opt.type, 'cuda') then 19 | print('==> switching to CUDA (GPU)') 20 | require 'cunn' 21 | require 'cutorch' 22 | require 'cudnn' 23 | cudnn.benchmark = true 24 | cudnn.fastest = true 25 | else 26 | print('==> running on CPU') 27 | end 28 | 29 | extract_args_from_model_title(opt.test_one_model_file) 30 | 31 | dofile 'utils/logger.lua' 32 | dofile 'entities/relatedness/relatedness.lua' 33 | dofile 'entities/ent_name2id_freq/ent_name_id.lua' 34 | dofile 'entities/ent_name2id_freq/e_freq_index.lua' 35 | dofile 'words/load_w_freq_and_vecs.lua' 36 | dofile 'words/w2v/w2v.lua' 37 | dofile 'entities/pretrained_e2v/e2v.lua' 38 | dofile 'ed/minibatch/build_minibatch.lua' 39 | dofile 'ed/models/model.lua' 40 | dofile 'ed/test/test.lua' 41 | 42 | local saved_linears = torch.load(opt.root_data_dir .. 'generated/ed_models/' .. opt.test_one_model_file) 43 | unpack_saveable_weights(saved_linears) 44 | 45 | test() 46 | -------------------------------------------------------------------------------- /ed/train.lua: -------------------------------------------------------------------------------- 1 | -- Training of ED models. 2 | 3 | if opt.opt == 'SGD' then 4 | optimState = { 5 | learningRate = opt.lr, 6 | momentum = 0.9, 7 | learningRateDecay = 5e-7 8 | } 9 | optimMethod = optim.sgd 10 | 11 | elseif opt.opt == 'ADAM' then -- See: http://cs231n.github.io/neural-networks-3/#update 12 | optimState = { 13 | learningRate = opt.lr, 14 | } 15 | optimMethod = optim.adam 16 | 17 | elseif opt.opt == 'ADADELTA' then -- See: http://cs231n.github.io/neural-networks-3/#update 18 | -- Run with default parameters, no need for learning rate and other stuff 19 | optimState = {} 20 | optimConfig = {} 21 | optimMethod = optim.adadelta 22 | 23 | elseif opt.opt == 'ADAGRAD' then -- See: http://cs231n.github.io/neural-networks-3/#update 24 | optimMethod = optim.adagrad 25 | optimState = { 26 | learningRate = opt.lr 27 | } 28 | 29 | else 30 | error('unknown optimization method') 31 | end 32 | ---------------------------------------------------------------------- 33 | 34 | -- Each batch is one document, so we test/validate/save the current model after each set of 35 | -- 5000 documents. Since aida-train contains 946 documents, this is equivalent with 5 full epochs. 36 | num_batches_per_epoch = 5000 37 | 38 | function train_and_test() 39 | 40 | print('\nDone testing for ' .. banner) 41 | print('Params serialized = ' .. params_serialized) 42 | 43 | -- epoch tracker 44 | epoch = 1 45 | 46 | local processed_so_far = 0 47 | 48 | local f_bs = 0 49 | local gradParameters_bs = nil 50 | 51 | while true do 52 | local time = sys.clock() 53 | print('\n') 54 | print('One epoch = ' .. (num_batches_per_epoch / 1000) .. ' full passes over AIDA-TRAIN in our case.') 55 | print(green('==> TRAINING EPOCH #' .. epoch .. ' <==')) 56 | 57 | print_net_weights() 58 | 59 | local processed_mentions = 0 60 | for batch_index = 1,num_batches_per_epoch do 61 | -- Read one mini-batch from one data_thread: 62 | local inputs, targets = get_minibatch() 63 | 64 | local num_mentions = targets:size(1) 65 | processed_mentions = processed_mentions + num_mentions 66 | 67 | local model, _ = get_model(num_mentions) 68 | model:training() 69 | 70 | -- Retrieve parameters and gradients: 71 | -- extracts and flattens all model's parameters into a 1-dim vector 72 | parameters,gradParameters = model:getParameters() 73 | gradParameters:zero() 74 | 75 | -- Just in case: 76 | collectgarbage() 77 | collectgarbage() 78 | 79 | -- Reset gradients 80 | gradParameters:zero() 81 | 82 | -- Evaluate function for complete mini batch 83 | 84 | local outputs = model:forward(inputs) 85 | assert(outputs:size(1) == num_mentions and outputs:size(2) == max_num_cand) 86 | local f = criterion:forward(outputs, targets) 87 | 88 | -- Estimate df/dW 89 | local df_do = criterion:backward(outputs, targets) 90 | 91 | model:backward(inputs, df_do) 92 | 93 | if opt.batch_size == 1 or batch_index % opt.batch_size == 1 then 94 | gradParameters_bs = gradParameters:clone():zero() 95 | f_bs = 0 96 | end 97 | 98 | gradParameters_bs:add(gradParameters) 99 | f_bs = f_bs + f 100 | 101 | if opt.batch_size == 1 or batch_index % opt.batch_size == 0 then 102 | 103 | gradParameters_bs:div(opt.batch_size) 104 | f_bs = f_bs / opt.batch_size 105 | 106 | -- Create closure to evaluate f(X) and df/dX 107 | local feval = function(x) 108 | return f_bs, gradParameters_bs 109 | end 110 | 111 | -- Optimize on current mini-batch 112 | optimState.learningRate = opt.lr 113 | optimMethod(feval, parameters, optimState) 114 | 115 | -- Regularize the f_network with projected SGD. 116 | regularize_f_network() 117 | end 118 | 119 | -- Display progress 120 | processed_so_far = processed_so_far + num_mentions 121 | if processed_so_far > 100000000 then 122 | processed_so_far = processed_so_far - 100000000 123 | end 124 | xlua.progress(processed_so_far, 100000000) 125 | end 126 | 127 | -- Measure time taken 128 | time = sys.clock() - time 129 | time = time / processed_mentions 130 | 131 | print("\n==> time to learn 1 sample = " .. (time*1000) .. 'ms') 132 | 133 | -- Test: 134 | test(epoch) 135 | print('\nDone testing for ' .. banner) 136 | print('Params serialized = ' .. params_serialized) 137 | 138 | -- Save the current model: 139 | if opt.save then 140 | local filename = opt.root_data_dir .. 'generated/ed_models/' .. params_serialized .. '|ep=' .. epoch 141 | print('==> saving model to '..filename) 142 | torch.save(filename, pack_saveable_weights()) 143 | end 144 | 145 | -- Next epoch 146 | epoch = epoch + 1 147 | end 148 | end 149 | -------------------------------------------------------------------------------- /entities/ent_name2id_freq/e_freq_gen.lua: -------------------------------------------------------------------------------- 1 | -- Creates a file that contains entity frequencies. 2 | 3 | if not opt then 4 | cmd = torch.CmdLine() 5 | cmd:option('-root_data_dir', '', 'Root path of the data, $DATA_PATH.') 6 | cmd:text() 7 | opt = cmd:parse(arg or {}) 8 | assert(opt.root_data_dir ~= '', 'Specify a valid root_data_dir path argument.') 9 | end 10 | 11 | 12 | require 'torch' 13 | tds = tds or require 'tds' 14 | 15 | dofile 'utils/utils.lua' 16 | dofile 'entities/ent_name2id_freq/ent_name_id.lua' 17 | 18 | entity_freqs = tds.Hash() 19 | 20 | local num_lines = 0 21 | it, _ = io.open(opt.root_data_dir .. 'generated/crosswikis_wikipedia_p_e_m.txt') 22 | line = it:read() 23 | 24 | while (line) do 25 | num_lines = num_lines + 1 26 | if num_lines % 2000000 == 0 then 27 | print('Processed ' .. num_lines .. ' lines. ') 28 | end 29 | 30 | local parts = split(line , '\t') 31 | local num_parts = table_len(parts) 32 | for i = 3, num_parts do 33 | local ent_str = split(parts[i], ',') 34 | local ent_wikiid = tonumber(ent_str[1]) 35 | local freq = tonumber(ent_str[2]) 36 | assert(ent_wikiid) 37 | assert(freq) 38 | 39 | if (not entity_freqs[ent_wikiid]) then 40 | entity_freqs[ent_wikiid] = 0 41 | end 42 | entity_freqs[ent_wikiid] = entity_freqs[ent_wikiid] + freq 43 | end 44 | line = it:read() 45 | end 46 | 47 | 48 | -- Writing word frequencies 49 | print('Sorting and writing') 50 | sorted_ent_freq = {} 51 | for ent_wikiid,freq in pairs(entity_freqs) do 52 | if freq >= 10 then 53 | table.insert(sorted_ent_freq, {ent_wikiid = ent_wikiid, freq = freq}) 54 | end 55 | end 56 | 57 | table.sort(sorted_ent_freq, function(a,b) return a.freq > b.freq end) 58 | 59 | out_file = opt.root_data_dir .. 'generated/ent_wiki_freq.txt' 60 | ouf = assert(io.open(out_file, "w")) 61 | total_freq = 0 62 | for _,x in pairs(sorted_ent_freq) do 63 | ouf:write(x.ent_wikiid .. '\t' .. get_ent_name_from_wikiid(x.ent_wikiid) .. '\t' .. x.freq .. '\n') 64 | total_freq = total_freq + x.freq 65 | end 66 | ouf:flush() 67 | io.close(ouf) 68 | 69 | print('Total freq = ' .. total_freq .. '\n') 70 | -------------------------------------------------------------------------------- /entities/ent_name2id_freq/e_freq_index.lua: -------------------------------------------------------------------------------- 1 | -- Loads an index containing entity -> frequency pairs. 2 | -- TODO: rewrite this file in a simpler way (is complicated because of some past experiments). 3 | tds = tds or require 'tds' 4 | 5 | print('==> Loading entity freq map') 6 | 7 | local ent_freq_file = opt.root_data_dir .. 'generated/ent_wiki_freq.txt' 8 | 9 | min_freq = 1 10 | e_freq = tds.Hash() 11 | e_freq.ent_f_start = tds.Hash() 12 | e_freq.ent_f_end = tds.Hash() 13 | e_freq.total_freq = 0 14 | e_freq.sorted = tds.Hash() 15 | 16 | cur_start = 1 17 | cnt = 0 18 | for line in io.lines(ent_freq_file) do 19 | local parts = split(line, '\t') 20 | local ent_wikiid = tonumber(parts[1]) 21 | local ent_f = tonumber(parts[3]) 22 | assert(ent_wikiid) 23 | assert(ent_f) 24 | 25 | if (not rewtr) or rewtr.reltd_ents_wikiid_to_rltdid[ent_wikiid] then 26 | e_freq.ent_f_start[ent_wikiid] = cur_start 27 | e_freq.ent_f_end[ent_wikiid] = cur_start + ent_f - 1 28 | e_freq.sorted[cnt] = ent_wikiid 29 | cur_start = cur_start + ent_f 30 | cnt = cnt + 1 31 | end 32 | end 33 | 34 | e_freq.total_freq = cur_start - 1 35 | 36 | print(' Done loading entity freq index. Size = ' .. cnt) 37 | 38 | function get_ent_freq(ent_wikiid) 39 | if e_freq.ent_f_start[ent_wikiid] then 40 | return e_freq.ent_f_end[ent_wikiid] - e_freq.ent_f_start[ent_wikiid] + 1 41 | end 42 | return 0 43 | end 44 | 45 | -------------------------------------------------------------------------------- /entities/ent_name2id_freq/ent_name_id.lua: -------------------------------------------------------------------------------- 1 | ------------------ Load entity name-id mappings ------------------ 2 | -- Each entity has: 3 | -- a) a Wikipedia URL referred as 'name' here 4 | -- b) a Wikipedia ID referred as 'ent_wikiid' or 'wikiid' here 5 | -- c) an ID that will be used in the entity embeddings lookup table. Referred as 'ent_thid' or 'thid' here. 6 | 7 | tds = tds or require 'tds' -- saves lots of memory for ent_name_id.lua. Mem overflow with normal {} 8 | local rltd_only = false 9 | if opt and opt.entities and opt.entities ~= 'ALL' then 10 | assert(rewtr.reltd_ents_wikiid_to_rltdid, 'Import relatedness.lua before ent_name_id.lua') 11 | rltd_only = true 12 | end 13 | 14 | -- Unk entity wikid: 15 | unk_ent_wikiid = 1 16 | 17 | local entity_wiki_txtfilename = opt.root_data_dir .. 'basic_data/wiki_name_id_map.txt' 18 | local entity_wiki_t7filename = opt.root_data_dir .. 'generated/ent_name_id_map.t7' 19 | if rltd_only then 20 | entity_wiki_t7filename = opt.root_data_dir .. 'generated/ent_name_id_map_RLTD.t7' 21 | end 22 | 23 | print('==> Loading entity wikiid - name map') 24 | 25 | local e_id_name = nil 26 | 27 | if paths.filep(entity_wiki_t7filename) then 28 | print(' ---> from t7 file: ' .. entity_wiki_t7filename) 29 | e_id_name = torch.load(entity_wiki_t7filename) 30 | 31 | else 32 | print(' ---> t7 file NOT found. Loading from disk (slower). Out f = ' .. entity_wiki_t7filename) 33 | dofile 'data_gen/indexes/wiki_disambiguation_pages_index.lua' 34 | print(' Still loading entity wikiid - name map ...') 35 | 36 | e_id_name = tds.Hash() 37 | 38 | -- map for entity name to entity wiki id 39 | e_id_name.ent_wikiid2name = tds.Hash() 40 | e_id_name.ent_name2wikiid = tds.Hash() 41 | 42 | -- map for entity wiki id to tensor id. Size = 4.4M 43 | if not rltd_only then 44 | e_id_name.ent_wikiid2thid = tds.Hash() 45 | e_id_name.ent_thid2wikiid = tds.Hash() 46 | end 47 | 48 | local cnt = 0 49 | local cnt_freq = 0 50 | for line in io.lines(entity_wiki_txtfilename) do 51 | local parts = split(line, '\t') 52 | local ent_name = parts[1] 53 | local ent_wikiid = tonumber(parts[2]) 54 | 55 | if (not wiki_disambiguation_index[ent_wikiid]) then 56 | if (not rltd_only) or rewtr.reltd_ents_wikiid_to_rltdid[ent_wikiid] then 57 | e_id_name.ent_wikiid2name[ent_wikiid] = ent_name 58 | e_id_name.ent_name2wikiid[ent_name] = ent_wikiid 59 | end 60 | if not rltd_only then 61 | cnt = cnt + 1 62 | e_id_name.ent_wikiid2thid[ent_wikiid] = cnt 63 | e_id_name.ent_thid2wikiid[cnt] = ent_wikiid 64 | end 65 | end 66 | end 67 | 68 | if not rltd_only then 69 | cnt = cnt + 1 70 | e_id_name.ent_wikiid2thid[unk_ent_wikiid] = cnt 71 | e_id_name.ent_thid2wikiid[cnt] = unk_ent_wikiid 72 | end 73 | e_id_name.ent_wikiid2name[unk_ent_wikiid] = 'UNK_ENT' 74 | e_id_name.ent_name2wikiid['UNK_ENT'] = unk_ent_wikiid 75 | 76 | torch.save(entity_wiki_t7filename, e_id_name) 77 | end 78 | 79 | if not rltd_only then 80 | unk_ent_thid = e_id_name.ent_wikiid2thid[unk_ent_wikiid] 81 | else 82 | unk_ent_thid = rewtr.reltd_ents_wikiid_to_rltdid[unk_ent_wikiid] 83 | end 84 | 85 | ------------------------ Functions for wikiids and names----------------- 86 | function get_map_all_valid_ents() 87 | local m = tds.Hash() 88 | for ent_wikiid, _ in pairs(e_id_name.ent_wikiid2name) do 89 | m[ent_wikiid] = 1 90 | end 91 | return m 92 | end 93 | 94 | is_valid_ent = function(ent_wikiid) 95 | if e_id_name.ent_wikiid2name[ent_wikiid] then 96 | return true 97 | end 98 | return false 99 | end 100 | 101 | 102 | get_ent_name_from_wikiid = function(ent_wikiid) 103 | local ent_name = e_id_name.ent_wikiid2name[ent_wikiid] 104 | if (not ent_wikiid) or (not ent_name) then 105 | return 'NIL' 106 | end 107 | return ent_name 108 | end 109 | 110 | preprocess_ent_name = function(ent_name) 111 | ent_name = trim1(ent_name) 112 | ent_name = string.gsub(ent_name, '&', '&') 113 | ent_name = string.gsub(ent_name, '"', '"') 114 | ent_name = ent_name:gsub('_', ' ') 115 | ent_name = first_letter_to_uppercase(ent_name) 116 | if get_redirected_ent_title then 117 | ent_name = get_redirected_ent_title(ent_name) 118 | end 119 | return ent_name 120 | end 121 | 122 | get_ent_wikiid_from_name = function(ent_name, not_verbose) 123 | local verbose = (not not_verbose) 124 | ent_name = preprocess_ent_name(ent_name) 125 | local ent_wikiid = e_id_name.ent_name2wikiid[ent_name] 126 | if (not ent_wikiid) or (not ent_name) then 127 | if verbose then 128 | print(red('Entity ' .. ent_name .. ' not found. Redirects file needs to be loaded for better performance.')) 129 | end 130 | return unk_ent_wikiid 131 | end 132 | return ent_wikiid 133 | end 134 | 135 | ------------------------ Functions for thids and wikiids ----------------- 136 | -- ent wiki id -> thid 137 | get_thid = function (ent_wikiid) 138 | if rltd_only then 139 | ent_thid = rewtr.reltd_ents_wikiid_to_rltdid[ent_wikiid] 140 | else 141 | ent_thid = e_id_name.ent_wikiid2thid[ent_wikiid] 142 | end 143 | if (not ent_wikiid) or (not ent_thid) then 144 | return unk_ent_thid 145 | end 146 | return ent_thid 147 | end 148 | 149 | contains_thid = function (ent_wikiid) 150 | if rltd_only then 151 | ent_thid = rewtr.reltd_ents_wikiid_to_rltdid[ent_wikiid] 152 | else 153 | ent_thid = e_id_name.ent_wikiid2thid[ent_wikiid] 154 | end 155 | if ent_wikiid == nil or ent_thid == nil then 156 | return false 157 | end 158 | return true 159 | end 160 | 161 | get_total_num_ents = function() 162 | if rltd_only then 163 | assert(table_len(rewtr.reltd_ents_wikiid_to_rltdid) == rewtr.num_rltd_ents) 164 | return table_len(rewtr.reltd_ents_wikiid_to_rltdid) 165 | else 166 | return #e_id_name.ent_thid2wikiid 167 | end 168 | end 169 | 170 | get_wikiid_from_thid = function(ent_thid) 171 | if rltd_only then 172 | ent_wikiid = rewtr.reltd_ents_rltdid_to_wikiid[ent_thid] 173 | else 174 | ent_wikiid = e_id_name.ent_thid2wikiid[ent_thid] 175 | end 176 | if ent_wikiid == nil or ent_thid == nil then 177 | return unk_ent_wikiid 178 | end 179 | return ent_wikiid 180 | end 181 | 182 | -- tensor of ent wiki ids --> tensor of thids 183 | get_ent_thids = function (ent_wikiids_tensor) 184 | local ent_thid_tensor = ent_wikiids_tensor:clone() 185 | if ent_wikiids_tensor:dim() == 2 then 186 | for i = 1,ent_thid_tensor:size(1) do 187 | for j = 1,ent_thid_tensor:size(2) do 188 | ent_thid_tensor[i][j] = get_thid(ent_wikiids_tensor[i][j]) 189 | end 190 | end 191 | elseif ent_wikiids_tensor:dim() == 1 then 192 | for i = 1,ent_thid_tensor:size(1) do 193 | ent_thid_tensor[i] = get_thid(ent_wikiids_tensor[i]) 194 | end 195 | else 196 | print('Tensor with > 2 dimentions not supported') 197 | os.exit() 198 | end 199 | return ent_thid_tensor 200 | end 201 | 202 | print(' Done loading entity name - wikiid. Size thid index = ' .. get_total_num_ents()) 203 | -------------------------------------------------------------------------------- /entities/learn_e2v/batch_dataset_a.lua: -------------------------------------------------------------------------------- 1 | dofile 'utils/utils.lua' 2 | 3 | if opt.entities == 'ALL' then 4 | wiki_words_train_file = opt.root_data_dir .. 'generated/wiki_canonical_words.txt' 5 | wiki_hyp_train_file = opt.root_data_dir .. 'generated/wiki_hyperlink_contexts.csv' 6 | else 7 | wiki_words_train_file = opt.root_data_dir .. 'generated/wiki_canonical_words_RLTD.txt' 8 | wiki_hyp_train_file = opt.root_data_dir .. 'generated/wiki_hyperlink_contexts_RLTD.csv' 9 | end 10 | 11 | wiki_words_it, _ = io.open(wiki_words_train_file) 12 | wiki_hyp_it, _ = io.open(wiki_hyp_train_file) 13 | 14 | assert(opt.num_passes_wiki_words) 15 | local train_data_source = 'wiki-canonical' 16 | local num_passes_wiki_words = 1 17 | 18 | local function read_one_line() 19 | if train_data_source == 'wiki-canonical' then 20 | line = wiki_words_it:read() 21 | else 22 | assert(train_data_source == 'wiki-canonical-hyperlinks') 23 | line = wiki_hyp_it:read() 24 | end 25 | if (not line) then 26 | if num_passes_wiki_words == opt.num_passes_wiki_words then 27 | train_data_source = 'wiki-canonical-hyperlinks' 28 | print('\n\n' .. 'Start training on Wiki Hyperlinks' .. '\n\n') 29 | end 30 | print('Training file is done. Num passes = ' .. num_passes_wiki_words .. '. Reopening.') 31 | num_passes_wiki_words = num_passes_wiki_words + 1 32 | if train_data_source == 'wiki-canonical' then 33 | wiki_words_it, _ = io.open(wiki_words_train_file) 34 | line = wiki_words_it:read() 35 | else 36 | wiki_hyp_it, _ = io.open(wiki_hyp_train_file) 37 | line = wiki_hyp_it:read() 38 | end 39 | end 40 | return line 41 | end 42 | 43 | local line = nil 44 | 45 | local function patch_of_lines(num) 46 | local lines = {} 47 | local cnt = 0 48 | assert(num > 0) 49 | 50 | while cnt < num do 51 | line = read_one_line() 52 | cnt = cnt + 1 53 | table.insert(lines, line) 54 | end 55 | 56 | assert(table_len(lines) == num) 57 | return lines 58 | end 59 | 60 | 61 | function get_minibatch() 62 | -- Create empty mini batch: 63 | local lines = patch_of_lines(opt.batch_size) 64 | local inputs = empty_minibatch() 65 | local targets = correct_type(torch.ones(opt.batch_size, opt.num_words_per_ent)) 66 | 67 | -- Fill in each example: 68 | for i = 1,opt.batch_size do 69 | local sample_line = lines[i] -- load new example line 70 | local target = process_one_line(sample_line, inputs, i) 71 | targets[i]:copy(target) 72 | end 73 | 74 | --- Minibatch post processing: 75 | postprocess_minibatch(inputs, targets) 76 | targets = targets:view(opt.batch_size * opt.num_words_per_ent) 77 | 78 | -- Special target for the NEG and NCE losses 79 | if opt.loss == 'neg' or opt.loss == 'nce' then 80 | nce_targets = torch.ones(opt.batch_size * opt.num_words_per_ent, opt.num_neg_words):mul(-1) 81 | for j = 1,opt.batch_size * opt.num_words_per_ent do 82 | nce_targets[j][targets[j]] = 1 83 | end 84 | targets = nce_targets 85 | end 86 | 87 | return inputs, targets 88 | end 89 | -------------------------------------------------------------------------------- /entities/learn_e2v/e2v_a.lua: -------------------------------------------------------------------------------- 1 | -- Entity embeddings utilities 2 | 3 | assert(opt.num_words_per_ent) 4 | 5 | -- Word lookup: 6 | geom_w2v_M = w2vutils.M:float() 7 | 8 | -- Stats: 9 | local num_invalid_ent_wikiids = 0 10 | local total_ent_wiki_vec_requests = 0 11 | local last_wrote = 0 12 | local function invalid_ent_wikiids_stats(ent_thid) 13 | total_ent_wiki_vec_requests = total_ent_wiki_vec_requests + 1 14 | if ent_thid == unk_ent_thid then 15 | num_invalid_ent_wikiids = num_invalid_ent_wikiids + 1 16 | end 17 | if (num_invalid_ent_wikiids % 15000 == 0 and num_invalid_ent_wikiids ~= last_wrote) then 18 | last_wrote = num_invalid_ent_wikiids 19 | local perc = 100.0 * num_invalid_ent_wikiids / total_ent_wiki_vec_requests 20 | print(red('*** Perc invalid ent wikiids = ' .. perc .. ' . Absolute num = ' .. num_invalid_ent_wikiids)) 21 | end 22 | end 23 | 24 | -- ent id -> vec 25 | function geom_entwikiid2vec(ent_wikiid) 26 | local ent_thid = get_thid(ent_wikiid) 27 | assert(ent_thid) 28 | invalid_ent_wikiids_stats(ent_thid) 29 | local ent_vec = nn.Normalize(2):forward(lookup_ent_vecs.weight[ent_thid]:float()) 30 | return ent_vec 31 | end 32 | 33 | -- ent name -> vec 34 | local function geom_entname2vec(ent_name) 35 | assert(ent_name) 36 | return geom_entwikiid2vec(get_ent_wikiid_from_name(ent_name)) 37 | end 38 | 39 | function entity_similarity(e1_wikiid, e2_wikiid) 40 | local e1_vec = geom_entwikiid2vec(e1_wikiid) 41 | local e2_vec = geom_entwikiid2vec(e2_wikiid) 42 | return e1_vec * e2_vec 43 | end 44 | 45 | 46 | local function geom_top_k_closest_words(ent_name, ent_vec, k) 47 | local tf_map = ent_wiki_words_4EX[ent_name] 48 | local w_not_found = {} 49 | for w,_ in pairs(tf_map) do 50 | if tf_map[w] >= 10 then 51 | w_not_found[w] = tf_map[w] 52 | end 53 | end 54 | 55 | distances = geom_w2v_M * ent_vec 56 | 57 | local best_scores, best_word_ids = topk(distances, k) 58 | local returnwords = {} 59 | local returndistances = {} 60 | for i = 1,k do 61 | local w = get_word_from_id(best_word_ids[i]) 62 | if is_stop_word_or_number(w) then 63 | table.insert(returnwords, red(w)) 64 | elseif tf_map[w] then 65 | if tf_map[w] >= 15 then 66 | table.insert(returnwords, yellow(w .. '{' .. tf_map[w] .. '}')) 67 | else 68 | table.insert(returnwords, skyblue(w .. '{' .. tf_map[w] .. '}')) 69 | end 70 | w_not_found[w] = nil 71 | else 72 | table.insert(returnwords, w) 73 | end 74 | assert(best_scores[i] == distances[best_word_ids[i]], best_scores[i] .. ' ' .. distances[best_word_ids[i]]) 75 | table.insert(returndistances, distances[best_word_ids[i]]) 76 | end 77 | return returnwords, returndistances, w_not_found 78 | end 79 | 80 | 81 | local function geom_most_similar_words_to_ent(ent_name, k) 82 | local ent_wikiid = get_ent_wikiid_from_name(ent_name) 83 | local k = k or 1 84 | local ent_vec = geom_entname2vec(ent_name) 85 | assert(math.abs(1 - ent_vec:norm()) < 0.01 or ent_vec:norm() == 0, ':::: ' .. ent_vec:norm()) 86 | 87 | print('\nTo entity: ' .. blue(ent_name) .. '; vec norm = ' .. ent_vec:norm() .. ':') 88 | neighbors, scores, w_not_found = geom_top_k_closest_words(ent_name, ent_vec, k) 89 | print(green('WORDS MODEL: ') .. list_with_scores_to_str(neighbors, scores)) 90 | 91 | local str = yellow('WORDS NOT FOUND: ') 92 | for w,tf in pairs(w_not_found) do 93 | if tf >= 20 then 94 | str = str .. yellow(w .. '{' .. tf .. '}; ') 95 | else 96 | str = str .. w .. '{' .. tf .. '}; ' 97 | end 98 | end 99 | print('\n' .. str) 100 | print('============================================================================') 101 | end 102 | 103 | 104 | -- Unit tests : 105 | function geom_unit_tests() 106 | print('\n' .. yellow('Words to Entity Similarity test:')) 107 | for i=1,table_len(ent_names_4EX) do 108 | geom_most_similar_words_to_ent(ent_names_4EX[i], 200) 109 | end 110 | end 111 | -------------------------------------------------------------------------------- /entities/learn_e2v/learn_a.lua: -------------------------------------------------------------------------------- 1 | -- Training of entity embeddings. 2 | 3 | -- To run: 4 | -- i) delete all _RLTD files 5 | -- ii) th entities/relatedness/filter_wiki_canonical_words_RLTD.lua ; th entities/relatedness/filter_wiki_hyperlink_contexts_RLTD.lua 6 | -- iii) th entities/learn_e2v/learn_a.lua -root_data_dir /path/to/your/ed_data/files/ 7 | 8 | -- Training of entity vectors 9 | require 'optim' 10 | require 'torch' 11 | require 'gnuplot' 12 | require 'nn' 13 | require 'xlua' 14 | 15 | dofile 'utils/utils.lua' 16 | 17 | cmd = torch.CmdLine() 18 | cmd:text() 19 | cmd:text('Learning entity vectors') 20 | cmd:text() 21 | cmd:text('Options:') 22 | 23 | ---------------- runtime options: 24 | cmd:option('-type', 'cudacudnn', 'Type: double | float | cuda | cudacudnn') 25 | 26 | cmd:option('-root_data_dir', '', 'Root path of the data, $DATA_PATH.') 27 | 28 | cmd:option('-optimization', 'ADAGRAD', 'Optimization method: RMSPROP | ADAGRAD | ADAM | SGD') 29 | 30 | cmd:option('-lr', 0.3, 'Learning rate') 31 | 32 | cmd:option('-batch_size', 500, 'Mini-batch size (1 = pure stochastic)') 33 | 34 | cmd:option('-word_vecs', 'w2v', '300d word vectors type: glove | w2v') 35 | 36 | cmd:option('-num_words_per_ent', 20, 'Num positive words sampled for the given entity at ' .. 37 | 'each iteration.') 38 | 39 | cmd:option('-num_neg_words', 5, 'Num negative words sampled for each positive word.') 40 | 41 | cmd:option('-unig_power', 0.6, 'Negative sampling unigram power (0.75 used in Word2Vec).') 42 | 43 | cmd:option('-entities', 'RLTD', 44 | 'Set of entities for which we train embeddings: 4EX (tiny, for debug) | ' .. 45 | 'RLTD (restricted set) | ALL (all Wiki entities, too big to fit on a single GPU)') 46 | 47 | cmd:option('-init_vecs_title_words', true, 'whether the entity embeddings should be initialized with the average of ' .. 48 | 'title word embeddings. Helps to speed up convergence speed of entity embeddings learning.') 49 | 50 | cmd:option('-loss', 'maxm', 'Loss function: nce (noise contrastive estimation) | ' .. 51 | 'neg (negative sampling) | is (importance sampling) | maxm (max-margin)') 52 | 53 | cmd:option('-data', 'wiki-canonical-hyperlinks', 'Training data: wiki-canonical (only) | ' .. 54 | 'wiki-canonical-hyperlinks') 55 | 56 | -- Only when opt.data = wiki-canonical-hyperlinks 57 | cmd:option('-num_passes_wiki_words', 200, 'Num passes (per entity) over Wiki canonical pages before ' .. 58 | 'changing to using Wiki hyperlinks.') 59 | 60 | cmd:option('-hyp_ctxt_len', 10, 'Left and right context window length for hyperlinks.') 61 | 62 | cmd:option('-banner_header', '', 'Banner header') 63 | 64 | cmd:text() 65 | opt = cmd:parse(arg or {}) 66 | 67 | banner = '' .. opt.banner_header .. ';obj-' .. opt.loss .. ';' .. opt.data 68 | if opt.data ~= 'wiki-canonical' then 69 | banner = banner .. ';hypCtxtL-' .. opt.hyp_ctxt_len 70 | banner = banner .. ';numWWpass-' .. opt.num_passes_wiki_words 71 | end 72 | banner = banner .. ';WperE-' .. opt.num_words_per_ent 73 | banner = banner .. ';' .. opt.word_vecs .. ';negW-' .. opt.num_neg_words 74 | banner = banner .. ';ents-' .. opt.entities .. ';unigP-' .. opt.unig_power 75 | banner = banner .. ';bs-' .. opt.batch_size .. ';' .. opt.optimization .. '-lr-' .. opt.lr 76 | 77 | print('\n' .. blue('BANNER : ' .. banner)) 78 | 79 | print('\n===> RUN TYPE: ' .. opt.type) 80 | 81 | torch.setdefaulttensortype('torch.FloatTensor') 82 | if string.find(opt.type, 'cuda') then 83 | print('==> switching to CUDA (GPU)') 84 | require 'cunn' 85 | require 'cutorch' 86 | require 'cudnn' 87 | cudnn.benchmark = true 88 | cudnn.fastest = true 89 | else 90 | print('==> running on CPU') 91 | end 92 | 93 | dofile 'utils/logger.lua' 94 | dofile 'entities/relatedness/relatedness.lua' 95 | dofile 'entities/ent_name2id_freq/ent_name_id.lua' 96 | dofile 'words/load_w_freq_and_vecs.lua' 97 | dofile 'words/w2v/w2v.lua' 98 | dofile 'entities/learn_e2v/minibatch_a.lua' 99 | dofile 'entities/learn_e2v/model_a.lua' 100 | dofile 'entities/learn_e2v/e2v_a.lua' 101 | dofile 'entities/learn_e2v/batch_dataset_a.lua' 102 | 103 | if opt.loss == 'neg' or opt.loss == 'nce' then 104 | criterion = nn.SoftMarginCriterion() 105 | elseif opt.loss == 'maxm' then 106 | criterion = nn.MultiMarginCriterion(1, torch.ones(opt.num_neg_words), 0.1) 107 | elseif opt.loss == 'is' then 108 | criterion = nn.CrossEntropyCriterion() 109 | end 110 | 111 | if string.find(opt.type, 'cuda') then 112 | criterion = criterion:cuda() 113 | end 114 | 115 | 116 | ---------------------------------------------------------------------- 117 | if opt.optimization == 'ADAGRAD' then -- See: http://cs231n.github.io/neural-networks-3/#update 118 | dofile 'utils/optim/adagrad_mem.lua' 119 | optimMethod = adagrad_mem 120 | optimState = { 121 | learningRate = opt.lr 122 | } 123 | elseif opt.optimization == 'RMSPROP' then -- See: cs231n.github.io/neural-networks-3/#update 124 | dofile 'utils/optim/rmsprop_mem.lua' 125 | optimMethod = rmsprop_mem 126 | optimState = { 127 | learningRate = opt.lr 128 | } 129 | elseif opt.optimization == 'SGD' then -- See: http://cs231n.github.io/neural-networks-3/#update 130 | optimState = { 131 | learningRate = opt.lr, 132 | learningRateDecay = 5e-7 133 | } 134 | optimMethod = optim.sgd 135 | elseif opt.optimization == 'ADAM' then -- See: http://cs231n.github.io/neural-networks-3/#update 136 | optimState = { 137 | learningRate = opt.lr, 138 | } 139 | optimMethod = optim.adam 140 | else 141 | error('unknown optimization method') 142 | end 143 | 144 | 145 | ---------------------------------------------------------------------- 146 | function train_ent_vecs() 147 | print('Training entity vectors w/ params: ' .. banner) 148 | 149 | -- Retrieve parameters and gradients: 150 | -- extracts and flattens all model's parameters into a 1-dim vector 151 | parameters,gradParameters = model:getParameters() 152 | gradParameters:zero() 153 | 154 | local processed_so_far = 0 155 | if opt.entities == 'ALL' then 156 | num_batches_per_epoch = 4000 157 | elseif opt.entities == 'RLTD' then 158 | num_batches_per_epoch = 2000 159 | elseif opt.entities == '4EX' then 160 | num_batches_per_epoch = 400 161 | end 162 | local test_every_num_epochs = 1 163 | local save_every_num_epochs = 3 164 | 165 | -- epoch tracker 166 | epoch = 1 167 | 168 | -- Initial testing: 169 | geom_unit_tests() -- Show some examples 170 | 171 | print('Training params: ' .. banner) 172 | 173 | -- do one epoch 174 | print('\n==> doing epoch on training data:') 175 | print("==> online epoch # " .. epoch .. ' [batch size = ' .. opt.batch_size .. ']') 176 | 177 | while true do 178 | 179 | local time = sys.clock() 180 | print(green('\n===> TRAINING EPOCH #' .. epoch .. '; num batches ' .. num_batches_per_epoch .. ' <===')) 181 | 182 | local avg_loss_before_opt_per_epoch = 0.0 183 | local avg_loss_after_opt_per_epoch = 0.0 184 | 185 | for batch_index = 1,num_batches_per_epoch do 186 | -- Read one mini-batch from one data_thread: 187 | inputs, targets = get_minibatch() 188 | 189 | -- Move data to GPU: 190 | minibatch_to_correct_type(inputs) 191 | targets = correct_type(targets) 192 | 193 | -- create closure to evaluate f(X) and df/dX 194 | local feval = function(x) 195 | -- get new parameters 196 | if x ~= parameters then 197 | parameters:copy(x) 198 | end 199 | 200 | -- reset gradients 201 | gradParameters:zero() 202 | 203 | -- evaluate function for complete mini batch 204 | local outputs = model:forward(inputs) 205 | assert(outputs:size(1) == opt.batch_size * opt.num_words_per_ent and 206 | outputs:size(2) == opt.num_neg_words) 207 | 208 | local f = criterion:forward(outputs, targets) 209 | 210 | -- estimate df/dW 211 | local df_do = criterion:backward(outputs, targets) 212 | local gradInput = model:backward(inputs, df_do) 213 | 214 | -- return f and df/dX 215 | return f,gradParameters 216 | end 217 | 218 | -- Debug info: 219 | local loss_before_opt = criterion:forward(model:forward(inputs), targets) 220 | avg_loss_before_opt_per_epoch = avg_loss_before_opt_per_epoch + loss_before_opt 221 | 222 | -- Optimize on current mini-batch 223 | optimMethod(feval, parameters, optimState) 224 | 225 | local loss_after_opt = criterion:forward(model:forward(inputs), targets) 226 | avg_loss_after_opt_per_epoch = avg_loss_after_opt_per_epoch + loss_after_opt 227 | if loss_after_opt > loss_before_opt then 228 | print(red('!!!!!! LOSS INCREASED: ' .. loss_before_opt .. ' --> ' .. loss_after_opt)) 229 | end 230 | 231 | -- Display progress 232 | train_size = 17000000 ---------- 4 passes over the Wiki entity set 233 | processed_so_far = processed_so_far + opt.batch_size 234 | if processed_so_far > train_size then 235 | processed_so_far = processed_so_far - train_size 236 | end 237 | xlua.progress(processed_so_far, train_size) 238 | end 239 | 240 | avg_loss_before_opt_per_epoch = avg_loss_before_opt_per_epoch / num_batches_per_epoch 241 | avg_loss_after_opt_per_epoch = avg_loss_after_opt_per_epoch / num_batches_per_epoch 242 | print(yellow('\nAvg loss before opt = ' .. avg_loss_before_opt_per_epoch .. 243 | '; Avg loss after opt = ' .. avg_loss_after_opt_per_epoch)) 244 | 245 | -- time taken 246 | time = sys.clock() - time 247 | time = time / (num_batches_per_epoch * opt.batch_size) 248 | print("==> time to learn 1 full entity = " .. (time*1000) .. 'ms') 249 | 250 | geom_unit_tests() -- Show some entity examples 251 | 252 | -- Various testing measures: 253 | if (epoch % test_every_num_epochs == 0) then 254 | if opt.entities ~= '4EX' then 255 | compute_relatedness_metrics(entity_similarity) 256 | end 257 | end 258 | 259 | -- Save model: 260 | if (epoch % save_every_num_epochs == 0) then 261 | print('==> saving model to ' .. opt.root_data_dir .. 'generated/ent_vecs/ent_vecs__ep_' .. epoch .. '.t7') 262 | torch.save(opt.root_data_dir .. 'generated/ent_vecs/ent_vecs__ep_' .. epoch .. '.t7', nn.Normalize(2):forward(lookup_ent_vecs.weight:float())) 263 | end 264 | 265 | print('Training params: ' .. banner) 266 | 267 | -- next epoch 268 | epoch = epoch + 1 269 | end 270 | end 271 | 272 | train_ent_vecs() 273 | -------------------------------------------------------------------------------- /entities/learn_e2v/minibatch_a.lua: -------------------------------------------------------------------------------- 1 | assert(opt.entities == '4EX' or opt.entities == 'ALL' or opt.entities == 'RLTD', opt.entities) 2 | 3 | function empty_minibatch() 4 | local ctxt_word_ids = torch.ones(opt.batch_size, opt.num_words_per_ent, opt.num_neg_words):mul(unk_w_id) 5 | local ent_component_words = torch.ones(opt.batch_size, opt.num_words_per_ent):int() 6 | local ent_wikiids = torch.ones(opt.batch_size):int() 7 | local ent_thids = torch.ones(opt.batch_size):int() 8 | return {{ctxt_word_ids}, {ent_component_words}, {ent_thids, ent_wikiids}} 9 | end 10 | 11 | -- Get functions: 12 | function get_pos_and_neg_w_ids(minibatch) 13 | return minibatch[1][1] 14 | end 15 | function get_pos_and_neg_w_vecs(minibatch) 16 | return minibatch[1][2] 17 | end 18 | function get_pos_and_neg_w_unig_at_power(minibatch) 19 | return minibatch[1][3] 20 | end 21 | function get_ent_wiki_w_ids(minibatch) 22 | return minibatch[2][1] 23 | end 24 | function get_ent_wiki_w_vecs(minibatch) 25 | return minibatch[2][2] 26 | end 27 | function get_ent_thids_batch(minibatch) 28 | return minibatch[3][1] 29 | end 30 | function get_ent_wikiids(minibatch) 31 | return minibatch[3][2] 32 | end 33 | 34 | 35 | -- Fills in the minibatch and returns the grd truth word index per each example. 36 | -- An example in our case is an entity, a positive word sampled from \hat{p}(e|m) 37 | -- and several negative words sampled from \hat{p}(w)^\alpha. 38 | function process_one_line(line, minibatch, mb_index) 39 | if opt.entities == '4EX' then 40 | line = ent_lines_4EX[ent_names_4EX[math.random(1, table_len(ent_names_4EX))]] 41 | end 42 | 43 | local parts = split(line, '\t') 44 | local num_parts = table_len(parts) 45 | 46 | if num_parts == 3 then ---------> Words from the Wikipedia canonical page 47 | assert(table_len(parts) == 3, line) 48 | ent_wikiid = tonumber(parts[1]) 49 | words_plus_stop_words = split(parts[3], ' ') 50 | 51 | else --------> Words from Wikipedia hyperlinks 52 | assert(num_parts >= 9, line .. ' --> ' .. num_parts) 53 | assert(parts[6] == 'CANDIDATES', line) 54 | 55 | local last_part = parts[num_parts] 56 | local ent_str = split(last_part, ',') 57 | ent_wikiid = tonumber(ent_str[2]) 58 | 59 | words_plus_stop_words = {} 60 | local left_ctxt_w = split(parts[4], ' ') 61 | local left_ctxt_w_num = table_len(left_ctxt_w) 62 | for i = math.max(1, left_ctxt_w_num - opt.hyp_ctxt_len + 1), left_ctxt_w_num do 63 | table.insert(words_plus_stop_words, left_ctxt_w[i]) 64 | end 65 | local right_ctxt_w = split(parts[5], ' ') 66 | local right_ctxt_w_num = table_len(right_ctxt_w) 67 | for i = 1, math.min(right_ctxt_w_num, opt.hyp_ctxt_len) do 68 | table.insert(words_plus_stop_words, right_ctxt_w[i]) 69 | end 70 | end 71 | 72 | assert(ent_wikiid) 73 | local ent_thid = get_thid(ent_wikiid) 74 | assert(get_wikiid_from_thid(ent_thid) == ent_wikiid) 75 | get_ent_thids_batch(minibatch)[mb_index] = ent_thid 76 | assert(get_ent_thids_batch(minibatch)[mb_index] == ent_thid) 77 | 78 | get_ent_wikiids(minibatch)[mb_index] = ent_wikiid 79 | 80 | 81 | -- Remove stop words from entity wiki words representations. 82 | local positive_words_in_this_iter = {} 83 | local num_positive_words_this_iter = 0 84 | for _,w in pairs(words_plus_stop_words) do 85 | if contains_w(w) then 86 | table.insert(positive_words_in_this_iter, w) 87 | num_positive_words_this_iter = num_positive_words_this_iter + 1 88 | end 89 | end 90 | 91 | -- Try getting some words from the entity title if the canonical page is empty. 92 | if num_positive_words_this_iter == 0 then 93 | local ent_name = parts[2] 94 | words_plus_stop_words = split_in_words(ent_name) 95 | for _,w in pairs(words_plus_stop_words) do 96 | if contains_w(w) then 97 | table.insert(positive_words_in_this_iter, w) 98 | num_positive_words_this_iter = num_positive_words_this_iter + 1 99 | end 100 | end 101 | 102 | -- Still empty ? Get some random words then. 103 | if num_positive_words_this_iter == 0 then 104 | table.insert(positive_words_in_this_iter, get_word_from_id(random_unigram_at_unig_power_w_id())) 105 | end 106 | end 107 | 108 | local targets = torch.zeros(opt.num_words_per_ent):int() 109 | 110 | -- Sample some negative words: 111 | get_pos_and_neg_w_ids(minibatch)[mb_index]:apply( 112 | function(x) 113 | -- Random negative words sampled sampled from \hat{p}(w)^\alpha. 114 | return random_unigram_at_unig_power_w_id() 115 | end 116 | ) 117 | 118 | -- Sample some positive words: 119 | for i = 1,opt.num_words_per_ent do 120 | local positive_w = positive_words_in_this_iter[math.random(1, num_positive_words_this_iter)] 121 | local positive_w_id = get_id_from_word(positive_w) 122 | 123 | -- Set the positive word in a random position. Remember that index (used in training). 124 | local grd_trth = math.random(1, opt.num_neg_words) 125 | get_ent_wiki_w_ids(minibatch)[mb_index][i] = positive_w_id 126 | assert(get_ent_wiki_w_ids(minibatch)[mb_index][i] == positive_w_id) 127 | targets[i] = grd_trth 128 | get_pos_and_neg_w_ids(minibatch)[mb_index][i][grd_trth] = positive_w_id 129 | end 130 | 131 | return targets 132 | end 133 | 134 | 135 | -- Fill minibatch with word and entity vectors: 136 | function postprocess_minibatch(minibatch, targets) 137 | 138 | minibatch[1][1] = get_pos_and_neg_w_ids(minibatch):view(opt.batch_size * opt.num_words_per_ent * opt.num_neg_words) 139 | minibatch[2][1] = get_ent_wiki_w_ids(minibatch):view(opt.batch_size * opt.num_words_per_ent) 140 | 141 | -- ctxt word vecs 142 | minibatch[1][2] = w2vutils:lookup_w_vecs(get_pos_and_neg_w_ids(minibatch)) 143 | 144 | minibatch[1][3] = torch.zeros(opt.batch_size * opt.num_words_per_ent * opt.num_neg_words) 145 | minibatch[1][3]:map(minibatch[1][1], function(_,w_id) return get_w_unnorm_unigram_at_power(w_id) end) 146 | end 147 | 148 | 149 | -- Convert mini batch to correct type (e.g. move data to GPU): 150 | function minibatch_to_correct_type(minibatch) 151 | minibatch[1][1] = correct_type(minibatch[1][1]) 152 | minibatch[2][1] = correct_type(minibatch[2][1]) 153 | minibatch[1][2] = correct_type(minibatch[1][2]) 154 | minibatch[1][3] = correct_type(minibatch[1][3]) 155 | minibatch[3][1] = correct_type(minibatch[3][1]) 156 | end -------------------------------------------------------------------------------- /entities/learn_e2v/model_a.lua: -------------------------------------------------------------------------------- 1 | -- Definition of the neural network used to learn entity embeddings. 2 | -- To run a simple unit test that checks the forward and backward passes, just run : 3 | -- th entities/learn_e2v/model_a.lua 4 | 5 | if not opt then -- unit tests 6 | unit_tests = true 7 | dofile 'utils/utils.lua' 8 | require 'nn' 9 | cmd = torch.CmdLine() 10 | cmd:option('-type', 'double', 'type: double | float | cuda | cudacudnn') 11 | cmd:option('-batch_size', 7, 'mini-batch size (1 = pure stochastic)') 12 | cmd:option('-num_words_per_ent', 100, 'num positive words per entity per iteration.') 13 | cmd:option('-num_neg_words', 25, 'num negative words in the partition function.') 14 | cmd:option('-loss', 'nce', 'nce | neg | is | maxm') 15 | cmd:option('-init_vecs_title_words', false, 'whether the entity embeddings should be initialized with the average of title word embeddings. Helps to speed up convergence speed of entity embeddings learning.') 16 | opt = cmd:parse(arg or {}) 17 | word_vecs_size = 5 18 | ent_vecs_size = word_vecs_size 19 | lookup_ent_vecs = nn.LookupTable(100, ent_vecs_size) 20 | end -- end unit tests 21 | 22 | 23 | if not unit_tests then 24 | ent_vecs_size = word_vecs_size 25 | 26 | -- Init ents vectors 27 | print('\n==> Init entity embeddings matrix. Num ents = ' .. get_total_num_ents()) 28 | lookup_ent_vecs = nn.LookupTable(get_total_num_ents(), ent_vecs_size) 29 | 30 | -- Zero out unk_ent_thid vector for unknown entities. 31 | lookup_ent_vecs.weight[unk_ent_thid]:copy(torch.zeros(ent_vecs_size)) 32 | 33 | -- Init entity vectors with average of title word embeddings. 34 | -- This would help speed-up training. 35 | if opt.init_vecs_title_words then 36 | print('Init entity embeddings with average of title word vectors to speed up learning.') 37 | for ent_thid = 1,get_total_num_ents() do 38 | local init_ent_vec = torch.zeros(ent_vecs_size) 39 | local ent_name = get_ent_name_from_wikiid(get_wikiid_from_thid(ent_thid)) 40 | words_plus_stop_words = split_in_words(ent_name) 41 | local num_words_title = 0 42 | for _,w in pairs(words_plus_stop_words) do 43 | if contains_w(w) then -- Remove stop words. 44 | init_ent_vec:add(w2vutils.M[get_id_from_word(w)]:float()) 45 | num_words_title = num_words_title + 1 46 | end 47 | end 48 | 49 | if num_words_title > 0 then 50 | if num_words_title > 3 then 51 | assert(init_ent_vec:norm() > 0, ent_name) 52 | end 53 | init_ent_vec:div(num_words_title) 54 | end 55 | 56 | if init_ent_vec:norm() > 0 then 57 | lookup_ent_vecs.weight[ent_thid]:copy(init_ent_vec) 58 | end 59 | end 60 | end 61 | 62 | collectgarbage(); collectgarbage(); 63 | print(' Done init.') 64 | end 65 | 66 | ---------------- Model Definition -------------------------------- 67 | cosine_words_ents = nn.Sequential() 68 | :add(nn.ConcatTable() 69 | :add(nn.Sequential() 70 | :add(nn.SelectTable(1)) 71 | :add(nn.SelectTable(2)) -- ctxt words vectors 72 | :add(nn.Normalize(2)) 73 | :add(nn.View(opt.batch_size, opt.num_words_per_ent * opt.num_neg_words, ent_vecs_size))) 74 | :add(nn.Sequential() 75 | :add(nn.SelectTable(3)) 76 | :add(nn.SelectTable(1)) 77 | :add(lookup_ent_vecs) -- entity vectors 78 | :add(nn.Normalize(2)) 79 | :add(nn.View(opt.batch_size, 1, ent_vecs_size)))) 80 | :add(nn.MM(false, true)) 81 | :add(nn.View(opt.batch_size * opt.num_words_per_ent, opt.num_neg_words)) 82 | 83 | model = nn.Sequential() 84 | :add(cosine_words_ents) 85 | :add(nn.View(opt.batch_size * opt.num_words_per_ent, opt.num_neg_words)) 86 | 87 | 88 | if opt.loss == 'is' then 89 | model = nn.Sequential() 90 | :add(nn.ConcatTable() 91 | :add(model) 92 | :add(nn.Sequential() 93 | :add(nn.SelectTable(1)) 94 | :add(nn.SelectTable(3)) -- unigram distributions at power 95 | :add(nn.Log()) 96 | :add(nn.View(opt.batch_size * opt.num_words_per_ent, opt.num_neg_words)))) 97 | :add(nn.CSubTable()) 98 | 99 | elseif opt.loss == 'nce' then 100 | model = nn.Sequential() 101 | :add(nn.ConcatTable() 102 | :add(model) 103 | :add(nn.Sequential() 104 | :add(nn.SelectTable(1)) 105 | :add(nn.SelectTable(3)) -- unigram distributions at power 106 | :add(nn.MulConstant(opt.num_neg_words - 1)) 107 | :add(nn.Log()) 108 | :add(nn.View(opt.batch_size * opt.num_words_per_ent, opt.num_neg_words)))) 109 | :add(nn.CSubTable()) 110 | end 111 | 112 | --------------------------------------------------------------------------------------------- 113 | 114 | ------- Cuda conversions: 115 | if string.find(opt.type, 'cuda') then 116 | model = model:cuda() --- This has to be called always before cudnn.convert 117 | end 118 | 119 | if string.find(opt.type, 'cudacudnn') then 120 | cudnn.convert(model, cudnn) 121 | end 122 | 123 | 124 | --- Unit tests 125 | if unit_tests then 126 | print('Network model unit tests:') 127 | local inputs = {} 128 | 129 | inputs[1] = {} 130 | inputs[1][1] = correct_type(torch.ones(opt.batch_size * opt.num_words_per_ent * opt.num_neg_words)) -- ctxt words 131 | 132 | inputs[2] = {} 133 | inputs[2][1] = correct_type(torch.ones(opt.batch_size * opt.num_words_per_ent)) -- ent wiki words 134 | 135 | inputs[3] = {} 136 | inputs[3][1] = correct_type(torch.ones(opt.batch_size)) -- ent th ids 137 | inputs[3][2] = torch.ones(opt.batch_size) -- ent wikiids 138 | 139 | -- ctxt word vecs 140 | inputs[1][2] = correct_type(torch.ones(opt.batch_size * opt.num_words_per_ent * opt.num_neg_words, word_vecs_size)) 141 | 142 | inputs[1][3] = correct_type(torch.randn(opt.batch_size * opt.num_words_per_ent * opt.num_neg_words)) 143 | 144 | local outputs = model:forward(inputs) 145 | 146 | assert(outputs:size(1) == opt.batch_size * opt.num_words_per_ent and 147 | outputs:size(2) == opt.num_neg_words) 148 | print('FWD success!') 149 | 150 | model:backward(inputs, correct_type(torch.randn(opt.batch_size * opt.num_words_per_ent, opt.num_neg_words))) 151 | print('BKWD success!') 152 | end 153 | -------------------------------------------------------------------------------- /entities/pretrained_e2v/check_ents.lua: -------------------------------------------------------------------------------- 1 | if not opt then 2 | cmd = torch.CmdLine() 3 | cmd:option('-root_data_dir', '', 'Root path of the data, $DATA_PATH.') 4 | cmd:option('-ent_vecs_filename', 'ent_vecs__ep_228.t7', 'File name containing entity vectors generated with entities/learn_e2v/learn_a.lua.') 5 | cmd:text() 6 | opt = cmd:parse(arg or {}) 7 | assert(opt.root_data_dir ~= '', 'Specify a valid root_data_dir path argument.') 8 | end 9 | 10 | 11 | require 'optim' 12 | require 'torch' 13 | require 'gnuplot' 14 | require 'nn' 15 | require 'xlua' 16 | 17 | dofile 'utils/utils.lua' 18 | dofile 'ed/args.lua' 19 | 20 | tds = require 'tds' 21 | 22 | 23 | print('===> RUN TYPE: ' .. opt.type) 24 | torch.setdefaulttensortype('torch.FloatTensor') 25 | if string.find(opt.type, 'cuda') then 26 | print('==> switching to CUDA (GPU)') 27 | require 'cunn' 28 | require 'cutorch' 29 | require 'cudnn' 30 | cudnn.benchmark = true 31 | cudnn.fastest = true 32 | else 33 | print('==> running on CPU') 34 | end 35 | 36 | dofile 'utils/logger.lua' 37 | dofile 'entities/relatedness/relatedness.lua' 38 | dofile 'entities/ent_name2id_freq/ent_name_id.lua' 39 | dofile 'entities/ent_name2id_freq/e_freq_index.lua' 40 | dofile 'words/load_w_freq_and_vecs.lua' 41 | dofile 'entities/pretrained_e2v/e2v.lua' -------------------------------------------------------------------------------- /entities/pretrained_e2v/e2v.lua: -------------------------------------------------------------------------------- 1 | -- Loads pre-trained entity vectors trained using the file entity/learn_e2v/learn_a.lua 2 | 3 | assert(opt.ent_vecs_filename) 4 | print('==> Loading pre-trained entity vectors: e2v from file ' .. opt.ent_vecs_filename) 5 | 6 | assert(opt.entities == 'RLTD', 'Only RLTD entities are currently supported. ALL entities would blow the GPU memory.') 7 | 8 | -- Defining variables: 9 | ent_vecs_size = 300 10 | 11 | geom_w2v_M = w2vutils.M:float() 12 | 13 | e2vutils = {} 14 | 15 | -- Lookup table: ids -> tensor of vecs 16 | e2vutils.lookup = torch.load(opt.root_data_dir .. 'generated/ent_vecs/' .. opt.ent_vecs_filename) 17 | e2vutils.lookup = nn.Normalize(2):forward(e2vutils.lookup) -- Needs to be normalized to have norm 1. 18 | 19 | assert(e2vutils.lookup:size(1) == get_total_num_ents() and 20 | e2vutils.lookup:size(2) == ent_vecs_size, e2vutils.lookup:size(1) .. ' ' .. get_total_num_ents()) 21 | assert(e2vutils.lookup[unk_ent_thid]:norm() == 0, e2vutils.lookup[unk_ent_thid]:norm()) 22 | 23 | -- ent wikiid -> vec 24 | e2vutils.entwikiid2vec = function(self, ent_wikiid) 25 | local thid = get_thid(ent_wikiid) 26 | return self.lookup[thid]:float() 27 | end 28 | assert(torch.norm(e2vutils:entwikiid2vec(unk_ent_wikiid)) == 0) 29 | 30 | 31 | e2vutils.entname2vec = function (self,ent_name) 32 | assert(ent_name) 33 | return e2vutils:entwikiid2vec(get_ent_wikiid_from_name(ent_name)) 34 | end 35 | 36 | -- Entity similarity based on cosine distance (note that entity vectors are normalized). 37 | function entity_similarity(e1_wikiid, e2_wikiid) 38 | local e1_vec = e2vutils:entwikiid2vec(e1_wikiid) 39 | local e2_vec = e2vutils:entwikiid2vec(e2_wikiid) 40 | return e1_vec * e2_vec 41 | end 42 | 43 | 44 | ----------------------------------------------------------------------- 45 | ---- Some unit tests to understand the quality of these embeddings ---- 46 | ----------------------------------------------------------------------- 47 | local function geom_top_k_closest_words(ent_name, ent_vec, k) 48 | local tf_map = ent_wiki_words_4EX[ent_name] 49 | local w_not_found = {} 50 | for w,_ in pairs(tf_map) do 51 | if tf_map[w] >= 10 then 52 | w_not_found[w] = tf_map[w] 53 | end 54 | end 55 | 56 | distances = geom_w2v_M * ent_vec 57 | 58 | local best_scores, best_word_ids = topk(distances, k) 59 | local returnwords = {} 60 | local returndistances = {} 61 | for i = 1,k do 62 | local w = get_word_from_id(best_word_ids[i]) 63 | if get_w_id_freq(best_word_ids[i]) >= 200 then 64 | local w_freq_str = '[fr=' .. get_w_id_freq(best_word_ids[i]) .. ']' 65 | if is_stop_word_or_number(w) then 66 | table.insert(returnwords, red(w .. w_freq_str)) 67 | elseif tf_map[w] then 68 | if tf_map[w] >= 15 then 69 | table.insert(returnwords, yellow(w .. w_freq_str .. '{tf=' .. tf_map[w] .. '}')) 70 | else 71 | table.insert(returnwords, skyblue(w .. w_freq_str .. '{tf=' .. tf_map[w] .. '}')) 72 | end 73 | w_not_found[w] = nil 74 | else 75 | table.insert(returnwords, w .. w_freq_str) 76 | end 77 | assert(best_scores[i] == distances[best_word_ids[i]], best_scores[i] .. ' ' .. distances[best_word_ids[i]]) 78 | table.insert(returndistances, distances[best_word_ids[i]]) 79 | end 80 | end 81 | return returnwords, returndistances, w_not_found 82 | end 83 | 84 | 85 | local function geom_most_similar_words_to_ent(ent_name, k) 86 | local ent_wikiid = get_ent_wikiid_from_name(ent_name) 87 | local k = k or 1 88 | local ent_vec = e2vutils:entname2vec(ent_name) 89 | assert(math.abs(1 - ent_vec:norm()) < 0.01 or ent_vec:norm() == 0, ':::: ' .. ent_vec:norm()) 90 | 91 | print('\nTo entity: ' .. blue(ent_name) .. '; vec norm = ' .. ent_vec:norm() .. ':') 92 | neighbors, scores, w_not_found = geom_top_k_closest_words(ent_name, ent_vec, k) 93 | print(green('TOP CLOSEST WORDS: ') .. list_with_scores_to_str(neighbors, scores)) 94 | 95 | local str = yellow('WORDS NOT FOUND: ') 96 | for w,tf in pairs(w_not_found) do 97 | if tf >= 20 then 98 | str = str .. yellow(w .. '{' .. tf .. '}; ') 99 | else 100 | str = str .. w .. '{' .. tf .. '}; ' 101 | end 102 | end 103 | print('\n' .. str) 104 | print('============================================================================') 105 | end 106 | 107 | function geom_unit_tests() 108 | print('\n' .. yellow('TOP CLOSEST WORDS to a given entity based on cosine distance:')) 109 | print('For each word, we show the unigram frequency [fr] and the cosine similarity.') 110 | print('Infrequent words [fr < 500] have noisy embeddings, thus should be trusted less.') 111 | print('WORDS NOT FOUND contains frequent words from the Wikipedia canonical page that are not found in the TOP CLOSEST WORDS list.') 112 | 113 | for i=1,table_len(ent_names_4EX) do 114 | geom_most_similar_words_to_ent(ent_names_4EX[i], 300) 115 | end 116 | end 117 | 118 | print(' Done reading e2v data. Entity vocab size = ' .. e2vutils.lookup:size(1)) 119 | -------------------------------------------------------------------------------- /entities/pretrained_e2v/e2v_txt_reader.lua: -------------------------------------------------------------------------------- 1 | print('==> loading e2v') 2 | 3 | local V = torch.ones(get_total_num_ents(), ent_vecs_size):mul(1e-10) -- not zero because of cosine_distance layer 4 | 5 | local cnt = 0 6 | for line in io.lines(e2v_txtfilename) do 7 | cnt = cnt + 1 8 | if cnt % 1000000 == 0 then 9 | print('=======> processed ' .. cnt .. ' lines') 10 | end 11 | 12 | local parts = split(line, ' ') 13 | assert(table_len(parts) == ent_vecs_size + 1) 14 | local ent_wikiid = tonumber(parts[1]) 15 | local vec = torch.zeros(ent_vecs_size) 16 | for i=1,ent_vecs_size do 17 | vec[i] = tonumber(parts[i + 1]) 18 | end 19 | 20 | if (contains_thid(ent_wikiid)) then 21 | V[get_thid(ent_wikiid)] = vec 22 | else 23 | print('Ent id = ' .. ent_wikiid .. ' does not have a vector. ') 24 | end 25 | end 26 | 27 | print(' Done loading entity vectors. Size = ' .. cnt .. '\n') 28 | 29 | print('Writing t7 File for future usage. Next time Ent2Vec will load faster!') 30 | torch.save(e2v_t7filename, V) 31 | print(' Done saving.\n') 32 | 33 | return V -------------------------------------------------------------------------------- /entities/relatedness/filter_wiki_canonical_words_RLTD.lua: -------------------------------------------------------------------------------- 1 | 2 | if not opt then 3 | cmd = torch.CmdLine() 4 | cmd:option('-root_data_dir', '', 'Root path of the data, $DATA_PATH.') 5 | cmd:text() 6 | opt = cmd:parse(arg or {}) 7 | assert(opt.root_data_dir ~= '', 'Specify a valid root_data_dir path argument.') 8 | end 9 | 10 | 11 | dofile 'utils/utils.lua' 12 | dofile 'entities/relatedness/relatedness.lua' 13 | 14 | input = opt.root_data_dir .. 'generated/wiki_canonical_words.txt' 15 | 16 | output = opt.root_data_dir .. 'generated/wiki_canonical_words_RLTD.txt' 17 | ouf = assert(io.open(output, "w")) 18 | 19 | print('\nStarting dataset filtering.') 20 | 21 | local cnt = 0 22 | for line in io.lines(input) do 23 | cnt = cnt + 1 24 | if cnt % 500000 == 0 then 25 | print(' =======> processed ' .. cnt .. ' lines') 26 | end 27 | 28 | local parts = split(line, '\t') 29 | assert(table_len(parts) == 3) 30 | 31 | local ent_wikiid = tonumber(parts[1]) 32 | local ent_name = parts[2] 33 | assert(ent_wikiid) 34 | 35 | if rewtr.reltd_ents_wikiid_to_rltdid[ent_wikiid] then 36 | ouf:write(line .. '\n') 37 | end 38 | end 39 | 40 | ouf:flush() 41 | io.close(ouf) 42 | -------------------------------------------------------------------------------- /entities/relatedness/filter_wiki_hyperlink_contexts_RLTD.lua: -------------------------------------------------------------------------------- 1 | -- Filter all training data s.t. only candidate entities and ground truth entities for which 2 | -- we have a valid entity embedding are kept. 3 | 4 | if not opt then 5 | cmd = torch.CmdLine() 6 | cmd:option('-root_data_dir', '', 'Root path of the data, $DATA_PATH.') 7 | cmd:text() 8 | opt = cmd:parse(arg or {}) 9 | assert(opt.root_data_dir ~= '', 'Specify a valid root_data_dir path argument.') 10 | end 11 | 12 | 13 | dofile 'utils/utils.lua' 14 | dofile 'entities/relatedness/relatedness.lua' 15 | 16 | input = opt.root_data_dir .. 'generated/wiki_hyperlink_contexts.csv' 17 | 18 | output = opt.root_data_dir .. 'generated/wiki_hyperlink_contexts_RLTD.csv' 19 | ouf = assert(io.open(output, "w")) 20 | 21 | print('\nStarting dataset filtering.') 22 | local cnt = 0 23 | for line in io.lines(input) do 24 | cnt = cnt + 1 25 | if cnt % 50000 == 0 then 26 | print(' =======> processed ' .. cnt .. ' lines') 27 | end 28 | 29 | local parts = split(line, '\t') 30 | local grd_str = parts[table_len(parts)] 31 | assert(parts[table_len(parts) - 1] == 'GT:') 32 | local grd_str_parts = split(grd_str, ',') 33 | 34 | local grd_pos = tonumber(grd_str_parts[1]) 35 | assert(grd_pos) 36 | 37 | local grd_ent_wikiid = tonumber(grd_str_parts[2]) 38 | assert(grd_ent_wikiid) 39 | 40 | if rewtr.reltd_ents_wikiid_to_rltdid[grd_ent_wikiid] then 41 | assert(parts[6] == 'CANDIDATES') 42 | 43 | local output_line = parts[1] .. '\t' .. parts[2] .. '\t' .. parts[3] .. '\t' .. parts[4] .. '\t' .. parts[5] .. '\t' .. parts[6] .. '\t' 44 | 45 | local new_grd_pos = -1 46 | local new_grd_str_without_idx = nil 47 | 48 | local i = 1 49 | local added_ents = 0 50 | while (parts[6 + i] ~= 'GT:') do 51 | local str = parts[6 + i] 52 | local str_parts = split(str, ',') 53 | local ent_wikiid = tonumber(str_parts[1]) 54 | if rewtr.reltd_ents_wikiid_to_rltdid[ent_wikiid] then 55 | added_ents = added_ents + 1 56 | output_line = output_line .. str .. '\t' 57 | end 58 | if (i == grd_pos) then 59 | assert(ent_wikiid == grd_ent_wikiid, 'Error for: ' .. line) 60 | new_grd_pos = added_ents 61 | new_grd_str_without_idx = str 62 | end 63 | 64 | i = i + 1 65 | end 66 | 67 | assert(new_grd_pos > 0) 68 | output_line = output_line .. 'GT:\t' .. new_grd_pos .. ',' .. new_grd_str_without_idx 69 | 70 | ouf:write(output_line .. '\n') 71 | end 72 | end 73 | 74 | ouf:flush() 75 | io.close(ouf) 76 | -------------------------------------------------------------------------------- /utils/logger.lua: -------------------------------------------------------------------------------- 1 | --[[ Logger: a simple class to log symbols during training, 2 | and automate plot generation 3 | Example: 4 | logger = optim.Logger('somefile.log') -- file to save stuff 5 | for i = 1,N do -- log some symbols during 6 | train_error = ... -- training/testing 7 | test_error = ... 8 | logger:add{['training error'] = train_error, 9 | ['test error'] = test_error} 10 | end 11 | logger:style{['training error'] = '-', -- define styles for plots 12 | ['test error'] = '-'} 13 | logger:plot() -- and plot 14 | ---- OR --- 15 | logger = optim.Logger('somefile.log') -- file to save stuff 16 | logger:setNames{'training error', 'test error'} 17 | for i = 1,N do -- log some symbols during 18 | train_error = ... -- training/testing 19 | test_error = ... 20 | logger:add{train_error, test_error} 21 | end 22 | logger:style{'-', '-'} -- define styles for plots 23 | logger:plot() -- and plot 24 | ]] 25 | require 'xlua' 26 | local Logger = torch.class('Logger') 27 | 28 | function Logger:__init(filename, timestamp) 29 | if filename then 30 | self.name = filename 31 | os.execute('mkdir ' .. (sys.uname() ~= 'windows' and '-p ' or '') .. ' "' .. paths.dirname(filename) .. '"') 32 | if timestamp then 33 | -- append timestamp to create unique log file 34 | filename = filename .. '-'..os.date("%Y_%m_%d_%X") 35 | end 36 | self.file = io.open(filename,'w') 37 | self.epsfile = self.name .. '.eps' 38 | else 39 | self.file = io.stdout 40 | self.name = 'stdout' 41 | print(' warning: no path provided, logging to std out') 42 | end 43 | self.empty = true 44 | self.symbols = {} 45 | self.styles = {} 46 | self.names = {} 47 | self.idx = {} 48 | self.figure = nil 49 | self.showPlot = false 50 | self.plotRawCmd = nil 51 | self.defaultStyle = '+' 52 | end 53 | 54 | function Logger:setNames(names) 55 | self.names = names 56 | self.empty = false 57 | self.nsymbols = #names 58 | for k,key in pairs(names) do 59 | self.file:write(key .. '\t') 60 | self.symbols[k] = {} 61 | self.styles[k] = {self.defaultStyle} 62 | self.idx[key] = k 63 | end 64 | self.file:write('\n') 65 | self.file:flush() 66 | end 67 | 68 | function Logger:add(symbols) 69 | -- (1) first time ? print symbols' names on first row 70 | if self.empty then 71 | self.empty = false 72 | self.nsymbols = #symbols 73 | for k,val in pairs(symbols) do 74 | self.file:write(k .. '\t') 75 | self.symbols[k] = {} 76 | self.styles[k] = {self.defaultStyle} 77 | self.names[k] = k 78 | end 79 | self.idx = self.names 80 | self.file:write('\n') 81 | end 82 | -- (2) print all symbols on one row 83 | for k,val in pairs(symbols) do 84 | if type(val) == 'number' then 85 | self.file:write(string.format('%11.4e',val) .. '\t') 86 | elseif type(val) == 'string' then 87 | self.file:write(val .. '\t') 88 | else 89 | xlua.error('can only log numbers and strings', 'Logger') 90 | end 91 | end 92 | self.file:write('\n') 93 | self.file:flush() 94 | -- (3) save symbols in internal table 95 | for k,val in pairs(symbols) do 96 | table.insert(self.symbols[k], val) 97 | end 98 | end 99 | 100 | function Logger:style(symbols) 101 | for name,style in pairs(symbols) do 102 | if type(style) == 'string' then 103 | self.styles[name] = {style} 104 | elseif type(style) == 'table' then 105 | self.styles[name] = style 106 | else 107 | xlua.error('style should be a string or a table of strings','Logger') 108 | end 109 | end 110 | end 111 | 112 | function Logger:plot(ylabel, xlabel) 113 | if not xlua.require('gnuplot') then 114 | if not self.warned then 115 | print(' warning: cannot plot with this version of Torch') 116 | self.warned = true 117 | end 118 | return 119 | end 120 | local plotit = false 121 | local plots = {} 122 | local plotsymbol = 123 | function(name,list) 124 | if #list > 1 then 125 | local nelts = #list 126 | local plot_y = torch.Tensor(nelts) 127 | for i = 1,nelts do 128 | plot_y[i] = list[i] 129 | end 130 | for _,style in ipairs(self.styles[name]) do 131 | table.insert(plots, {self.names[name], plot_y, style}) 132 | end 133 | plotit = true 134 | end 135 | end 136 | 137 | -- plot all symbols 138 | for name,list in pairs(self.symbols) do 139 | plotsymbol(name,list) 140 | end 141 | 142 | if plotit then 143 | if self.showPlot then 144 | self.figure = gnuplot.figure(self.figure) 145 | gnuplot.plot(plots) 146 | gnuplot.xlabel(xlabel) 147 | gnuplot.ylabel(ylabel) 148 | if self.plotRawCmd then gnuplot.raw(self.plotRawCmd) end 149 | gnuplot.grid('on') 150 | gnuplot.title(banner) 151 | end 152 | 153 | if self.epsfile then 154 | os.execute('rm -f "' .. self.epsfile .. '"') 155 | local epsfig = gnuplot.epsfigure(self.epsfile) 156 | gnuplot.plot(plots) 157 | gnuplot.xlabel(xlabel) 158 | gnuplot.ylabel(ylabel) 159 | if self.plotRawCmd then gnuplot.raw(self.plotRawCmd) end 160 | gnuplot.grid('on') 161 | gnuplot.title(banner) 162 | gnuplot.plotflush() 163 | gnuplot.close(epsfig) 164 | end 165 | end 166 | end -------------------------------------------------------------------------------- /utils/optim/adadelta_mem.lua: -------------------------------------------------------------------------------- 1 | -- Memory optimized implementation of optim.adadelta 2 | 3 | --[[ ADADELTA implementation for SGD http://arxiv.org/abs/1212.5701 4 | ARGS: 5 | - `opfunc` : a function that takes a single input (X), the point of 6 | evaluation, and returns f(X) and df/dX 7 | - `x` : the initial point 8 | - `config` : a table of hyper-parameters 9 | - `config.rho` : interpolation parameter 10 | - `config.eps` : for numerical stability 11 | - `config.weightDecay` : weight decay 12 | - `state` : a table describing the state of the optimizer; after each 13 | call the state is modified 14 | - `state.paramVariance` : vector of temporal variances of parameters 15 | - `state.accDelta` : vector of accummulated delta of gradients 16 | RETURN: 17 | - `x` : the new x vector 18 | - `f(x)` : the function, evaluated before the update 19 | ]] 20 | function adadelta_mem(opfunc, x, config, state) 21 | -- (0) get/update state 22 | if config == nil and state == nil then 23 | print('no state table, ADADELTA initializing') 24 | end 25 | local config = config or {} 26 | local state = state or config 27 | local rho = config.rho or 0.9 28 | local eps = config.eps or 1e-6 29 | local wd = config.weightDecay or 0 30 | state.evalCounter = state.evalCounter or 0 31 | -- (1) evaluate f(x) and df/dx 32 | local fx,dfdx = opfunc(x) 33 | 34 | -- (2) weight decay 35 | if wd ~= 0 then 36 | dfdx:add(wd, x) 37 | end 38 | 39 | -- (3) parameter update 40 | if not state.paramVariance then 41 | state.paramVariance = torch.Tensor():typeAs(x):resizeAs(dfdx):zero() 42 | state.delta = torch.Tensor():typeAs(x):resizeAs(dfdx):zero() 43 | state.accDelta = torch.Tensor():typeAs(x):resizeAs(dfdx):zero() 44 | end 45 | state.paramVariance:mul(rho):addcmul(1-rho,dfdx,dfdx) 46 | state.paramVariance:add(eps):sqrt() -- std 47 | state.delta:copy(state.accDelta):add(eps):sqrt():cdiv(state.paramVariance):cmul(dfdx) 48 | x:add(-1, state.delta) 49 | state.accDelta:mul(rho):addcmul(1-rho, state.delta, state.delta) 50 | state.paramVariance:cmul(state.paramVariance):csub(eps) -- recompute variance again 51 | 52 | -- (4) update evaluation counter 53 | state.evalCounter = state.evalCounter + 1 54 | 55 | -- return x*, f(x) before optimization 56 | return x,{fx} 57 | end -------------------------------------------------------------------------------- /utils/optim/adagrad_mem.lua: -------------------------------------------------------------------------------- 1 | -- Memory optimized implementation of optim.adagrad 2 | 3 | --[[ ADAGRAD implementation for SGD 4 | ARGS: 5 | - `opfunc` : a function that takes a single input (X), the point of 6 | evaluation, and returns f(X) and df/dX 7 | - `x` : the initial point 8 | - `state` : a table describing the state of the optimizer; after each 9 | call the state is modified 10 | - `state.learningRate` : learning rate 11 | - `state.paramVariance` : vector of temporal variances of parameters 12 | - `state.weightDecay` : scalar that controls weight decay 13 | RETURN: 14 | - `x` : the new x vector 15 | - `f(x)` : the function, evaluated before the update 16 | ]] 17 | 18 | function adagrad_mem(opfunc, x, config, state) 19 | -- (0) get/update state 20 | if config == nil and state == nil then 21 | print('no state table, ADAGRAD initializing') 22 | end 23 | local config = config or {} 24 | local state = state or config 25 | local lr = config.learningRate or 1e-3 26 | local lrd = config.learningRateDecay or 0 27 | local wd = config.weightDecay or 0 28 | state.evalCounter = state.evalCounter or 0 29 | local nevals = state.evalCounter 30 | 31 | -- (1) evaluate f(x) and df/dx 32 | local fx,dfdx = opfunc(x) 33 | 34 | -- (2) weight decay with a single parameter 35 | if wd ~= 0 then 36 | dfdx:add(wd, x) 37 | end 38 | 39 | -- (3) learning rate decay (annealing) 40 | local clr = lr / (1 + nevals*lrd) 41 | 42 | -- (4) parameter update with single or individual learning rates 43 | if not state.paramVariance then 44 | state.paramVariance = torch.Tensor():typeAs(x):resizeAs(dfdx):zero() 45 | end 46 | 47 | state.paramVariance:addcmul(1,dfdx,dfdx) 48 | state.paramVariance:add(1e-10) 49 | state.paramVariance:sqrt() -- Keeps the std 50 | x:addcdiv(-clr, dfdx, state.paramVariance) 51 | state.paramVariance:cmul(state.paramVariance) -- Keeps the variance again 52 | state.paramVariance:add(-1e-10) 53 | 54 | 55 | -- (5) update evaluation counter 56 | state.evalCounter = state.evalCounter + 1 57 | 58 | -- return x*, f(x) before optimization 59 | return x,{fx} 60 | end -------------------------------------------------------------------------------- /utils/optim/rmsprop_mem.lua: -------------------------------------------------------------------------------- 1 | -- Memory optimized implementation of optim.rmsprop 2 | 3 | --[[ An implementation of RMSprop 4 | ARGS: 5 | - 'opfunc' : a function that takes a single input (X), the point 6 | of a evaluation, and returns f(X) and df/dX 7 | - 'x' : the initial point 8 | - 'config` : a table with configuration parameters for the optimizer 9 | - 'config.learningRate' : learning rate 10 | - 'config.alpha' : smoothing constant 11 | - 'config.epsilon' : value with which to initialise m 12 | - 'config.weightDecay' : weight decay 13 | - 'state' : a table describing the state of the optimizer; 14 | after each call the state is modified 15 | - 'state.m' : leaky sum of squares of parameter gradients, 16 | - 'state.tmp' : and the square root (with epsilon smoothing) 17 | RETURN: 18 | - `x` : the new x vector 19 | - `f(x)` : the function, evaluated before the update 20 | ]] 21 | 22 | function rmsprop_mem(opfunc, x, config, state) 23 | -- (0) get/update state 24 | local config = config or {} 25 | local state = state or config 26 | local lr = config.learningRate or 1e-2 27 | local alpha = config.alpha or 0.99 28 | local epsilon = config.epsilon or 1e-8 29 | local wd = config.weightDecay or 0 30 | local mfill = config.initialMean or 0 31 | 32 | -- (1) evaluate f(x) and df/dx 33 | local fx, dfdx = opfunc(x) 34 | 35 | -- (2) weight decay 36 | if wd ~= 0 then 37 | dfdx:add(wd, x) 38 | end 39 | 40 | -- (3) initialize mean square values and square gradient storage 41 | if not state.m then 42 | state.m = torch.Tensor():typeAs(x):resizeAs(dfdx):fill(mfill) 43 | end 44 | 45 | -- (4) calculate new (leaky) mean squared values 46 | state.m:mul(alpha) 47 | state.m:addcmul(1.0-alpha, dfdx, dfdx) 48 | 49 | -- (5) perform update 50 | state.m:add(epsilon) 51 | state.m:sqrt() 52 | x:addcdiv(-lr, dfdx, state.m) 53 | state.m:cmul(state.m) 54 | state.m:add(-epsilon) 55 | 56 | -- return x*, f(x) before optimization 57 | return x, {fx} 58 | end -------------------------------------------------------------------------------- /utils/utils.lua: -------------------------------------------------------------------------------- 1 | function topk(one_dim_tensor, k) 2 | local bestk, indices = torch.topk(one_dim_tensor, k, true) 3 | local sorted, newindices = torch.sort(bestk, true) 4 | local oldindices = torch.LongTensor(k) 5 | for i = 1,k do 6 | oldindices[i] = indices[newindices[i]] 7 | end 8 | return sorted, oldindices 9 | end 10 | 11 | 12 | function list_with_scores_to_str(list, scores) 13 | local str = '' 14 | for i,v in pairs(list) do 15 | str = str .. list[i] .. '[' .. string.format("%.2f", scores[i]) .. ']; ' 16 | end 17 | return str 18 | end 19 | 20 | function table_len(t) 21 | local count = 0 22 | for _ in pairs(t) do count = count + 1 end 23 | return count 24 | end 25 | 26 | 27 | function split(inputstr, sep) 28 | if sep == nil then 29 | sep = "%s" 30 | end 31 | local t={} ; i=1 32 | for str in string.gmatch(inputstr, "([^"..sep.."]+)") do 33 | t[i] = str 34 | i = i + 1 35 | end 36 | return t 37 | end 38 | -- Unit test: 39 | assert(6 == #split('aa_bb cc__dd ee _ _ __ff' , '_ ')) 40 | 41 | 42 | 43 | function correct_type(data) 44 | if opt.type == 'float' then return data:float() 45 | elseif opt.type == 'double' then return data:double() 46 | elseif string.find(opt.type, 'cuda') then return data:cuda() 47 | else print('Unsuported type') 48 | end 49 | end 50 | 51 | -- color fonts: 52 | function red(s) 53 | return '\27[31m' .. s .. '\27[39m' 54 | end 55 | 56 | function green(s) 57 | return '\27[32m' .. s .. '\27[39m' 58 | end 59 | 60 | function yellow(s) 61 | return '\27[33m' .. s .. '\27[39m' 62 | end 63 | 64 | function blue(s) 65 | return '\27[34m' .. s .. '\27[39m' 66 | end 67 | 68 | function violet(s) 69 | return '\27[35m' .. s .. '\27[39m' 70 | end 71 | 72 | function skyblue(s) 73 | return '\27[36m' .. s .. '\27[39m' 74 | end 75 | 76 | 77 | 78 | function split_in_words(inputstr) 79 | local words = {} 80 | for word in inputstr:gmatch("%w+") do table.insert(words, word) end 81 | return words 82 | end 83 | 84 | 85 | function first_letter_to_uppercase(s) 86 | return s:sub(1,1):upper() .. s:sub(2) 87 | end 88 | 89 | function modify_uppercase_phrase(s) 90 | if (s == s:upper()) then 91 | local words = split_in_words(s:lower()) 92 | local res = {} 93 | for _,w in pairs(words) do 94 | table.insert(res, first_letter_to_uppercase(w)) 95 | end 96 | return table.concat(res, ' ') 97 | else 98 | return s 99 | end 100 | end 101 | 102 | function blue_num_str(n) 103 | return blue(string.format("%.3f", n)) 104 | end 105 | 106 | 107 | 108 | function string_starts(s, m) 109 | return string.sub(s,1,string.len(m)) == m 110 | end 111 | 112 | -- trim: 113 | function trim1(s) 114 | return (s:gsub("^%s*(.-)%s*$", "%1")) 115 | end 116 | 117 | 118 | function nice_print_red_green(a,b) 119 | local s = string.format("%.3f", a) .. ':' .. string.format("%.3f", b) .. '[' 120 | if a > b then 121 | return s .. red(string.format("%.3f", a-b)) .. ']' 122 | elseif a < b then 123 | return s .. green(string.format("%.3f", b-a)) .. ']' 124 | else 125 | return s .. '0]' 126 | end 127 | end 128 | -------------------------------------------------------------------------------- /words/load_w_freq_and_vecs.lua: -------------------------------------------------------------------------------- 1 | -- Loads all common words in both Wikipedia and Word2vec/Glove , their unigram frequencies and their pre-trained Word2Vec embeddings. 2 | 3 | -- To load this as a standalone file do: 4 | -- th> opt = {word_vecs = 'w2v', root_data_dir = '$DATA_PATH'} 5 | -- th> dofile 'words/load_w_freq_and_vecs.lua' 6 | 7 | if not opt then 8 | cmd = torch.CmdLine() 9 | cmd:option('-root_data_dir', '', 'Root path of the data, $DATA_PATH.') 10 | cmd:text() 11 | opt = cmd:parse(arg or {}) 12 | assert(opt.root_data_dir ~= '', 'Specify a valid root_data_dir path argument.') 13 | end 14 | 15 | 16 | default_path = opt.root_data_dir .. 'basic_data/wordEmbeddings/' 17 | 18 | tds = tds or require 'tds' 19 | torch.setdefaulttensortype('torch.FloatTensor') 20 | if not list_with_scores_to_str then 21 | dofile 'utils/utils.lua' 22 | end 23 | if not is_stop_word_or_number then 24 | dofile 'words/stop_words.lua' 25 | end 26 | 27 | assert(opt, 'Define opt') 28 | assert(opt.word_vecs, 'Define opt.word_vecs') 29 | 30 | print('==> Loading common w2v + top freq list of words') 31 | 32 | local output_t7filename = opt.root_data_dir .. 'generated/common_top_words_freq_vectors_' .. opt.word_vecs .. '.t7' 33 | 34 | if paths.filep(output_t7filename) then 35 | print(' ---> from t7 file.') 36 | common_w2v_freq_words = torch.load(output_t7filename) 37 | 38 | else 39 | print(' ---> t7 file NOT found. Loading from disk instead (slower). Out file = ' .. output_t7filename) 40 | local freq_words = tds.Hash() 41 | 42 | print(' word freq index ...') 43 | local num_freq_words = 1 44 | local w_freq_file = opt.root_data_dir .. 'generated/word_wiki_freq.txt' 45 | for line in io.lines(w_freq_file) do 46 | local parts = split(line, '\t') 47 | local w = parts[1] 48 | local w_f = tonumber(parts[2]) 49 | if not is_stop_word_or_number(w) then 50 | freq_words[w] = w_f 51 | num_freq_words = num_freq_words + 1 52 | end 53 | end 54 | 55 | common_w2v_freq_words = tds.Hash() 56 | 57 | print(' word vectors index ...') 58 | if opt.word_vecs == 'glove' then 59 | w2v_txtfilename = default_path .. 'Glove/glove.840B.300d.txt' 60 | local line_num = 0 61 | for line in io.lines(w2v_txtfilename) do 62 | line_num = line_num + 1 63 | if line_num % 200000 == 0 then 64 | print(' Processed ' .. line_num) 65 | end 66 | local parts = split(line, ' ') 67 | local w = parts[1] 68 | if freq_words[w] then 69 | common_w2v_freq_words[w] = 1 70 | end 71 | end 72 | 73 | else 74 | assert(opt.word_vecs == 'w2v') 75 | w2v_binfilename = default_path .. 'Word2Vec/GoogleNews-vectors-negative300.bin' 76 | local word_vecs_size = 300 77 | file = torch.DiskFile(w2v_binfilename,'r') 78 | file:ascii() 79 | local vocab_size = file:readInt() 80 | local size = file:readInt() 81 | assert(size == word_vecs_size, 'Wrong size : ' .. size .. ' vs ' .. word_vecs_size) 82 | 83 | function read_string_w2v(file) 84 | local str = {} 85 | while true do 86 | local char = file:readChar() 87 | if char == 32 or char == 10 or char == 0 then 88 | break 89 | else 90 | str[#str+1] = char 91 | end 92 | end 93 | str = torch.CharStorage(str) 94 | return str:string() 95 | end 96 | 97 | --Reading Contents 98 | file:binary() 99 | local line_num = 0 100 | for i = 1,vocab_size do 101 | line_num = line_num + 1 102 | if line_num % 200000 == 0 then 103 | print('Processed ' .. line_num) 104 | end 105 | local w = read_string_w2v(file) 106 | local v = torch.FloatTensor(file:readFloat(word_vecs_size)) 107 | if freq_words[w] then 108 | common_w2v_freq_words[w] = 1 109 | end 110 | end 111 | end 112 | 113 | print('Writing t7 File for future usage. Next time loading will be faster!') 114 | torch.save(output_t7filename, common_w2v_freq_words) 115 | end 116 | 117 | -- Now load the freq and w2v indexes 118 | dofile 'words/w_freq/w_freq_index.lua' 119 | -------------------------------------------------------------------------------- /words/stop_words.lua: -------------------------------------------------------------------------------- 1 | all_stop_words = { ['a'] = 1, ['about'] = 1, ['above'] = 1, ['across'] = 1, ['after'] = 1, ['afterwards'] = 1, ['again'] = 1, ['against'] = 1, ['all'] = 1, ['almost'] = 1, ['alone'] = 1, ['along'] = 1, ['already'] = 1, ['also'] = 1, ['although'] = 1, ['always'] = 1, ['am'] = 1, ['among'] = 1, ['amongst'] = 1, ['amoungst'] = 1, ['amount'] = 1, ['an'] = 1, ['and'] = 1, ['another'] = 1, ['any'] = 1, ['anyhow'] = 1, ['anyone'] = 1, ['anything'] = 1, ['anyway'] = 1, ['anywhere'] = 1, ['are'] = 1, ['around'] = 1, ['as'] = 1, ['at'] = 1, ['back'] = 1, ['be'] = 1, ['became'] = 1, ['because'] = 1, ['become'] = 1, ['becomes'] = 1, ['becoming'] = 1, ['been'] = 1, ['before'] = 1, ['beforehand'] = 1, ['behind'] = 1, ['being'] = 1, ['below'] = 1, ['beside'] = 1, ['besides'] = 1, ['between'] = 1, ['beyond'] = 1, ['both'] = 1, ['bottom'] = 1, ['but'] = 1, ['by'] = 1, ['call'] = 1, ['can'] = 1, ['cannot'] = 1, ['cant'] = 1, ['dont'] = 1, ['co'] = 1, ['con'] = 1, ['could'] = 1, ['couldnt'] = 1, ['cry'] = 1, ['de'] = 1, ['describe'] = 1, ['detail'] = 1, ['do'] = 1, ['done'] = 1, ['down'] = 1, ['due'] = 1, ['during'] = 1, ['each'] = 1, ['eg'] = 1, ['eight'] = 1, ['either'] = 1, ['eleven'] = 1, ['else'] = 1, ['elsewhere'] = 1, ['empty'] = 1, ['enough'] = 1, ['etc'] = 1, ['even'] = 1, ['ever'] = 1, ['every'] = 1, ['everyone'] = 1, ['everything'] = 1, ['everywhere'] = 1, ['except'] = 1, ['few'] = 1, ['fifteen'] = 1, ['fify'] = 1, ['fill'] = 1, ['find'] = 1, ['fire'] = 1, ['first'] = 1, ['five'] = 1, ['for'] = 1, ['former'] = 1, ['formerly'] = 1, ['forty'] = 1, ['found'] = 1, ['four'] = 1, ['from'] = 1, ['front'] = 1, ['full'] = 1, ['further'] = 1, ['get'] = 1, ['give'] = 1, ['go'] = 1, ['had'] = 1, ['has'] = 1, ['hasnt'] = 1, ['have'] = 1, ['he'] = 1, ['hence'] = 1, ['her'] = 1, ['here'] = 1, ['hereafter'] = 1, ['hereby'] = 1, ['herein'] = 1, ['hereupon'] = 1, ['hers'] = 1, ['herself'] = 1, ['him'] = 1, ['himself'] = 1, ['his'] = 1, ['how'] = 1, ['however'] = 1, ['hundred'] = 1, ['i'] = 1, ['ie'] = 1, ['if'] = 1, ['in'] = 1, ['inc'] = 1, ['indeed'] = 1, ['interest'] = 1, ['into'] = 1, ['is'] = 1, ['it'] = 1, ['its'] = 1, ['itself'] = 1, ['keep'] = 1, ['last'] = 1, ['latter'] = 1, ['latterly'] = 1, ['least'] = 1, ['less'] = 1, ['ltd'] = 1, ['made'] = 1, ['many'] = 1, ['may'] = 1, ['me'] = 1, ['meanwhile'] = 1, ['might'] = 1, ['mill'] = 1, ['mine'] = 1, ['more'] = 1, ['moreover'] = 1, ['most'] = 1, ['mostly'] = 1, ['move'] = 1, ['much'] = 1, ['must'] = 1, ['my'] = 1, ['myself'] = 1, ['name'] = 1, ['namely'] = 1, ['neither'] = 1, ['never'] = 1, ['nevertheless'] = 1, ['next'] = 1, ['nine'] = 1, ['no'] = 1, ['nobody'] = 1, ['none'] = 1, ['noone'] = 1, ['nor'] = 1, ['not'] = 1, ['nothing'] = 1, ['now'] = 1, ['nowhere'] = 1, ['of'] = 1, ['off'] = 1, ['often'] = 1, ['on'] = 1, ['once'] = 1, ['one'] = 1, ['only'] = 1, ['onto'] = 1, ['or'] = 1, ['other'] = 1, ['others'] = 1, ['otherwise'] = 1, ['our'] = 1, ['ours'] = 1, ['ourselves'] = 1, ['out'] = 1, ['over'] = 1, ['own'] = 1, ['part'] = 1, ['per'] = 1, ['perhaps'] = 1, ['please'] = 1, ['put'] = 1, ['rather'] = 1, ['re'] = 1, ['same'] = 1, ['see'] = 1, ['seem'] = 1, ['seemed'] = 1, ['seeming'] = 1, ['seems'] = 1, ['serious'] = 1, ['several'] = 1, ['she'] = 1, ['should'] = 1, ['show'] = 1, ['side'] = 1, ['since'] = 1, ['sincere'] = 1, ['six'] = 1, ['sixty'] = 1, ['so'] = 1, ['some'] = 1, ['somehow'] = 1, ['someone'] = 1, ['something'] = 1, ['sometime'] = 1, ['sometimes'] = 1, ['somewhere'] = 1, ['still'] = 1, ['such'] = 1, ['system'] = 1, ['take'] = 1, ['ten'] = 1, ['than'] = 1, ['that'] = 1, ['the'] = 1, ['their'] = 1, ['them'] = 1, ['themselves'] = 1, ['then'] = 1, ['thence'] = 1, ['there'] = 1, ['thereafter'] = 1, ['thereby'] = 1, ['therefore'] = 1, ['therein'] = 1, ['thereupon'] = 1, ['these'] = 1, ['they'] = 1, ['thick'] = 1, ['thin'] = 1, ['third'] = 1, ['this'] = 1, ['those'] = 1, ['though'] = 1, ['three'] = 1, ['through'] = 1, ['throughout'] = 1, ['thru'] = 1, ['thus'] = 1, ['to'] = 1, ['together'] = 1, ['too'] = 1, ['top'] = 1, ['toward'] = 1, ['towards'] = 1, ['twelve'] = 1, ['twenty'] = 1, ['two'] = 1, ['un'] = 1, ['under'] = 1, ['until'] = 1, ['up'] = 1, ['upon'] = 1, ['us'] = 1, ['very'] = 1, ['via'] = 1, ['was'] = 1, ['we'] = 1, ['well'] = 1, ['were'] = 1, ['what'] = 1, ['whatever'] = 1, ['when'] = 1, ['whence'] = 1, ['whenever'] = 1, ['where'] = 1, ['whereafter'] = 1, ['whereas'] = 1, ['whereby'] = 1, ['wherein'] = 1, ['whereupon'] = 1, ['wherever'] = 1, ['whether'] = 1, ['which'] = 1, ['while'] = 1, ['whither'] = 1, ['who'] = 1, ['whoever'] = 1, ['whole'] = 1, ['whom'] = 1, ['whose'] = 1, ['why'] = 1, ['will'] = 1, ['with'] = 1, ['within'] = 1, ['without'] = 1, ['would'] = 1, ['yet'] = 1, ['you'] = 1, ['your'] = 1, ['yours'] = 1, ['yourself'] = 1, ['yourselves'] = 1, ['st'] = 1, ['years'] = 1, ['yourselves'] = 1, ['new'] = 1, ['used'] = 1, ['known'] = 1, ['year'] = 1, ['later'] = 1, ['including'] = 1, ['used'] = 1, ['end'] = 1, ['did'] = 1, ['just'] = 1, ['best'] = 1, ['using'] = 1} 2 | 3 | function is_stop_word_or_number(w) 4 | assert(type(w) == 'string', w) 5 | if (all_stop_words[w:lower()]) or tonumber(w) or w:len() <= 1 then 6 | return true 7 | else 8 | return false 9 | end 10 | end -------------------------------------------------------------------------------- /words/w2v/glove_reader.lua: -------------------------------------------------------------------------------- 1 | local M = torch.zeros(total_num_words(), word_vecs_size):float() 2 | 3 | --Reading Contents 4 | for line in io.lines(w2v_txtfilename) do 5 | local parts = split(line, ' ') 6 | local w = parts[1] 7 | local w_id = get_id_from_word(w) 8 | if w_id ~= unk_w_id then 9 | for i=2, #parts do 10 | M[w_id][i-1] = tonumber(parts[i]) 11 | end 12 | end 13 | end 14 | 15 | return M 16 | 17 | 18 | -------------------------------------------------------------------------------- /words/w2v/w2v.lua: -------------------------------------------------------------------------------- 1 | -- Loads pre-trained word embeddings from either Word2Vec or Glove 2 | 3 | assert(get_id_from_word) 4 | assert(common_w2v_freq_words) 5 | assert(total_num_words) 6 | 7 | word_vecs_size = 300 8 | 9 | -- Loads pre-trained glove or word2vec embeddings: 10 | if opt.word_vecs == 'glove' then 11 | -- Glove downloaded from: http://nlp.stanford.edu/projects/glove/ 12 | w2v_txtfilename = default_path .. 'Glove/glove.840B.300d.txt' 13 | w2v_t7filename = opt.root_data_dir .. 'generated/glove.840B.300d.t7' 14 | w2v_reader = 'words/w2v/glove_reader.lua' 15 | elseif opt.word_vecs == 'w2v' then 16 | -- Word2Vec downloaded from: https://code.google.com/archive/p/word2vec/ 17 | w2v_binfilename = default_path .. 'Word2Vec/GoogleNews-vectors-negative300.bin' 18 | w2v_t7filename = opt.root_data_dir .. 'generated/GoogleNews-vectors-negative300.t7' 19 | w2v_reader = 'words/w2v/word2vec_reader.lua' 20 | end 21 | 22 | ---------------------- Code: ----------------------- 23 | w2vutils = {} 24 | 25 | print('==> Loading ' .. opt.word_vecs .. ' vectors') 26 | if not paths.filep(w2v_t7filename) then 27 | print(' ---> t7 file NOT found. Loading w2v from the bin/txt file instead (slower).') 28 | w2vutils.M = require(w2v_reader) 29 | print('Writing t7 File for future usage. Next time Word2Vec loading will be faster!') 30 | torch.save(w2v_t7filename, w2vutils.M) 31 | else 32 | print(' ---> from t7 file.') 33 | w2vutils.M = torch.load(w2v_t7filename) 34 | end 35 | 36 | -- Move the word embedding matrix on the GPU if we do some training. 37 | -- In this way we can perform word embedding lookup much faster. 38 | if opt and string.find(opt.type, 'cuda') then 39 | w2vutils.M = w2vutils.M:cuda() 40 | end 41 | 42 | ---------- Define additional functions ----------------- 43 | -- word -> vec 44 | w2vutils.get_w_vec = function (self,word) 45 | local w_id = get_id_from_word(word) 46 | return w2vutils.M[w_id]:clone() 47 | end 48 | 49 | -- word_id -> vec 50 | w2vutils.get_w_vec_from_id = function (self,w_id) 51 | return w2vutils.M[w_id]:clone() 52 | end 53 | 54 | w2vutils.lookup_w_vecs = function (self,word_id_tensor) 55 | assert(word_id_tensor:dim() <= 2, 'Only word id tensors w/ 1 or 2 dimensions are supported.') 56 | local output = torch.FloatTensor() 57 | local word_ids = word_id_tensor:long() 58 | if opt and string.find(opt.type, 'cuda') then 59 | output = output:cuda() 60 | word_ids = word_ids:cuda() 61 | end 62 | 63 | if word_ids:dim() == 2 then 64 | output:index(w2vutils.M, 1, word_ids:view(-1)) 65 | output = output:view(word_ids:size(1), word_ids:size(2), w2vutils.M:size(2)) 66 | elseif word_ids:dim() == 1 then 67 | output:index(w2vutils.M, 1, word_ids) 68 | output = output:view(word_ids:size(1), w2vutils.M:size(2)) 69 | end 70 | 71 | return output 72 | end 73 | 74 | -- Normalize word vectors to have norm 1 . 75 | w2vutils.renormalize = function (self) 76 | w2vutils.M[unk_w_id]:mul(0) 77 | w2vutils.M[unk_w_id]:add(1) 78 | w2vutils.M:cdiv(w2vutils.M:norm(2,2):expand(w2vutils.M:size())) 79 | local x = w2vutils.M:norm(2,2):view(-1) - 1 80 | assert(x:norm() < 0.1, x:norm()) 81 | assert(w2vutils.M[100]:norm() < 1.001 and w2vutils.M[100]:norm() > 0.99) 82 | w2vutils.M[unk_w_id]:mul(0) 83 | end 84 | 85 | w2vutils:renormalize() 86 | 87 | print(' Done reading w2v data. Word vocab size = ' .. w2vutils.M:size(1)) 88 | 89 | -- Phrase embedding using average of vectors of words in the phrase 90 | w2vutils.phrase_avg_vec = function(self, phrase) 91 | local words = split_in_words(phrase) 92 | local num_words = table_len(words) 93 | local num_existent_words = 0 94 | local vec = torch.zeros(word_vecs_size) 95 | for i = 1,num_words do 96 | local w = words[i] 97 | local w_id = get_id_from_word(w) 98 | if w_id ~= unk_w_id then 99 | vec:add(w2vutils:get_w_vec_from_id(w_id)) 100 | num_existent_words = num_existent_words + 1 101 | end 102 | end 103 | if (num_existent_words > 0) then 104 | vec:div(num_existent_words) 105 | end 106 | return vec 107 | end 108 | 109 | w2vutils.top_k_closest_words = function (self,vec, k, mat) 110 | local k = k or 1 111 | vec = vec:float() 112 | local distances = torch.mv(mat, vec) 113 | local best_scores, best_word_ids = topk(distances, k) 114 | local returnwords = {} 115 | local returndistances = {} 116 | for i = 1,k do 117 | local w = get_word_from_id(best_word_ids[i]) 118 | if is_stop_word_or_number(w) then 119 | table.insert(returnwords, red(w)) 120 | else 121 | table.insert(returnwords, w) 122 | end 123 | assert(best_scores[i] == distances[best_word_ids[i]], best_scores[i] .. ' ' .. distances[best_word_ids[i]]) 124 | table.insert(returndistances, distances[best_word_ids[i]]) 125 | end 126 | return returnwords, returndistances 127 | end 128 | 129 | w2vutils.most_similar2word = function(self, word, k) 130 | local k = k or 1 131 | local v = w2vutils:get_w_vec(word) 132 | neighbors, scores = w2vutils:top_k_closest_words(v, k, w2vutils.M) 133 | print('To word ' .. skyblue(word) .. ' : ' .. list_with_scores_to_str(neighbors, scores)) 134 | end 135 | 136 | w2vutils.most_similar2vec = function(self, vec, k) 137 | local k = k or 1 138 | neighbors, scores = w2vutils:top_k_closest_words(vec, k, w2vutils.M) 139 | print(list_with_scores_to_str(neighbors, scores)) 140 | end 141 | 142 | 143 | --------------------- Unit tests ---------------------------------------- 144 | local unit_tests = opt.unit_tests or false 145 | if (unit_tests) then 146 | print('\nWord to word similarity test:') 147 | w2vutils:most_similar2word('nice', 5) 148 | w2vutils:most_similar2word('france', 5) 149 | w2vutils:most_similar2word('hello', 5) 150 | end 151 | 152 | -- Computes for each word w : \sum_v exp() and \sum_v 153 | w2vutils.total_word_correlation = function(self, k, j) 154 | local exp_Z = torch.zeros(w2vutils.M:narrow(1, 1, j):size(1)) 155 | 156 | local sum_t = w2vutils.M:narrow(1, 1, j):sum(1) -- 1 x d 157 | local sum_Z = (w2vutils.M:narrow(1, 1, j) * sum_t:t()):view(-1) -- num_w 158 | 159 | print(red('Top words by sum_Z:')) 160 | best_sum_Z, best_word_ids = topk(sum_Z, k) 161 | for i = 1,k do 162 | local w = get_word_from_id(best_word_ids[i]) 163 | assert(best_sum_Z[i] == sum_Z[best_word_ids[i]]) 164 | print(w .. ' [' .. best_sum_Z[i] .. ']; ') 165 | end 166 | 167 | print('\n' .. red('Bottom words by sum_Z:')) 168 | best_sum_Z, best_word_ids = topk(- sum_Z, k) 169 | for i = 1,k do 170 | local w = get_word_from_id(best_word_ids[i]) 171 | assert(best_sum_Z[i] == - sum_Z[best_word_ids[i]]) 172 | print(w .. ' [' .. sum_Z[best_word_ids[i]] .. ']; ') 173 | end 174 | end 175 | 176 | 177 | -- Plot with gnuplot: 178 | -- set palette model RGB defined ( 0 'white', 1 'pink', 2 'green' , 3 'blue', 4 'red' ) 179 | -- plot 'tsne-w2v-vecs.txt_1000' using 1:2:3 with labels offset 0,1, '' using 1:2:4 w points pt 7 ps 2 palette 180 | w2vutils.tsne = function(self, num_rand_words) 181 | local topic1 = {'japan', 'china', 'france', 'switzerland', 'romania', 'india', 'australia', 'country', 'city', 'tokyo', 'nation', 'capital', 'continent', 'europe', 'asia', 'earth', 'america'} 182 | local topic2 = {'football', 'striker', 'goalkeeper', 'basketball', 'coach', 'championship', 'cup', 183 | 'soccer', 'player', 'captain', 'qualifier', 'goal', 'under-21', 'halftime', 'standings', 'basketball', 184 | 'games', 'league', 'rugby', 'hockey', 'fifa', 'fans', 'maradona', 'mutu', 'hagi', 'beckham', 'injury', 'game', 185 | 'kick', 'penalty'} 186 | local topic_avg = {'japan national football team', 'germany national football team', 187 | 'china national football team', 'brazil soccer', 'japan soccer', 'germany soccer', 'china soccer', 188 | 'fc barcelona', 'real madrid'} 189 | 190 | local stop_words_array = {} 191 | for w,_ in pairs(stop_words) do 192 | table.insert(stop_words_array, w) 193 | end 194 | 195 | local topic1_len = table_len(topic1) 196 | local topic2_len = table_len(topic2) 197 | local topic_avg_len = table_len(topic_avg) 198 | local stop_words_len = table_len(stop_words_array) 199 | 200 | torch.setdefaulttensortype('torch.DoubleTensor') 201 | w2vutils.M = w2vutils.M:double() 202 | 203 | local tensor = torch.zeros(num_rand_words + stop_words_len + topic1_len + topic2_len + topic_avg_len, word_vecs_size) 204 | local tensor_w_ids = torch.zeros(num_rand_words) 205 | local tensor_colors = torch.zeros(tensor:size(1)) 206 | 207 | for i = 1,num_rand_words do 208 | tensor_w_ids[i] = math.random(1,25000) 209 | tensor_colors[i] = 0 210 | tensor[i]:copy(w2vutils.M[tensor_w_ids[i]]) 211 | end 212 | 213 | for i = 1, stop_words_len do 214 | tensor_colors[num_rand_words + i] = 1 215 | tensor[num_rand_words + i]:copy(w2vutils:phrase_avg_vec(stop_words_array[i])) 216 | end 217 | 218 | for i = 1, topic1_len do 219 | tensor_colors[num_rand_words + stop_words_len + i] = 2 220 | tensor[num_rand_words + stop_words_len + i]:copy(w2vutils:phrase_avg_vec(topic1[i])) 221 | end 222 | 223 | for i = 1, topic2_len do 224 | tensor_colors[num_rand_words + stop_words_len + topic1_len + i] = 3 225 | tensor[num_rand_words + stop_words_len + topic1_len + i]:copy(w2vutils:phrase_avg_vec(topic2[i])) 226 | end 227 | 228 | for i = 1, topic_avg_len do 229 | tensor_colors[num_rand_words + stop_words_len + topic1_len + topic2_len + i] = 4 230 | tensor[num_rand_words + stop_words_len + topic1_len + topic2_len + i]:copy(w2vutils:phrase_avg_vec(topic_avg[i])) 231 | end 232 | 233 | local manifold = require 'manifold' 234 | opts = {ndims = 2, perplexity = 30, pca = 50, use_bh = false} 235 | mapped_x1 = manifold.embedding.tsne(tensor, opts) 236 | assert(mapped_x1:size(1) == tensor:size(1) and mapped_x1:size(2) == 2) 237 | ouf_vecs = assert(io.open('tsne-w2v-vecs.txt_' .. num_rand_words, "w")) 238 | for i = 1,mapped_x1:size(1) do 239 | local w = nil 240 | if tensor_colors[i] == 0 then 241 | w = get_word_from_id(tensor_w_ids[i]) 242 | elseif tensor_colors[i] == 1 then 243 | w = stop_words_array[i - num_rand_words]:gsub(' ', '-') 244 | elseif tensor_colors[i] == 2 then 245 | w = topic1[i - num_rand_words - stop_words_len]:gsub(' ', '-') 246 | elseif tensor_colors[i] == 3 then 247 | w = topic2[i - num_rand_words - stop_words_len - topic1_len]:gsub(' ', '-') 248 | elseif tensor_colors[i] == 4 then 249 | w = topic_avg[i - num_rand_words - stop_words_len - topic1_len - topic2_len]:gsub(' ', '-') 250 | end 251 | assert(w) 252 | 253 | local v = mapped_x1[i] 254 | for j = 1,2 do 255 | ouf_vecs:write(v[j] .. ' ') 256 | end 257 | ouf_vecs:write(w .. ' ' .. tensor_colors[i] .. '\n') 258 | end 259 | io.close(ouf_vecs) 260 | print(' DONE') 261 | end 262 | -------------------------------------------------------------------------------- /words/w2v/word2vec_reader.lua: -------------------------------------------------------------------------------- 1 | -- Adapted from https://github.com/rotmanmi/word2vec.torch 2 | function read_string_w2v(file) 3 | local str = {} 4 | while true do 5 | local char = file:readChar() 6 | if char == 32 or char == 10 or char == 0 then 7 | break 8 | else 9 | str[#str+1] = char 10 | end 11 | end 12 | str = torch.CharStorage(str) 13 | return str:string() 14 | end 15 | 16 | 17 | file = torch.DiskFile(w2v_binfilename,'r') 18 | 19 | --Reading Header 20 | file:ascii() 21 | local vocab_size = file:readInt() 22 | local size = file:readInt() 23 | assert(size == word_vecs_size, 'Wrong size : ' .. size .. ' vs ' .. word_vecs_size) 24 | 25 | local M = torch.zeros(total_num_words(), word_vecs_size):float() 26 | 27 | --Reading Contents 28 | file:binary() 29 | local num_phrases = 0 30 | for i = 1,vocab_size do 31 | local w = read_string_w2v(file) 32 | local v = torch.FloatTensor(file:readFloat(word_vecs_size)) 33 | local w_id = get_id_from_word(w) 34 | if w_id ~= unk_w_id then 35 | M[w_id]:copy(v) 36 | end 37 | end 38 | 39 | print('Num words = ' .. total_num_words() .. '. Num phrases = ' .. num_phrases) 40 | 41 | return M 42 | -------------------------------------------------------------------------------- /words/w_freq/w_freq_gen.lua: -------------------------------------------------------------------------------- 1 | -- Computes an unigram frequency of each word in the Wikipedia corpus 2 | 3 | if not opt then 4 | cmd = torch.CmdLine() 5 | cmd:option('-root_data_dir', '', 'Root path of the data, $DATA_PATH.') 6 | cmd:text() 7 | opt = cmd:parse(arg or {}) 8 | assert(opt.root_data_dir ~= '', 'Specify a valid root_data_dir path argument.') 9 | end 10 | 11 | 12 | require 'torch' 13 | dofile 'utils/utils.lua' 14 | 15 | tds = tds or require 'tds' 16 | 17 | word_freqs = tds.Hash() 18 | 19 | local num_lines = 0 20 | it, _ = io.open(opt.root_data_dir .. 'generated/wiki_canonical_words.txt') 21 | line = it:read() 22 | 23 | while (line) do 24 | num_lines = num_lines + 1 25 | if num_lines % 100000 == 0 then 26 | print('Processed ' .. num_lines .. ' lines. ') 27 | end 28 | 29 | local parts = split(line , '\t') 30 | local words = split(parts[3], ' ') 31 | for _,w in pairs(words) do 32 | if (not word_freqs[w]) then 33 | word_freqs[w] = 0 34 | end 35 | word_freqs[w] = word_freqs[w] + 1 36 | end 37 | line = it:read() 38 | end 39 | 40 | 41 | -- Writing word frequencies 42 | print('Sorting and writing') 43 | sorted_word_freq = {} 44 | for w,freq in pairs(word_freqs) do 45 | if freq >= 10 then 46 | table.insert(sorted_word_freq, {w = w, freq = freq}) 47 | end 48 | end 49 | 50 | table.sort(sorted_word_freq, function(a,b) return a.freq > b.freq end) 51 | 52 | out_file = opt.root_data_dir .. 'generated/word_wiki_freq.txt' 53 | ouf = assert(io.open(out_file, "w")) 54 | total_freq = 0 55 | for _,x in pairs(sorted_word_freq) do 56 | ouf:write(x.w .. '\t' .. x.freq .. '\n') 57 | total_freq = total_freq + x.freq 58 | end 59 | ouf:flush() 60 | io.close(ouf) 61 | 62 | print('Total freq = ' .. total_freq .. '\n') 63 | -------------------------------------------------------------------------------- /words/w_freq/w_freq_index.lua: -------------------------------------------------------------------------------- 1 | -- Loads all words and their frequencies and IDs from a dictionary. 2 | assert(common_w2v_freq_words) 3 | if not opt.unig_power then 4 | opt.unig_power = 0.6 5 | end 6 | 7 | print('==> Loading word freq map with unig power ' .. red(opt.unig_power)) 8 | local w_freq_file = opt.root_data_dir .. 'generated/word_wiki_freq.txt' 9 | 10 | local w_freq = {} 11 | w_freq.id2word = tds.Hash() 12 | w_freq.word2id = tds.Hash() 13 | 14 | w_freq.w_f_start = tds.Hash() 15 | w_freq.w_f_end = tds.Hash() 16 | w_freq.total_freq = 0.0 17 | 18 | w_freq.w_f_at_unig_power_start = tds.Hash() 19 | w_freq.w_f_at_unig_power_end = tds.Hash() 20 | w_freq.total_freq_at_unig_power = 0.0 21 | 22 | -- UNK word id 23 | unk_w_id = 1 24 | w_freq.word2id['UNK_W'] = unk_w_id 25 | w_freq.id2word[unk_w_id] = 'UNK_W' 26 | 27 | local tmp_wid = 1 28 | for line in io.lines(w_freq_file) do 29 | local parts = split(line, '\t') 30 | local w = parts[1] 31 | if common_w2v_freq_words[w] then 32 | tmp_wid = tmp_wid + 1 33 | local w_id = tmp_wid 34 | w_freq.id2word[w_id] = w 35 | w_freq.word2id[w] = w_id 36 | 37 | local w_f = tonumber(parts[2]) 38 | assert(w_f) 39 | if w_f < 100 then 40 | w_f = 100 41 | end 42 | w_freq.w_f_start[w_id] = w_freq.total_freq 43 | w_freq.total_freq = w_freq.total_freq + w_f 44 | w_freq.w_f_end[w_id] = w_freq.total_freq 45 | 46 | w_freq.w_f_at_unig_power_start[w_id] = w_freq.total_freq_at_unig_power 47 | w_freq.total_freq_at_unig_power = w_freq.total_freq_at_unig_power + math.pow(w_f, opt.unig_power) 48 | w_freq.w_f_at_unig_power_end[w_id] = w_freq.total_freq_at_unig_power 49 | end 50 | end 51 | 52 | w_freq.total_num_words = tmp_wid 53 | 54 | print(' Done loading word freq index. Num words = ' .. w_freq.total_num_words .. '; total freq = ' .. w_freq.total_freq) 55 | 56 | -------------------------------------------- 57 | total_num_words = function() 58 | return w_freq.total_num_words 59 | end 60 | 61 | contains_w_id = function(w_id) 62 | assert(w_id >= 1 and w_id <= total_num_words(), w_id) 63 | return (w_id ~= unk_w_id) 64 | end 65 | 66 | -- id -> word 67 | get_word_from_id = function(w_id) 68 | assert(w_id >= 1 and w_id <= total_num_words(), w_id) 69 | return w_freq.id2word[w_id] 70 | end 71 | 72 | -- word -> id 73 | get_id_from_word = function(w) 74 | local w_id = w_freq.word2id[w] 75 | if w_id == nil then 76 | return unk_w_id 77 | end 78 | return w_id 79 | end 80 | 81 | contains_w = function(w) 82 | return contains_w_id(get_id_from_word(w)) 83 | end 84 | 85 | -- word frequency: 86 | function get_w_id_freq(w_id) 87 | assert(contains_w_id(w_id), w_id) 88 | return w_freq.w_f_end[w_id] - w_freq.w_f_start[w_id] + 1 89 | end 90 | 91 | -- p(w) prior: 92 | function get_w_id_unigram(w_id) 93 | return get_w_id_freq(w_id) / w_freq.total_freq 94 | end 95 | 96 | function get_w_tensor_log_unigram(vec_w_ids) 97 | assert(vec_w_ids:dim() == 2) 98 | local v = torch.zeros(vec_w_ids:size(1), vec_w_ids:size(2)) 99 | for i= 1,vec_w_ids:size(1) do 100 | for j = 1,vec_w_ids:size(2) do 101 | v[i][j] = math.log(get_w_id_unigram(vec_w_ids[i][j])) 102 | end 103 | end 104 | return v 105 | end 106 | 107 | 108 | if (opt.unit_tests) then 109 | print(get_w_id_unigram(get_id_from_word('the'))) 110 | print(get_w_id_unigram(get_id_from_word('of'))) 111 | print(get_w_id_unigram(get_id_from_word('and'))) 112 | print(get_w_id_unigram(get_id_from_word('romania'))) 113 | end 114 | 115 | -- Generates an random word sampled from the word unigram frequency. 116 | local function random_w_id(total_freq, w_f_start, w_f_end) 117 | local j = math.random() * total_freq 118 | local i_start = 2 119 | local i_end = total_num_words() 120 | 121 | while i_start <= i_end do 122 | local i_mid = math.floor((i_start + i_end) / 2) 123 | local w_id_mid = i_mid 124 | if w_f_start[w_id_mid] <= j and j <= w_f_end[w_id_mid] then 125 | return w_id_mid 126 | elseif (w_f_start[w_id_mid] > j) then 127 | i_end = i_mid - 1 128 | elseif (w_f_end[w_id_mid] < j) then 129 | i_start = i_mid + 1 130 | end 131 | end 132 | print(red('Binary search error !!')) 133 | end 134 | 135 | -- Frequent word subsampling procedure from the Word2Vec paper. 136 | function random_unigram_at_unig_power_w_id() 137 | return random_w_id(w_freq.total_freq_at_unig_power, w_freq.w_f_at_unig_power_start, w_freq.w_f_at_unig_power_end) 138 | end 139 | 140 | 141 | function get_w_unnorm_unigram_at_power(w_id) 142 | return math.pow(get_w_id_unigram(w_id), opt.unig_power) 143 | end 144 | 145 | 146 | function unit_test_random_unigram_at_unig_power_w_id(k_samples) 147 | local empirical_dist = {} 148 | for i=1,k_samples do 149 | local w_id = random_unigram_at_unig_power_w_id() 150 | assert(w_id ~= unk_w_id) 151 | if not empirical_dist[w_id] then 152 | empirical_dist[w_id] = 0 153 | end 154 | empirical_dist[w_id] = empirical_dist[w_id] + 1 155 | end 156 | print('Now sorting ..') 157 | local sorted_empirical_dist = {} 158 | for k,v in pairs(empirical_dist) do 159 | table.insert(sorted_empirical_dist, {w_id = k, f = v}) 160 | end 161 | table.sort(sorted_empirical_dist, function(a,b) return a.f > b.f end) 162 | 163 | local str = '' 164 | for i = 1,math.min(100, table_len(sorted_empirical_dist)) do 165 | str = str .. get_word_from_id(sorted_empirical_dist[i].w_id) .. '{' .. sorted_empirical_dist[i].f .. '}; ' 166 | end 167 | print('Unit test random sampling: ' .. str) 168 | end 169 | --------------------------------------------------------------------------------