├── LICENSE
├── README.md
├── data_gen
├── gen_p_e_m
│ ├── gen_p_e_m_from_wiki.lua
│ ├── gen_p_e_m_from_yago.lua
│ ├── merge_crosswikis_wiki.lua
│ └── unicode_map.lua
├── gen_test_train_data
│ ├── gen_ace_msnbc_aquaint_csv.lua
│ ├── gen_aida_test.lua
│ ├── gen_aida_train.lua
│ └── gen_all.lua
├── gen_wiki_data
│ ├── gen_ent_wiki_w_repr.lua
│ └── gen_wiki_hyp_train_data.lua
├── indexes
│ ├── wiki_disambiguation_pages_index.lua
│ ├── wiki_redirects_index.lua
│ └── yago_crosswikis_wiki.lua
└── parse_wiki_dump
│ └── parse_wiki_dump_tools.lua
├── ed
├── args.lua
├── ed.lua
├── loss.lua
├── minibatch
│ ├── build_minibatch.lua
│ └── data_loader.lua
├── models
│ ├── SetConstantDiag.lua
│ ├── linear_layers.lua
│ ├── model.lua
│ ├── model_global.lua
│ └── model_local.lua
├── test
│ ├── check_coref.lua
│ ├── coref_persons.lua
│ ├── ent_freq_stats_test.lua
│ ├── ent_p_e_m_stats_test.lua
│ ├── test.lua
│ └── test_one_loaded_model.lua
└── train.lua
├── entities
├── ent_name2id_freq
│ ├── e_freq_gen.lua
│ ├── e_freq_index.lua
│ └── ent_name_id.lua
├── learn_e2v
│ ├── 4EX_wiki_words.lua
│ ├── batch_dataset_a.lua
│ ├── e2v_a.lua
│ ├── learn_a.lua
│ ├── minibatch_a.lua
│ └── model_a.lua
├── pretrained_e2v
│ ├── check_ents.lua
│ ├── e2v.lua
│ └── e2v_txt_reader.lua
└── relatedness
│ ├── filter_wiki_canonical_words_RLTD.lua
│ ├── filter_wiki_hyperlink_contexts_RLTD.lua
│ └── relatedness.lua
├── our_system_annotations.txt
├── utils
├── logger.lua
├── optim
│ ├── adadelta_mem.lua
│ ├── adagrad_mem.lua
│ └── rmsprop_mem.lua
└── utils.lua
└── words
├── load_w_freq_and_vecs.lua
├── stop_words.lua
├── w2v
├── glove_reader.lua
├── w2v.lua
└── word2vec_reader.lua
└── w_freq
├── w_freq_gen.lua
└── w_freq_index.lua
/README.md:
--------------------------------------------------------------------------------
1 | # Source code for "Deep Joint Entity Disambiguation with Local Neural Attention"
2 |
3 | [O-E. Ganea and T. Hofmann, full paper @ EMNLP 2017](https://arxiv.org/abs/1704.04920)
4 |
5 | Slides and poster can be accessed [here](http://people.inf.ethz.ch/ganeao/).
6 |
7 | ## Pre-trained entity embeddings
8 |
9 | Entity embeddings trained with our method using Word2Vec 300 dimensional pre-trained word vectors (GoogleNews-vectors-negative300.bin). They have norm 1 and are restricted only to entities appearing in the training, validation and test sets described in our paper. Available [here](https://polybox.ethz.ch/index.php/s/sH2JSB2c1OSj7yv).
10 |
11 | ## Full set of annotations made by one of our global models
12 |
13 | See file our_system_annotations.txt . Best to visualize together with its color scheme in a bash terminal. Contains the full set of annotations for the following datasets:
14 |
15 | ```
16 | $ cat our_system_annotations.txt | grep 'Micro '
17 | ==> AQUAINT AQUAINT ; EPOCH = 307: Micro recall = 88.03% ; Micro F1 = 89.51%
18 | ==> MSNBC MSNBC ; EPOCH = 307: Micro recall = 93.29% ; Micro F1 = 93.65%
19 | ==> ACE04 ACE04 ; EPOCH = 307: Micro recall = 84.05% ; Micro F1 = 86.92%
20 | ==> aida-B aida-B ; EPOCH = 307: Micro recall = 92.08% ; Micro F1 = 92.08%
21 | ==> aida-A aida-A ; EPOCH = 307: Micro recall = 91.00% ; Micro F1 = 91.01%
22 | ```
23 |
24 | Global model was trained on AIDA-train with pre-trained entity embeddings trained on Wikipedia. See details of how to run our code below.
25 |
26 | Detailed statistics per dataset as in table 6 of our paper can be accessed:
27 |
28 | ```
29 | $ cat our_system_annotations.txt | grep -A20 'Micro '
30 | ```
31 |
32 |
33 | ## How to run the system and reproduce our results
34 |
35 | 1) Install [Torch](http://torch.ch/)
36 |
37 |
38 | 2) Install torch libraries: cudnn, cutorch, [tds](https://github.com/torch/tds), gnuplot, xlua
39 |
40 | ```luarocks install lib_name```
41 |
42 | Check that each of these libraries can be imported in a torch terminal.
43 |
44 |
45 | 3) Create a $DATA_PATH directoy (will be assumed to end in '/' in the next steps). Create a directory $DATA_PATH/generated/ that will contain all files generated in the next steps.
46 |
47 |
48 | 4) Download data files needed for training and testing from [this link](https://drive.google.com/uc?id=0Bx8d3azIm_ZcbHMtVmRVc1o5TWM&export=download).
49 | Download basic_data.zip, unzip it and place the basic_data directory in $DATA_PATH/. All generated files will be build based on files in this basic_data/ directory.
50 |
51 |
52 | 5) Download pre-trained Word2Vec vectors GoogleNews-vectors-negative300.bin.gz from https://code.google.com/archive/p/word2vec/.
53 | Unzip it and place the bin file in the folder $DATA_PATH/basic_data/wordEmbeddings/Word2Vec.
54 |
55 |
56 | Now we start creating additional data files needed in our pipeline:
57 |
58 | 6) Create wikipedia_p_e_m.txt:
59 |
60 | ```th data_gen/gen_p_e_m/gen_p_e_m_from_wiki.lua -root_data_dir $DATA_PATH```
61 |
62 |
63 | 7) Merge wikipedia_p_e_m.txt and crosswikis_p_e_m.txt :
64 |
65 | ```th data_gen/gen_p_e_m/merge_crosswikis_wiki.lua -root_data_dir $DATA_PATH```
66 |
67 |
68 | 8) Create yago_p_e_m.txt:
69 |
70 | ```th data_gen/gen_p_e_m/gen_p_e_m_from_yago.lua -root_data_dir $DATA_PATH ```
71 |
72 |
73 | 9) Create a file ent_wiki_freq.txt with entity frequencies:
74 |
75 | ```th entities/ent_name2id_freq/e_freq_gen.lua -root_data_dir $DATA_PATH```
76 |
77 |
78 | 10) Generate all entity disambiguation datasets in a CSV format needed in our training stage:
79 |
80 | ```
81 | mkdir $DATA_PATH/generated/test_train_data/
82 | th data_gen/gen_test_train_data/gen_all.lua -root_data_dir $DATA_PATH
83 | ```
84 |
85 | Verify the statistics of these files as explained in the header comments of the files gen_ace_msnbc_aquaint_csv.lua and gen_aida_test.lua .
86 |
87 |
88 | 11) Create training data for learning entity embeddings:
89 |
90 | i) From Wiki canonical pages:
91 |
92 | ```th data_gen/gen_wiki_data/gen_ent_wiki_w_repr.lua -root_data_dir $DATA_PATH```
93 |
94 | ii) From context windows surrounding Wiki hyperlinks:
95 |
96 | ```th data_gen/gen_wiki_data/gen_wiki_hyp_train_data.lua -root_data_dir $DATA_PATH```
97 |
98 |
99 | 12) Compute the unigram frequency of each word in the Wikipedia corpus:
100 |
101 | ```th words/w_freq/w_freq_gen.lua -root_data_dir $DATA_PATH```
102 |
103 |
104 | 13) Compute the restricted training data for learning entity embeddings by using only candidate entities from the relatedness datasets and all ED sets:
105 | i) From Wiki canonical pages:
106 |
107 | ```th entities/relatedness/filter_wiki_canonical_words_RLTD.lua -root_data_dir $DATA_PATH```
108 |
109 | ii) From context windows surrounding Wiki hyperlinks:
110 |
111 | ```th entities/relatedness/filter_wiki_hyperlink_contexts_RLTD.lua -root_data_dir $DATA_PATH```
112 |
113 | All files in the $DATA_PATH/generated/ folder containing the substring "_RLTD" are restricted to this set of entities (should contain 276030 entities).
114 |
115 | Your $DATA_PATH/generated/ folder should now contain the files :
116 |
117 | ```
118 | $DATA_PATH/generated $ ls -lah ./
119 | total 147G
120 | 9.5M all_candidate_ents_ed_rltd_datasets_RLTD.t7
121 | 775M crosswikis_wikipedia_p_e_m.txt
122 | 5.0M empty_page_ents.txt
123 | 520M ent_name_id_map.t7
124 | 95M ent_wiki_freq.txt
125 | 7.3M relatedness_test.t7
126 | 8.9M relatedness_validate.t7
127 | 220 test_train_data
128 | 1.5G wiki_canonical_words_RLTD.txt
129 | 8.4G wiki_canonical_words.txt
130 | 88G wiki_hyperlink_contexts.csv
131 | 48G wiki_hyperlink_contexts_RLTD.csv
132 | 329M wikipedia_p_e_m.txt
133 | 11M word_wiki_freq.txt
134 | 749M yago_p_e_m.txt
135 |
136 | $DATA_PATH//generated/test_train_data $ ls -lah ./
137 | total 124M
138 | 14M aida_testA.csv
139 | 13M aida_testB.csv
140 | 50M aida_train.csv
141 | 723K wned-ace2004.csv
142 | 1.6M wned-aquaint.csv
143 | 31M wned-clueweb.csv
144 | 1.6M wned-msnbc.csv
145 | 15M wned-wikipedia.csv
146 | ```
147 |
148 | 14) Now we train entity embeddings for the restricted set of entities (written in all_candidate_ents_ed_rltd_datasets_RLTD.t7). This is the step described in Section 3 of our paper.
149 |
150 | To check the full list of parameters run:
151 |
152 | ```th entities/learn_e2v/learn_a.lua -help```
153 |
154 | Optimal parameters (see entities/learn_e2v/learn_a.lua):
155 |
156 | ```
157 | optimization = 'ADAGRAD'
158 | lr = 0.3
159 | batch_size = 500
160 | word_vecs = 'w2v'
161 | num_words_per_ent = 20
162 | num_neg_words = 5
163 | unig_power = 0.6
164 | entities = 'RLTD'
165 | loss = 'maxm'
166 | data = 'wiki-canonical-hyperlinks'
167 | num_passes_wiki_words = 200
168 | hyp_ctxt_len = 10
169 | ```
170 |
171 | To run the embedding training on one GPU:
172 |
173 | ```
174 | mkdir $DATA_PATH/generated/ent_vecs
175 | CUDA_VISIBLE_DEVICES=0 th entities/learn_e2v/learn_a.lua -root_data_dir $DATA_PATH |& tee log_train_entity_vecs
176 | ```
177 |
178 | Warning: This code is not sufficiently optimized to run at maximum speed and GPU usage. Sorry for the inconvenience. It only uses the main thread to load data and perform word embedding lookup. It can be made to run much faster.
179 |
180 | During training, you will see (in the log file log_train_entity_vecs) the validation score on the entity relatedness dataset of the current set of entity embeddings. After around 24 hours (this code is not optimized!), you will see no improvement and can thus stop the training script. Pick the set of saved entity vectors from the folder generated/ent_vecs/ corresponding to the best validation score on the entity relatedness dataset which is the sum of all validation metrics (the TOTAL VALIDATION column). In our paper, we reported in Table 1 the results on the test set corresponding to this best validation score.
181 |
182 | You should get some numbers similar to the following (may vary a little bit due to random initialization):
183 |
184 | ```
185 | Entity Relatedness quality measure:
186 | measure = NDCG1 NDCG5 NDCG10 MAP TOTAL VALIDATION
187 | our (vald) = 0.681 0.639 0.671 0.619 2.610
188 | our (test) = 0.650 0.609 0.641 0.579
189 | Yamada'16 = 0.59 0.56 0.59 0.52
190 | WikiMW = 0.54 0.52 0.55 0.48
191 | ==> saving model to $DATA_PATH/generated/ent_vecs/ent_vecs__ep_69.t7
192 | ```
193 |
194 | We will call the name of this file with entity embeddings as $ENTITY_VECS. In our case, it is 'ent_vecs__ep_69.t7'
195 |
196 | This code uses a simple initialization of entity embeddings based on the average of entities' title words (excluding stop words). This helps speed-up training and avoiding getting stuck in local minima. We found that using a random initialization might result in a slight quality decrease for ED only (up to 1%), requiring also a longer training time until reaching the same quality on entity relatedness (~60 hours).`
197 |
198 | 15) Run the training for the global/local ED neural network. Arguments file: ed/args.lua . To list all arguments:
199 |
200 | ```th ed/ed.lua -help```
201 |
202 | Command to run the training:
203 |
204 | ```
205 | mkdir $DATA_PATH/generated/ed_models/
206 | mkdir $DATA_PATH/generated/ed_models/training_plots/
207 | CUDA_VISIBLE_DEVICES=0 th ed/ed.lua -root_data_dir $DATA_PATH -ent_vecs_filename $ENTITY_VECS -model 'global' |& tee log_train_ed
208 | ```
209 |
210 | Let it train for at least 48 hours (or 400 epochs as defined in our code), until the validation accuracy does not improve any more or starts dropping. As we wrote in the paper, we stop learning if
211 | the validation F1 does not increase after 500 full epochs of the AIDA train dataset. Validation F1 can be following using the command:
212 |
213 | ```cat log_train_ed | grep -A20 'Micro F1' | grep -A20 'aida-A'```
214 |
215 | The best ED models will be saved in the folder generated/ed_models. This will only happen after the model gets > 90% F1 score on validation set (see test.lua).
216 |
217 | Statistics, weights and scors will be written in the log_train_ed file. Plots of micro F1 scores on all the validation and test set will be written in the folder $DATA_PATH/generated/ed_models/training_plots/ .
218 |
219 | Results variability: Results reported in our ED paper in Tables 3 and 4 are averaged over different runs of the ED neural architecture learning, but using the same set of entity embeddings. However, we found that the variance of ED results based on different trainings of entity embeddings might be a little higher, up to 0.5%.
220 |
221 | 16) After training is terminated, one can re-load and test the best ED model using the command:
222 |
223 | ```
224 | CUDA_VISIBLE_DEVICES=0 th ed/test/test_one_loaded_model.lua -root_data_dir $DATA_PATH -model global -ent_vecs_filename $ENTITY_VECS -test_one_model_file $ED_MODEL_FILENAME
225 | ```
226 |
227 | where $ED_MODEL_FILENAME is a file in $DATA_PATH/generated/ed_models/ .
228 |
229 |
230 | 17) Enjoy!
231 |
--------------------------------------------------------------------------------
/data_gen/gen_p_e_m/gen_p_e_m_from_wiki.lua:
--------------------------------------------------------------------------------
1 | -- Generate p(e|m) index from Wikipedia
2 | -- Run: th data_gen/gen_p_e_m/gen_p_e_m_from_wiki.lua -root_data_dir $DATA_PATH
3 |
4 | cmd = torch.CmdLine()
5 | cmd:option('-root_data_dir', '', 'Root path of the data, $DATA_PATH.')
6 | cmd:text()
7 | opt = cmd:parse(arg or {})
8 | assert(opt.root_data_dir ~= '', 'Specify a valid root_data_dir path argument.')
9 |
10 |
11 | require 'torch'
12 | dofile 'utils/utils.lua'
13 | dofile 'data_gen/parse_wiki_dump/parse_wiki_dump_tools.lua'
14 | tds = tds or require 'tds'
15 |
16 | print('\nComputing Wikipedia p_e_m')
17 |
18 | it, _ = io.open(opt.root_data_dir .. 'basic_data/textWithAnchorsFromAllWikipedia2014Feb.txt')
19 | line = it:read()
20 |
21 | wiki_e_m_counts = tds.Hash()
22 |
23 | -- Find anchors, e.g. anarchism
24 | local num_lines = 0
25 | local parsing_errors = 0
26 | local list_ent_errors = 0
27 | local diez_ent_errors = 0
28 | local disambiguation_ent_errors = 0
29 | local num_valid_hyperlinks = 0
30 |
31 | while (line) do
32 | num_lines = num_lines + 1
33 | if num_lines % 5000000 == 0 then
34 | print('Processed ' .. num_lines .. ' lines. Parsing errs = ' ..
35 | parsing_errors .. ' List ent errs = ' ..
36 | list_ent_errors .. ' diez errs = ' .. diez_ent_errors ..
37 | ' disambig errs = ' .. disambiguation_ent_errors ..
38 | ' . Num valid hyperlinks = ' .. num_valid_hyperlinks)
39 | end
40 |
41 | if not line:find(' b.freq end)
79 |
80 | local str = ''
81 | local total_freq = 0
82 | for _,el in pairs(tbl) do
83 | str = str .. el.ent_wikiid .. ',' .. el.freq
84 | str = str .. ',' .. get_ent_name_from_wikiid(el.ent_wikiid):gsub(' ', '_') .. '\t'
85 | total_freq = total_freq + el.freq
86 | end
87 | ouf:write(mention .. '\t' .. total_freq .. '\t' .. str .. '\n')
88 | end
89 | ouf:flush()
90 | io.close(ouf)
91 |
92 | print(' Done sorting and writing.')
93 |
--------------------------------------------------------------------------------
/data_gen/gen_p_e_m/gen_p_e_m_from_yago.lua:
--------------------------------------------------------------------------------
1 | -- Generate p(e|m) index from Wikipedia
2 | -- Run: th data_gen/gen_p_e_m/gen_p_e_m_from_yago.lua -root_data_dir $DATA_PATH
3 |
4 | cmd = torch.CmdLine()
5 | cmd:option('-root_data_dir', '', 'Root path of the data, $DATA_PATH.')
6 | cmd:text()
7 | opt = cmd:parse(arg or {})
8 | assert(opt.root_data_dir ~= '', 'Specify a valid root_data_dir path argument.')
9 |
10 |
11 | require 'torch'
12 | dofile 'utils/utils.lua'
13 | dofile 'data_gen/gen_p_e_m/unicode_map.lua'
14 | if not get_redirected_ent_title then
15 | dofile 'data_gen/indexes/wiki_redirects_index.lua'
16 | end
17 | if not get_ent_name_from_wikiid then
18 | dofile 'entities/ent_name2id_freq/ent_name_id.lua'
19 | end
20 |
21 | tds = tds or require 'tds'
22 |
23 | print('\nComputing YAGO p_e_m')
24 |
25 | local it, _ = io.open(opt.root_data_dir .. 'basic_data/p_e_m_data/aida_means.tsv')
26 | local line = it:read()
27 |
28 | local num_lines = 0
29 | local wiki_e_m_counts = tds.Hash()
30 |
31 | while (line) do
32 | num_lines = num_lines + 1
33 | if num_lines % 5000000 == 0 then
34 | print('Processed ' .. num_lines .. ' lines.')
35 | end
36 | local parts = split(line, '\t')
37 | assert(table_len(parts) == 2)
38 | assert(parts[1]:sub(1,1) == '"')
39 | assert(parts[1]:sub(parts[1]:len(),parts[1]:len()) == '"')
40 |
41 | local mention = parts[1]:sub(2, parts[1]:len() - 1)
42 | local ent_name = parts[2]
43 | ent_name = string.gsub(ent_name, '&', '&')
44 | ent_name = string.gsub(ent_name, '"', '"')
45 | while ent_name:find('\\u') do
46 | local x = ent_name:find('\\u')
47 | local code = ent_name:sub(x, x + 5)
48 | assert(unicode2ascii[code], code)
49 | replace = unicode2ascii[code]
50 | if(replace == "%") then
51 | replace = "%%"
52 | end
53 | ent_name = string.gsub(ent_name, code, replace)
54 | end
55 |
56 | ent_name = preprocess_ent_name(ent_name)
57 | local ent_wikiid = get_ent_wikiid_from_name(ent_name, true)
58 | if ent_wikiid ~= unk_ent_wikiid then
59 | if not wiki_e_m_counts[mention] then
60 | wiki_e_m_counts[mention] = tds.Hash()
61 | end
62 | wiki_e_m_counts[mention][ent_wikiid] = 1
63 | end
64 |
65 | line = it:read()
66 | end
67 |
68 |
69 | print('Now sorting and writing ..')
70 | out_file = opt.root_data_dir .. 'generated/yago_p_e_m.txt'
71 | ouf = assert(io.open(out_file, "w"))
72 |
73 | for mention, list in pairs(wiki_e_m_counts) do
74 | local str = ''
75 | local total_freq = 0
76 | for ent_wikiid, _ in pairs(list) do
77 | str = str .. ent_wikiid .. ',' .. get_ent_name_from_wikiid(ent_wikiid):gsub(' ', '_') .. '\t'
78 | total_freq = total_freq + 1
79 | end
80 | ouf:write(mention .. '\t' .. total_freq .. '\t' .. str .. '\n')
81 | end
82 | ouf:flush()
83 | io.close(ouf)
84 |
85 | print(' Done sorting and writing.')
86 |
--------------------------------------------------------------------------------
/data_gen/gen_p_e_m/merge_crosswikis_wiki.lua:
--------------------------------------------------------------------------------
1 | -- Merge Wikipedia and Crosswikis p(e|m) indexes
2 | -- Run: th data_gen/gen_p_e_m/merge_crosswikis_wiki.lua -root_data_dir $DATA_PATH
3 |
4 | cmd = torch.CmdLine()
5 | cmd:option('-root_data_dir', '', 'Root path of the data, $DATA_PATH.')
6 | cmd:text()
7 | opt = cmd:parse(arg or {})
8 | assert(opt.root_data_dir ~= '', 'Specify a valid root_data_dir path argument.')
9 |
10 |
11 | require 'torch'
12 | dofile 'utils/utils.lua'
13 | dofile 'entities/ent_name2id_freq/ent_name_id.lua'
14 |
15 | print('\nMerging Wikipedia and Crosswikis p_e_m')
16 |
17 | tds = tds or require 'tds'
18 | merged_e_m_counts = tds.Hash()
19 |
20 | print('Process Wikipedia')
21 | it, _ = io.open(opt.root_data_dir .. 'generated/wikipedia_p_e_m.txt')
22 | line = it:read()
23 |
24 | while (line) do
25 | local parts = split(line, "\t")
26 | local mention = parts[1]
27 |
28 | if (not mention:find('Wikipedia')) and (not mention:find('wikipedia')) then
29 | if not merged_e_m_counts[mention] then
30 | merged_e_m_counts[mention] = tds.Hash()
31 | end
32 |
33 | local total_freq = tonumber(parts[2])
34 | assert(total_freq)
35 | local num_ents = table_len(parts)
36 | for i = 3, num_ents do
37 | local ent_str = split(parts[i], ",")
38 | local ent_wikiid = tonumber(ent_str[1])
39 | assert(ent_wikiid)
40 | local freq = tonumber(ent_str[2])
41 | assert(freq)
42 |
43 | if not merged_e_m_counts[mention][ent_wikiid] then
44 | merged_e_m_counts[mention][ent_wikiid] = 0
45 | end
46 | merged_e_m_counts[mention][ent_wikiid] = merged_e_m_counts[mention][ent_wikiid] + freq
47 | end
48 | end
49 | line = it:read()
50 | end
51 |
52 |
53 | print('Process Crosswikis')
54 |
55 | it, _ = io.open(opt.root_data_dir .. 'basic_data/p_e_m_data/crosswikis_p_e_m.txt')
56 | line = it:read()
57 |
58 | while (line) do
59 | local parts = split(line, "\t")
60 | local mention = parts[1]
61 |
62 | if (not mention:find('Wikipedia')) and (not mention:find('wikipedia')) then
63 | if not merged_e_m_counts[mention] then
64 | merged_e_m_counts[mention] = tds.Hash()
65 | end
66 |
67 | local total_freq = tonumber(parts[2])
68 | assert(total_freq)
69 | local num_ents = table_len(parts)
70 | for i = 3, num_ents do
71 | local ent_str = split(parts[i], ",")
72 | local ent_wikiid = tonumber(ent_str[1])
73 | assert(ent_wikiid)
74 | local freq = tonumber(ent_str[2])
75 | assert(freq)
76 |
77 | if not merged_e_m_counts[mention][ent_wikiid] then
78 | merged_e_m_counts[mention][ent_wikiid] = 0
79 | end
80 | merged_e_m_counts[mention][ent_wikiid] = merged_e_m_counts[mention][ent_wikiid] + freq
81 | end
82 | end
83 | line = it:read()
84 | end
85 |
86 |
87 | print('Now sorting and writing ..')
88 | out_file = opt.root_data_dir .. 'generated/crosswikis_wikipedia_p_e_m.txt'
89 | ouf = assert(io.open(out_file, "w"))
90 |
91 | for mention, list in pairs(merged_e_m_counts) do
92 | if mention:len() >= 1 then
93 | local tbl = {}
94 | for ent_wikiid, freq in pairs(list) do
95 | table.insert(tbl, {ent_wikiid = ent_wikiid, freq = freq})
96 | end
97 | table.sort(tbl, function(a,b) return a.freq > b.freq end)
98 |
99 | local str = ''
100 | local total_freq = 0
101 | local num_ents = 0
102 | for _,el in pairs(tbl) do
103 | if is_valid_ent(el.ent_wikiid) then
104 | str = str .. el.ent_wikiid .. ',' .. el.freq
105 | str = str .. ',' .. get_ent_name_from_wikiid(el.ent_wikiid):gsub(' ', '_') .. '\t'
106 | num_ents = num_ents + 1
107 | total_freq = total_freq + el.freq
108 |
109 | if num_ents >= 100 then -- At most 100 candidates
110 | break
111 | end
112 | end
113 | end
114 | ouf:write(mention .. '\t' .. total_freq .. '\t' .. str .. '\n')
115 | end
116 | end
117 | ouf:flush()
118 | io.close(ouf)
119 |
120 | print(' Done sorting and writing.')
121 |
122 |
--------------------------------------------------------------------------------
/data_gen/gen_p_e_m/unicode_map.lua:
--------------------------------------------------------------------------------
1 | local unicode = {'\\u00bb', '\\u007d', '\\u00a1', '\\u0259', '\\u0641', '\\u0398', '\\u00fd', '\\u0940', '\\u00f9', '\\u02bc', '\\u00f6', '\\u00f8', '\\u0107', '\\u0648', '\\u0105', '\\u002c', '\\u6768', '\\u0160', '\\u015b', '\\u00c0', '\\u266f', '\\u0430', '\\u0141', '\\u1ea3', '\\u00df', '\\u2212', '\\u8650', '\\u012d', '\\u1e47', '\\u00c5', '\\u00ab', '\\u0226', '\\u0930', '\\u04a4', '\\u030d', '\\u0631', '\\u207f', '\\u00bf', '\\u2010', '\\u1e6e', '\\u00cd', '\\u00c6', '\\u30e8', '\\u1e63', '\\u6a5f', '\\u03c3', '\\u00d5', '\\u0644', '\\u2020', '\\u0104', '\\u010a', '\\u013c', '\\u0123', '\\u0159', '\\u1e6f', '\\u003a', '\\u06af', '\\u00fc', '\\u4e09', '\\u0028', '\\u03b2', '\\u2103', '\\u0191', '\\u03bc', '\\u00d1', '\\u207a', '\\u79d2', '\\u6536', '\\u1ed3', '\\u0329', '\\u0196', '\\u00fb', '\\u0435', '\\u01b0', '\\u007e', '\\u1e62', '\\u0181', '\\u1ea7', '\\u2011', '\\u03c9', '\\u201d', '\\u0165', '\\u0422', '\\u1e33', '\\u0144', '\\u00fa', '\\u1ed5', '\\u0632', '\\u0643', '\\u1ea1', '\\u011e', '\\u062a', '\\u00ee', '\\u00c2', '\\u016d', '\\u003d', '\\u2202', '\\u2605', '\\u0112', '\\u73cd', '\\u03a1', '\\u0182', '\\u00d2', '\\u0153', '\\u016f', '\\u00de', '\\u00a3', '\\u1e45', '\\u1ef1', '\\u4e45', '\\u06cc', '\\u1ea2', '\\u0152', '\\uff09', '\\u0219', '\\u0457', '\\u0283', '\\u1e5a', '\\u064e', '\\u0164', '\\u0116', '\\u018b', '\\u1ec3', '\\u00b4', '\\u002b', '\\u7248', '\\u0937', '\\u203a', '\\u4eba', '\\u002f', '\\u0136', '\\u01bc', '\\u017b', '\\u00a9', '\\u03a9', '\\u00f2', '\\u2026', '\\u00c7', '\\u0969', '\\u0198', '\\u011f', '\\u00e0', '\\u0126', '\\u018a', '\\u1edf', '\\u005e', '\\u03b4', '\\u0137', '\\u01f5', '\\u1e34', '\\u007b', '\\u00f3', '\\u01c0', '\\u00f1', '\\u1ef9', '\\u03c8', '\\ub8e8', '\\u0119', '\\u014f', '\\uff5e', '\\u016c', '\\u0358', '\\u529f', '\\u2606', '\\u00e4', '\\u012b', '\\u00c9', '\\u0173', '\\u092f', '\\u5957', '\\u1e49', '\\u1ec1', '\\u019f', '\\u02bb', '\\u0399', '\\u01c1', '\\u03d5', '\\u017a', '\\u1e0d', '\\u0148', '\\u01a4', '\\u00a2', '\\u011c', '\\u1e92', '\\u01b1', '\\u0443', '\\u00f4', '\\u2122', '\\u82e5', '\\u0967', '\\u1eaf', '\\u013e', '\\u1e46', '\\u03b1', '\\u884c', '\\u0328', '\\u0021', '\\u00aa', '\\u014d', '\\u002e', '\\u00cb', '\\u062f', '\\u0102', '\\u0155', '\\u00cf', '\\u0446', '\\u1e80', '\\u003b', '\\u6c38', '\\u0103', '\\u1e6c', '\\u203c', '\\u00dc', '\\u00b3', '\\u0145', '\\u0122', '\\u8fdb', '\\u015e', '\\u017c', '\\u043f', '\\u0442', '\\u0629', '\\u6176', '\\u1edd', '\\u2018', '\\u00ea', '\\u0060', '\\u0147', '\\u00e7', '\\u00dd', '\\u00d6', '\\u043a', '\\u00ec', '\\ufb01', '\\u0124', '\\u1e5f', '\\u0431', '\\u1e94', '\\u1ea8', '\\u01b3', '\\u018f', '\\u0627', '\\u1ef3', '\\u091f', '\\u03a6', '\\u674e', '\\u016b', '\\u039b', '\\u2032', '\\u002a', '\\u2033', '\\u00b7', '\\u00ce', '\\u2075', '\\u043c', '\\u2116', '\\u1e6d', '\\u00be', '\\u0171', '\\u0433', '\\u0635', '\\u0640', '\\u01a1', '\\u00a7', '\\u019d', '\\u301c', '\\u00f5', '\\u0187', '\\u1ef6', '\\u1e0c', '\\u1ec5', '\\u01b8', '\\u00f0', '\\u1e93', '\\u00eb', '\\u03bd', '\\u0108', '\\u010d', '\\u5229', '\\u2080', '\\u00c8', '\\u9ece', '\\u0917', '\\u85cf', '\\u1edb', '\\u1e25', '\\u045b', '\\u1ec9', '\\u1ecd', '\\u0633', '\\u2022', '\\u01e8', '\\u0197', '\\u2019', '\\u0421', '\\u1ecb', '\\u9910', '\\uc2a4', '\\u1e31', '\\u00c1', '\\u1ea9', '\\u1eeb', '\\u01f4', '\\u01c2', '\\u0146', '\\u0162', '\\ufb02', '\\u01ac', '\\u0025', '\\u015c', '\\u01ce', '\\u01c3', '\\u0179', '\\u01e5', '\\u58eb', '\\u745e', '\\uff08', '\\u016a', '\\u03a5', '\\u039c', '\\u00bd', '\\u0169', '\\u55f7', '\\u89d2', '\\u00e6', '\\u9752', '\\u005b', '\\u003f', '\\u041f', '\\u06a9', '\\u0027', '\\u1ebf', '\\u0646', '\\u0130', '\\u1eb1', '\\u0395', '\\u03b3', '\\u01d4', '\\u00ed', '\\u041a', '\\u0149', '\\u0143', '\\u010f', '\\u1e24', '\\u0121', '\\u1ecf', '\\u06c1', '\\u00d3', '\\u0029', '\\u02bf', '\\u010e', '\\u1e0e', '\\u01d2', '\\u00ff', '\\u00fe', '\\u03a0', '\\u1ebc', '\\u2153', '\\u00e1', '\\u00ca', '\\u012c', '\\u017d', '\\u0110', '\\u266d', '\\u0639', '\\u014e', '\\u1e43', '\\u0026', '\\u20ac', '\\u0024', '\\u011d', '\\u003e', '\\u0163', '\\u0939', '\\u221a', '\\u00e3', '\\u65f6', '\\u0118', '\\u0101', '\\u0628', '\\u221e', '\\u1ed1', '\\u0393', '\\u00c4', '\\u0161', '\\u00e9', '\\u0220', '\\u0115', '\\u002d', '\\u03c0', '\\u0177', '\\u00ba', '\\u0158', '\\u01a7', '\\u0117', '\\u043b', '\\u00d7', '\\u00e2', '\\u0175', '\\u0420', '\\u0391', '\\u00b2', '\\u014c', '\\u2013', '\\u00b9', '\\u1ef7', '\\u064a', '\\u0301', '\\u95a2', '\\u0113', '\\u013b', '\\u094d', '\\u03b5', '\\u1eef', '\\u2c6b', '\\u00da', '\\u00d8', '\\u0432', '\\u0109', '\\u00d9', '\\u00d4', '\\u011b', '\\u0303', '\\u0392', '\\u1eed', '\\u0444', '\\u026a', '\\u0218', '\\u00ef', '\\u1ed9', '\\u00b0', '\\u010c', '\\uac00', '\\u02be', '\\u2012', '\\u5baa', '\\u00e8', '\\u1ebd', '\\u30fb', '\\u0127', '\\u010b', '\\u0131', '\\u1ebb', '\\u0150', '\\u0327', '\\u0100', '\\u1ee7', '\\u1ed7', '\\u0129', '\\u00c3', '\\u003c', '\\u2260', '\\u0106', '\\u6625', '\\u0184', '\\u1eb5', '\\u4fdd', '\\u00b1', '\\u021b', '\\u014b', '\\uff0d', '\\u1e2a', '\\u00e5', '\\u017e', '\\u011a', '\\u1eab', '\\u200e', '\\u1e35', '\\u1e5b', '\\u2192', '\\u0040', '\\u1eb7', '\\u01b2', '\\u5b58', '\\u201c', '\\u015f', '\\u01e6', '\\u0111', '\\u738b', '\\u03a7', '\\u1ead', '\\u1ec7', '\\u0324', '\\u2665', '\\ub9c8', '\\u6bba', '\\u0151', '\\u2661', '\\u03ba', '\\ua784', '\\u2014', '\\u1ee9', '\\u0120', '\\u012a', '\\u7433', '\\u0134', '\\u039a', '\\u1ee3', '\\u1ea5', '\\u1ee5', '\\u0142', '\\u043e', '\\u01eb', '\\u0440', '\\u03a3', '\\u093e', '\\u00d0', '\\u092e', '\\u00b5', '\\u013d', '\\u1ecc', '\\u0394', '\\u00bc', '\\u01d0', '\\u015a', '\\u02b9', '\\u0645', '\\u043d', '\\u00cc'}
2 |
3 | local ascii = {'»', '}', '¡', 'ə', 'ف', 'Θ', 'ý', 'ी', 'ù', 'ʼ', 'ö', 'ø', 'ć', 'و', 'ą', ',', '杨', 'Š', 'ś', 'À', '♯', 'а', 'Ł', 'ả', 'ß', '−', '虐', 'ĭ', 'ṇ', 'Å', '«', 'Ȧ', 'र', 'Ҥ', 'ʼ', 'ر', 'ⁿ', '¿', '‐', 'Ṯ', 'Í', 'Æ', 'ヨ', 'ṣ', '機', 'σ', 'Õ', 'ل', '†', 'Ą', 'Ċ', 'ļ', 'ģ', 'ř', 'ṯ', ':', 'گ', 'ü', '三', '(', 'β', '℃', 'Ƒ', 'μ', 'Ñ', '⁺', '秒', '收', 'ồ', '̩', 'Ɩ', 'û', 'е', 'ư', '~', 'Ṣ', 'Ɓ', 'ầ', '‑', 'ω', '”', 'ť', 'Т', 'ḳ', 'ń', 'ú', 'ổ', 'ز', 'ك', 'ạ', 'Ğ', 'ت', 'î', 'Â', 'ŭ', '=', '∂', '★', 'Ē', '珍', 'Ρ', 'Ƃ', 'Ò', 'œ', 'ů', 'Þ', '£', 'ṅ', 'ự', '久', 'ی', 'Ả', 'Œ', ')', 'ș', 'ї', 'ʃ', 'Ṛ', 'َ', 'Ť', 'Ė', 'Ƌ', 'ể', '´', '+', '版', 'ष', '›', '人', '/', 'Ķ', 'Ƽ', 'Ż', '©', 'Ω', 'ò', '…', 'Ç', '३', 'Ƙ', 'ğ', 'à', 'Ħ', 'Ɗ', 'ở', '^', 'δ', 'ķ', 'ǵ', 'Ḵ', '{', 'ó', 'ǀ', 'ñ', 'ỹ', 'ψ', '루', 'ę', 'ŏ', '~', 'Ŭ', '͘', '功', '☆', 'ä', 'ī', 'É', 'ų', 'य', '套', 'ṉ', 'ề', 'Ɵ', 'ʻ', 'Ι', 'ǁ', 'ϕ', 'ź', 'ḍ', 'ň', 'Ƥ', '¢', 'Ĝ', 'Ẓ', 'Ʊ', 'у', 'ô', '™', '若', '१', 'ắ', 'ľ', 'Ṇ', 'α', '行', '̨', '!', 'ª', 'ō', '.', 'Ë', 'د', 'Ă', 'ŕ', 'Ï', 'ц', 'Ẁ', ';', '永', 'ă', 'Ṭ', '‼', 'Ü', '³', 'Ņ', 'Ģ', '进', 'Ş', 'ż', 'п', 'т', 'ة', '慶', 'ờ', '‘', 'ê', '`', 'Ň', 'ç', 'Ý', 'Ö', 'к', 'ì', 'fi', 'Ĥ', 'ṟ', 'б', 'Ẕ', 'Ẩ', 'Ƴ', 'Ə', 'ا', 'ỳ', 'ट', 'Φ', '李', 'ū', 'Λ', '′', '*', '″', '·', 'Î', '⁵', 'м', '№', 'ṭ', '¾', 'ű', 'г', 'ص', 'ـ', 'ơ', '§', 'Ɲ', '〜', 'õ', 'Ƈ', 'Ỷ', 'Ḍ', 'ễ', 'Ƹ', 'ð', 'ẓ', 'ë', 'ν', 'Ĉ', 'č', '利', '₀', 'È', '黎', 'ग', '藏', 'ớ', 'ḥ', 'ћ', 'ỉ', 'ọ', 'س', '•', 'Ǩ', 'Ɨ', '’', 'С', 'ị', '餐', '스', 'ḱ', 'Á', 'ẩ', 'ừ', 'Ǵ', 'ǂ', 'ņ', 'Ţ', 'fl', 'Ƭ', '%', 'Ŝ', 'ǎ', 'ǃ', 'Ź', 'ǥ', '士', '瑞', '(', 'Ū', 'Υ', 'Μ', '½', 'ũ', '嗷', '角', 'æ', '青', '[', '?', 'П', 'ک', '\'', 'ế', 'ن', 'İ', 'ằ', 'Ε', 'γ', 'ǔ', 'í', 'К', 'ʼn', 'Ń', 'ď', 'Ḥ', 'ġ', 'ỏ', 'ہ', 'Ó', ')', 'ʿ', 'Ď', 'Ḏ', 'ǒ', 'ÿ', 'þ', 'Π', 'Ẽ', '⅓', 'á', 'Ê', 'Ĭ', 'Ž', 'Đ', '♭', 'ع', 'Ŏ', 'ṃ', '&', '€', '$', 'ĝ', '>', 'ţ', 'ह', '√', 'ã', '时', 'Ę', 'ā', 'ب', '∞', 'ố', 'Γ', 'Ä', 'š', 'é', 'Ƞ', 'ĕ', '-', 'π', 'ŷ', 'º', 'Ř', 'Ƨ', 'ė', 'л', '×', 'â', 'ŵ', 'Р', 'Α', '²', 'Ō', '–', '¹', 'ỷ', 'ي', '́', '関', 'ē', 'Ļ', '्', 'ε', 'ữ', 'Ⱬ', 'Ú', 'Ø', 'в', 'ĉ', 'Ù', 'Ô', 'ě', '̃', 'Β', 'ử', 'ф', 'ɪ', 'Ș', 'ï', 'ộ', '°', 'Č', '가', 'ʾ', '‒', '宪', 'è', 'ẽ', '・', 'ħ', 'ċ', 'ı', 'ẻ', 'Ő', '̧', 'Ā', 'ủ', 'ỗ', 'ĩ', 'Ã', '<', '≠', 'Ć', '春', 'Ƅ', 'ẵ', '保', '±', 'ț', 'ŋ', '-', 'Ḫ', 'å', 'ž', 'Ě', 'ẫ', '', 'ḵ', 'ṛ', '→', '@', 'ặ', 'Ʋ', '存', '“', 'ş', 'Ǧ', 'đ', '王', 'Χ', 'ậ', 'ệ', '̤', '♥', '마', '殺', 'ő', '♡', 'κ', 'Ꞅ', '—', 'ứ', 'Ġ', 'Ī', '琳', 'Ĵ', 'Κ', 'ợ', 'ấ', 'ụ', 'ł', 'о', 'ǫ', 'р', 'Σ', 'ा', 'Ð', 'म', 'µ', 'Ľ', 'Ọ', 'Δ', '¼', 'ǐ', 'Ś', 'ʹ', 'م', 'н', 'Ì'}
4 |
5 | dofile 'utils/utils.lua'
6 | assert(table_len(unicode) == table_len(ascii))
7 |
8 | unicode2ascii = {}
9 | for i,letter in pairs(ascii) do
10 | unicode2ascii[unicode[i]] = ascii[i]
11 | unicode2ascii['\\u0022'] = '"'
12 | unicode2ascii['\\u0023'] = '#'
13 | unicode2ascii['\\u005c'] = '\\'
14 | unicode2ascii['\\u00a0'] = ''
15 | end
16 |
17 |
--------------------------------------------------------------------------------
/data_gen/gen_test_train_data/gen_ace_msnbc_aquaint_csv.lua:
--------------------------------------------------------------------------------
1 | -- Generate test data from the ACE/MSNBC/AQUAINT datasets by keeping the context and
2 | -- entity candidates for each annotated mention
3 |
4 | -- Format:
5 | -- doc_name \t doc_name \t mention \t left_ctxt \t right_ctxt \t CANDIDATES \t [ent_wikiid,p_e_m,ent_name]+ \t GT: \t pos,ent_wikiid,p_e_m,ent_name
6 |
7 | -- Stats:
8 | --cat wned-ace2004.csv | wc -l
9 | --257
10 | --cat wned-ace2004.csv | grep -P 'GT:\t-1' | wc -l
11 | --20
12 | --cat wned-ace2004.csv | grep -P 'GT:\t1,' | wc -l
13 | --217
14 |
15 | --cat wned-aquaint.csv | wc -l
16 | --727
17 | --cat wned-aquaint.csv | grep -P 'GT:\t-1' | wc -l
18 | --33
19 | --cat wned-aquaint.csv | grep -P 'GT:\t1,' | wc -l
20 | --604
21 |
22 | --cat wned-msnbc.csv | wc -l
23 | --656
24 | --cat wned-msnbc.csv | grep -P 'GT:\t-1' | wc -l
25 | --22
26 | --cat wned-msnbc.csv | grep -P 'GT:\t1,' | wc -l
27 | --496
28 |
29 | if not ent_p_e_m_index then
30 | require 'torch'
31 | dofile 'data_gen/indexes/wiki_redirects_index.lua'
32 | dofile 'data_gen/indexes/yago_crosswikis_wiki.lua'
33 | dofile 'utils/utils.lua'
34 | end
35 |
36 | tds = tds or require 'tds'
37 |
38 | local function gen_test_ace(dataset)
39 |
40 | print('\nGenerating test data from ' .. dataset .. ' set ')
41 |
42 | local path = opt.root_data_dir .. 'basic_data/test_datasets/wned-datasets/' .. dataset .. '/'
43 |
44 | out_file = opt.root_data_dir .. 'generated/test_train_data/wned-' .. dataset .. '.csv'
45 | ouf = assert(io.open(out_file, "w"))
46 |
47 | annotations, _ = io.open(path .. dataset .. '.xml')
48 |
49 | local num_nonexistent_ent_id = 0
50 | local num_correct_ents = 0
51 |
52 | local cur_doc_text = ''
53 | local cur_doc_name = ''
54 |
55 | local line = annotations:read()
56 | while (line) do
57 | if (not line:find('document docName=\"')) then
58 | if line:find('') then
59 | line = annotations:read()
60 | local x,y = line:find('')
61 | local z,t = line:find('')
62 | local cur_mention = line:sub(y + 1, z - 1)
63 | cur_mention = string.gsub(cur_mention, '&', '&')
64 |
65 | line = annotations:read()
66 | x,y = line:find('')
67 | z,t = line:find('')
68 | cur_ent_title = ''
69 | if not line:find('') then
70 | cur_ent_title = line:sub(y + 1, z - 1)
71 | end
72 |
73 | line = annotations:read()
74 | x,y = line:find('')
75 | z,t = line:find('')
76 | local offset = 1 + tonumber(line:sub(y + 1, z - 1))
77 |
78 | line = annotations:read()
79 | x,y = line:find('')
80 | z,t = line:find('')
81 | local length = tonumber(line:sub(y + 1, z - 1))
82 | length = cur_mention:len()
83 |
84 | line = annotations:read()
85 | if line:find('') then
86 | line = annotations:read()
87 | end
88 |
89 | assert(line:find(''))
90 |
91 | offset = math.max(1, offset - 10)
92 | while (cur_doc_text:sub(offset, offset + length - 1) ~= cur_mention) do
93 | -- print(cur_mention .. ' ---> ' .. cur_doc_text:sub(offset, offset + length - 1))
94 | offset = offset + 1
95 | end
96 |
97 | cur_mention = preprocess_mention(cur_mention)
98 |
99 | if cur_ent_title ~= 'NIL' and cur_ent_title ~= '' and cur_ent_title:len() > 0 then
100 | local cur_ent_wikiid = get_ent_wikiid_from_name(cur_ent_title)
101 | if cur_ent_wikiid == unk_ent_wikiid then
102 | num_nonexistent_ent_id = num_nonexistent_ent_id + 1
103 | print(green(cur_ent_title))
104 | else
105 | num_correct_ents = num_correct_ents + 1
106 | end
107 |
108 | assert(cur_mention:len() > 0)
109 | local str = cur_doc_name .. '\t' .. cur_doc_name .. '\t' .. cur_mention .. '\t'
110 |
111 | local left_words = split_in_words(cur_doc_text:sub(1, offset - 1))
112 | local num_left_words = table_len(left_words)
113 | local left_ctxt = {}
114 | for i = math.max(1, num_left_words - 100 + 1), num_left_words do
115 | table.insert(left_ctxt, left_words[i])
116 | end
117 | if table_len(left_ctxt) == 0 then
118 | table.insert(left_ctxt, 'EMPTYCTXT')
119 | end
120 | str = str .. table.concat(left_ctxt, ' ') .. '\t'
121 |
122 | local right_words = split_in_words(cur_doc_text:sub(offset + length))
123 | local num_right_words = table_len(right_words)
124 | local right_ctxt = {}
125 | for i = 1, math.min(num_right_words, 100) do
126 | table.insert(right_ctxt, right_words[i])
127 | end
128 | if table_len(right_ctxt) == 0 then
129 | table.insert(right_ctxt, 'EMPTYCTXT')
130 | end
131 | str = str .. table.concat(right_ctxt, ' ') .. '\tCANDIDATES\t'
132 |
133 |
134 | -- Entity candidates from p(e|m) dictionary
135 | if ent_p_e_m_index[cur_mention] and #(ent_p_e_m_index[cur_mention]) > 0 then
136 |
137 | local sorted_cand = {}
138 | for ent_wikiid,p in pairs(ent_p_e_m_index[cur_mention]) do
139 | table.insert(sorted_cand, {ent_wikiid = ent_wikiid, p = p})
140 | end
141 | table.sort(sorted_cand, function(a,b) return a.p > b.p end)
142 |
143 | local candidates = {}
144 | local gt_pos = -1
145 | for pos,e in pairs(sorted_cand) do
146 | if pos <= 100 then
147 | table.insert(candidates, e.ent_wikiid .. ',' .. string.format("%.3f", e.p) .. ',' .. get_ent_name_from_wikiid(e.ent_wikiid))
148 | if e.ent_wikiid == cur_ent_wikiid then
149 | gt_pos = pos
150 | end
151 | else
152 | break
153 | end
154 | end
155 | str = str .. table.concat(candidates, '\t') .. '\tGT:\t'
156 |
157 | if gt_pos > 0 then
158 | ouf:write(str .. gt_pos .. ',' .. candidates[gt_pos] .. '\n')
159 | else
160 | if cur_ent_wikiid ~= unk_ent_wikiid then
161 | ouf:write(str .. '-1,' .. cur_ent_wikiid .. ',' .. cur_ent_title .. '\n')
162 | else
163 | ouf:write(str .. '-1\n')
164 | end
165 | end
166 | else
167 | if cur_ent_wikiid ~= unk_ent_wikiid then
168 | ouf:write(str .. 'EMPTYCAND\tGT:\t-1,' .. cur_ent_wikiid .. ',' .. cur_ent_title .. '\n')
169 | else
170 | ouf:write(str .. 'EMPTYCAND\tGT:\t-1\n')
171 | end
172 | end
173 |
174 | end
175 | end
176 | else
177 | local x,y = line:find('document docName=\"')
178 | local z,t = line:find('\">')
179 | cur_doc_name = line:sub(y + 1, z - 1)
180 | cur_doc_name = string.gsub(cur_doc_name, '&', '&')
181 |
182 | local it,_ = io.open(path .. 'RawText/' .. cur_doc_name)
183 | cur_doc_text = ''
184 | local cur_line = it:read()
185 | while cur_line do
186 | cur_doc_text = cur_doc_text .. cur_line .. ' '
187 | cur_line = it:read()
188 | end
189 | cur_doc_text = string.gsub(cur_doc_text, '&', '&')
190 | end
191 | line = annotations:read()
192 | end
193 |
194 | ouf:flush()
195 | io.close(ouf)
196 |
197 | print('Done ' .. dataset .. '.')
198 | print('num_nonexistent_ent_id = ' .. num_nonexistent_ent_id .. '; num_correct_ents = ' .. num_correct_ents)
199 | end
200 |
201 |
202 | gen_test_ace('wikipedia')
203 | gen_test_ace('clueweb')
204 | gen_test_ace('ace2004')
205 | gen_test_ace('msnbc')
206 | gen_test_ace('aquaint')
207 |
--------------------------------------------------------------------------------
/data_gen/gen_test_train_data/gen_aida_test.lua:
--------------------------------------------------------------------------------
1 | -- Generate test data from the AIDA dataset by keeping the context and
2 | -- entity candidates for each annotated mention
3 |
4 | -- Format:
5 | -- doc_name \t doc_name \t mention \t left_ctxt \t right_ctxt \t CANDIDATES \t [ent_wikiid,p_e_m,ent_name]+ \t GT: \t pos,ent_wikiid,p_e_m,ent_name
6 |
7 | -- Stats:
8 | --cat aida_testA.csv | wc -l
9 | --4791
10 | --cat aida_testA.csv | grep -P 'GT:\t-1' | wc -l
11 | --43
12 | --cat aida_testA.csv | grep -P 'GT:\t1,' | wc -l
13 | --3401
14 |
15 | --cat aida_testB.csv | wc -l
16 | --4485
17 | --cat aida_testB.csv | grep -P 'GT:\t-1' | wc -l
18 | --19
19 | --cat aida_testB.csv | grep -P 'GT:\t1,' | wc -l
20 | --3084
21 |
22 | if not ent_p_e_m_index then
23 | require 'torch'
24 | dofile 'data_gen/indexes/wiki_redirects_index.lua'
25 | dofile 'data_gen/indexes/yago_crosswikis_wiki.lua'
26 | dofile 'utils/utils.lua'
27 | end
28 |
29 | tds = tds or require 'tds'
30 |
31 | print('\nGenerating test data from AIDA set ')
32 |
33 | it, _ = io.open(opt.root_data_dir .. 'basic_data/test_datasets/AIDA/testa_testb_aggregate_original')
34 |
35 | out_file_A = opt.root_data_dir .. 'generated/test_train_data/aida_testA.csv'
36 | out_file_B = opt.root_data_dir .. 'generated/test_train_data/aida_testB.csv'
37 |
38 | ouf_A = assert(io.open(out_file_A, "w"))
39 | ouf_B = assert(io.open(out_file_B, "w"))
40 |
41 | local ouf = ouf_A
42 |
43 | local num_nme = 0
44 | local num_nonexistent_ent_title = 0
45 | local num_nonexistent_ent_id = 0
46 | local num_nonexistent_both = 0
47 | local num_correct_ents = 0
48 | local num_total_ents = 0
49 |
50 | local cur_words_num = 0
51 | local cur_words = {}
52 | local cur_mentions = {}
53 | local cur_mentions_num = 0
54 |
55 | local cur_doc_name = ''
56 |
57 | local function write_results()
58 | -- Write results:
59 | if cur_doc_name ~= '' then
60 | local header = cur_doc_name .. '\t' .. cur_doc_name .. '\t'
61 | for _, hyp in pairs(cur_mentions) do
62 | assert(hyp.mention:len() > 0, line)
63 | local mention = hyp.mention
64 | local str = header .. hyp.mention .. '\t'
65 |
66 | local left_ctxt = {}
67 | for i = math.max(0, hyp.start_off - 100), hyp.start_off - 1 do
68 | table.insert(left_ctxt, cur_words[i])
69 | end
70 | if table_len(left_ctxt) == 0 then
71 | table.insert(left_ctxt, 'EMPTYCTXT')
72 | end
73 | str = str .. table.concat(left_ctxt, ' ') .. '\t'
74 |
75 | local right_ctxt = {}
76 | for i = hyp.end_off + 1, math.min(cur_words_num, hyp.end_off + 100) do
77 | table.insert(right_ctxt, cur_words[i])
78 | end
79 | if table_len(right_ctxt) == 0 then
80 | table.insert(right_ctxt, 'EMPTYCTXT')
81 | end
82 | str = str .. table.concat(right_ctxt, ' ') .. '\tCANDIDATES\t'
83 |
84 | -- Entity candidates from p(e|m) dictionary
85 | if ent_p_e_m_index[mention] and #(ent_p_e_m_index[mention]) > 0 then
86 |
87 | local sorted_cand = {}
88 | for ent_wikiid,p in pairs(ent_p_e_m_index[mention]) do
89 | table.insert(sorted_cand, {ent_wikiid = ent_wikiid, p = p})
90 | end
91 | table.sort(sorted_cand, function(a,b) return a.p > b.p end)
92 |
93 | local candidates = {}
94 | local gt_pos = -1
95 | for pos,e in pairs(sorted_cand) do
96 | if pos <= 100 then
97 | table.insert(candidates, e.ent_wikiid .. ',' .. string.format("%.3f", e.p) .. ',' .. get_ent_name_from_wikiid(e.ent_wikiid))
98 | if e.ent_wikiid == hyp.ent_wikiid then
99 | gt_pos = pos
100 | end
101 | else
102 | break
103 | end
104 | end
105 | str = str .. table.concat(candidates, '\t') .. '\tGT:\t'
106 |
107 | if gt_pos > 0 then
108 | ouf:write(str .. gt_pos .. ',' .. candidates[gt_pos] .. '\n')
109 | else
110 | if hyp.ent_wikiid ~= unk_ent_wikiid then
111 | ouf:write(str .. '-1,' .. hyp.ent_wikiid .. ',' .. get_ent_name_from_wikiid(hyp.ent_wikiid) .. '\n')
112 | else
113 | ouf:write(str .. '-1\n')
114 | end
115 | end
116 | else
117 | if hyp.ent_wikiid ~= unk_ent_wikiid then
118 | ouf:write(str .. 'EMPTYCAND\tGT:\t-1,' .. hyp.ent_wikiid .. ',' .. get_ent_name_from_wikiid(hyp.ent_wikiid) .. '\n')
119 | else
120 | ouf:write(str .. 'EMPTYCAND\tGT:\t-1\n')
121 | end
122 | end
123 | end
124 | end
125 | end
126 |
127 |
128 |
129 | local line = it:read()
130 | while (line) do
131 | if (not line:find('-DOCSTART-')) then
132 | local parts = split(line, '\t')
133 | local num_parts = table_len(parts)
134 | assert(num_parts == 0 or num_parts == 1 or num_parts == 4 or num_parts == 7 or num_parts == 6, line)
135 | if num_parts > 0 then
136 | if num_parts == 4 and parts[2] == 'B' then
137 | num_nme = num_nme + 1
138 | end
139 |
140 | if (num_parts == 7 or num_parts == 6) and parts[2] == 'B' then
141 |
142 | -- Find current mention. A few hacks here.
143 | local cur_mention = preprocess_mention(parts[3])
144 |
145 | local x,y = parts[5]:find('/wiki/')
146 | local cur_ent_title = parts[5]:sub(y + 1)
147 | local cur_ent_wikiid = tonumber(parts[6])
148 | local index_ent_title = get_ent_name_from_wikiid(cur_ent_wikiid)
149 | local index_ent_wikiid = get_ent_wikiid_from_name(cur_ent_title)
150 |
151 | local final_ent_wikiid = index_ent_wikiid
152 | if final_ent_wikiid == unk_ent_wikiid then
153 | final_ent_wikiid = cur_ent_wikiid
154 | end
155 |
156 | if (index_ent_title == cur_ent_title and cur_ent_wikiid == index_ent_wikiid) then
157 | num_correct_ents = num_correct_ents + 1
158 | elseif (index_ent_title ~= cur_ent_title and cur_ent_wikiid ~= index_ent_wikiid) then
159 | num_nonexistent_both = num_nonexistent_both + 1
160 | elseif index_ent_title ~= cur_ent_title then
161 | assert(cur_ent_wikiid == index_ent_wikiid)
162 | num_nonexistent_ent_title = num_nonexistent_ent_title + 1
163 | else
164 | assert(index_ent_title == cur_ent_title)
165 | assert(cur_ent_wikiid ~= index_ent_wikiid)
166 | num_nonexistent_ent_id = num_nonexistent_ent_id + 1
167 | end
168 |
169 | num_total_ents = num_total_ents + 1 -- Keep even incorrect links
170 |
171 | cur_mentions_num = cur_mentions_num + 1
172 | cur_mentions[cur_mentions_num] = {}
173 | cur_mentions[cur_mentions_num].mention = cur_mention
174 | cur_mentions[cur_mentions_num].ent_wikiid = final_ent_wikiid
175 | cur_mentions[cur_mentions_num].start_off = cur_words_num + 1
176 | cur_mentions[cur_mentions_num].end_off = cur_words_num + table_len(split(parts[3], ' '))
177 | end
178 |
179 | local words_on_this_line = split_in_words(parts[1])
180 | for _,w in pairs(words_on_this_line) do
181 | table.insert(cur_words, modify_uppercase_phrase(w))
182 | cur_words_num = cur_words_num + 1
183 | end
184 | end
185 |
186 | else
187 | assert(line:find('-DOCSTART-'))
188 | write_results()
189 |
190 | if cur_doc_name:find('testa') and line:find('testb') then
191 | ouf = ouf_B
192 | print('Done validation testA : ')
193 | print('num_nme = ' .. num_nme .. '; num_nonexistent_ent_title = ' .. num_nonexistent_ent_title)
194 | print('num_nonexistent_ent_id = ' .. num_nonexistent_ent_id .. '; num_nonexistent_both = ' .. num_nonexistent_both)
195 | print('num_correct_ents = ' .. num_correct_ents .. '; num_total_ents = ' .. num_total_ents)
196 | end
197 |
198 | local words = split_in_words(line)
199 | for _,w in pairs(words) do
200 | if w:find('testa') or w:find('testb') then
201 | cur_doc_name = w
202 | break
203 | end
204 | end
205 | cur_words = {}
206 | cur_words_num = 0
207 | cur_mentions = {}
208 | cur_mentions_num = 0
209 | end
210 |
211 | line = it:read()
212 | end
213 |
214 | write_results()
215 |
216 | ouf_A:flush()
217 | io.close(ouf_A)
218 | ouf_B:flush()
219 | io.close(ouf_B)
220 |
221 |
222 | print(' Done AIDA.')
223 | print('num_nme = ' .. num_nme .. '; num_nonexistent_ent_title = ' .. num_nonexistent_ent_title)
224 | print('num_nonexistent_ent_id = ' .. num_nonexistent_ent_id .. '; num_nonexistent_both = ' .. num_nonexistent_both)
225 | print('num_correct_ents = ' .. num_correct_ents .. '; num_total_ents = ' .. num_total_ents)
226 |
--------------------------------------------------------------------------------
/data_gen/gen_test_train_data/gen_aida_train.lua:
--------------------------------------------------------------------------------
1 | -- Generate train data from the AIDA dataset by keeping the context and
2 | -- entity candidates for each annotated mention
3 |
4 | -- Format:
5 | -- doc_name \t doc_name \t mention \t left_ctxt \t right_ctxt \t CANDIDATES \t [ent_wikiid,p_e_m,ent_name]+ \t GT: \t pos,ent_wikiid,p_e_m,ent_name
6 |
7 | if not ent_p_e_m_index then
8 | require 'torch'
9 | dofile 'data_gen/indexes/wiki_redirects_index.lua'
10 | dofile 'data_gen/indexes/yago_crosswikis_wiki.lua'
11 | dofile 'utils/utils.lua'
12 | tds = tds or require 'tds'
13 | end
14 |
15 | print('\nGenerating train data from AIDA set ')
16 |
17 | it, _ = io.open(opt.root_data_dir .. 'basic_data/test_datasets/AIDA/aida_train.txt')
18 |
19 | out_file = opt.root_data_dir .. 'generated/test_train_data/aida_train.csv'
20 | ouf = assert(io.open(out_file, "w"))
21 |
22 | local num_nme = 0
23 | local num_nonexistent_ent_title = 0
24 | local num_nonexistent_ent_id = 0
25 | local num_nonexistent_both = 0
26 | local num_correct_ents = 0
27 | local num_total_ents = 0
28 |
29 | local cur_words_num = 0
30 | local cur_words = {}
31 | local cur_mentions = {}
32 | local cur_mentions_num = 0
33 |
34 | local cur_doc_name = ''
35 |
36 | local function write_results()
37 | -- Write results:
38 | if cur_doc_name ~= '' then
39 | local header = cur_doc_name .. '\t' .. cur_doc_name .. '\t'
40 | for _, hyp in pairs(cur_mentions) do
41 | assert(hyp.mention:len() > 0, line)
42 | local str = header .. hyp.mention .. '\t'
43 |
44 | local left_ctxt = {}
45 | for i = math.max(0, hyp.start_off - 100), hyp.start_off - 1 do
46 | table.insert(left_ctxt, cur_words[i])
47 | end
48 | if table_len(left_ctxt) == 0 then
49 | table.insert(left_ctxt, 'EMPTYCTXT')
50 | end
51 | str = str .. table.concat(left_ctxt, ' ') .. '\t'
52 |
53 | local right_ctxt = {}
54 | for i = hyp.end_off + 1, math.min(cur_words_num, hyp.end_off + 100) do
55 | table.insert(right_ctxt, cur_words[i])
56 | end
57 | if table_len(right_ctxt) == 0 then
58 | table.insert(right_ctxt, 'EMPTYCTXT')
59 | end
60 | str = str .. table.concat(right_ctxt, ' ') .. '\tCANDIDATES\t'
61 |
62 | -- Entity candidates from p(e|m) dictionary
63 | if ent_p_e_m_index[hyp.mention] and #(ent_p_e_m_index[hyp.mention]) > 0 then
64 |
65 | local sorted_cand = {}
66 | for ent_wikiid,p in pairs(ent_p_e_m_index[hyp.mention]) do
67 | table.insert(sorted_cand, {ent_wikiid = ent_wikiid, p = p})
68 | end
69 | table.sort(sorted_cand, function(a,b) return a.p > b.p end)
70 |
71 | local candidates = {}
72 | local gt_pos = -1
73 | for pos,e in pairs(sorted_cand) do
74 | if pos <= 100 then
75 | table.insert(candidates, e.ent_wikiid .. ',' .. string.format("%.3f", e.p) .. ',' .. get_ent_name_from_wikiid(e.ent_wikiid))
76 | if e.ent_wikiid == hyp.ent_wikiid then
77 | gt_pos = pos
78 | end
79 | else
80 | break
81 | end
82 | end
83 | str = str .. table.concat(candidates, '\t') .. '\tGT:\t'
84 |
85 | if gt_pos > 0 then
86 | ouf:write(str .. gt_pos .. ',' .. candidates[gt_pos] .. '\n')
87 | end
88 | end
89 | end
90 | end
91 | end
92 |
93 |
94 |
95 | local line = it:read()
96 | while (line) do
97 | if (not line:find('-DOCSTART-')) then
98 | local parts = split(line, '\t')
99 | local num_parts = table_len(parts)
100 | assert(num_parts == 0 or num_parts == 1 or num_parts == 4 or num_parts == 7 or num_parts == 6, line)
101 | if num_parts > 0 then
102 | if num_parts == 4 and parts[2] == 'B' then
103 | num_nme = num_nme + 1
104 | end
105 |
106 | if (num_parts == 7 or num_parts == 6) and parts[2] == 'B' then
107 |
108 | -- Find current mention. A few hacks here.
109 |
110 | local cur_mention = preprocess_mention(parts[3])
111 |
112 | local x,y = parts[5]:find('/wiki/')
113 | local cur_ent_title = parts[5]:sub(y + 1)
114 | local cur_ent_wikiid = tonumber(parts[6])
115 | local index_ent_title = get_ent_name_from_wikiid(cur_ent_wikiid)
116 | local index_ent_wikiid = get_ent_wikiid_from_name(cur_ent_title)
117 |
118 | local final_ent_wikiid = index_ent_wikiid
119 | if final_ent_wikiid == unk_ent_wikiid then
120 | final_ent_wikiid = cur_ent_wikiid
121 | end
122 |
123 | if (index_ent_title == cur_ent_title and cur_ent_wikiid == index_ent_wikiid) then
124 | num_correct_ents = num_correct_ents + 1
125 | elseif (index_ent_title ~= cur_ent_title and cur_ent_wikiid ~= index_ent_wikiid) then
126 | num_nonexistent_both = num_nonexistent_both + 1
127 | elseif index_ent_title ~= cur_ent_title then
128 | assert(cur_ent_wikiid == index_ent_wikiid)
129 | num_nonexistent_ent_title = num_nonexistent_ent_title + 1
130 | else
131 | assert(index_ent_title == cur_ent_title)
132 | assert(cur_ent_wikiid ~= index_ent_wikiid)
133 | num_nonexistent_ent_id = num_nonexistent_ent_id + 1
134 | end
135 |
136 | num_total_ents = num_total_ents + 1 -- Keep even incorrect links
137 |
138 | cur_mentions_num = cur_mentions_num + 1
139 | cur_mentions[cur_mentions_num] = {}
140 | cur_mentions[cur_mentions_num].mention = cur_mention
141 | cur_mentions[cur_mentions_num].ent_wikiid = final_ent_wikiid
142 | cur_mentions[cur_mentions_num].start_off = cur_words_num + 1
143 | cur_mentions[cur_mentions_num].end_off = cur_words_num + table_len(split(parts[3], ' '))
144 | end
145 |
146 | local words_on_this_line = split_in_words(parts[1])
147 | for _,w in pairs(words_on_this_line) do
148 | table.insert(cur_words, modify_uppercase_phrase(w))
149 | cur_words_num = cur_words_num + 1
150 | end
151 | end
152 |
153 | else
154 | assert(line:find('-DOCSTART-'))
155 | write_results()
156 |
157 | cur_doc_name = line:sub(13)
158 |
159 | cur_words = {}
160 | cur_words_num = 0
161 | cur_mentions = {}
162 | cur_mentions_num = 0
163 | end
164 |
165 | line = it:read()
166 | end
167 |
168 | write_results()
169 |
170 | ouf:flush()
171 | io.close(ouf)
172 |
173 |
174 | print(' Done AIDA.')
175 | print('num_nme = ' .. num_nme .. '; num_nonexistent_ent_title = ' .. num_nonexistent_ent_title)
176 | print('num_nonexistent_ent_id = ' .. num_nonexistent_ent_id .. '; num_nonexistent_both = ' .. num_nonexistent_both)
177 | print('num_correct_ents = ' .. num_correct_ents .. '; num_total_ents = ' .. num_total_ents)
178 |
--------------------------------------------------------------------------------
/data_gen/gen_test_train_data/gen_all.lua:
--------------------------------------------------------------------------------
1 | -- Generates all training and test data for entity disambiguation.
2 |
3 | if not ent_p_e_m_index then
4 | require 'torch'
5 | dofile 'data_gen/indexes/wiki_redirects_index.lua'
6 | dofile 'data_gen/indexes/yago_crosswikis_wiki.lua'
7 | dofile 'utils/utils.lua'
8 | tds = tds or require 'tds'
9 | end
10 |
11 | dofile 'data_gen/gen_test_train_data/gen_aida_test.lua'
12 | dofile 'data_gen/gen_test_train_data/gen_aida_train.lua'
13 | dofile 'data_gen/gen_test_train_data/gen_ace_msnbc_aquaint_csv.lua'
--------------------------------------------------------------------------------
/data_gen/gen_wiki_data/gen_ent_wiki_w_repr.lua:
--------------------------------------------------------------------------------
1 | if not opt then
2 | cmd = torch.CmdLine()
3 | cmd:option('-root_data_dir', '', 'Root path of the data, $DATA_PATH.')
4 | cmd:text()
5 | opt = cmd:parse(arg or {})
6 | assert(opt.root_data_dir ~= '', 'Specify a valid root_data_dir path argument.')
7 | end
8 |
9 |
10 | require 'torch'
11 | dofile 'utils/utils.lua'
12 | dofile 'data_gen/parse_wiki_dump/parse_wiki_dump_tools.lua'
13 | dofile 'entities/ent_name2id_freq/e_freq_index.lua'
14 | tds = tds or require 'tds'
15 |
16 | print('\nExtracting text only from Wiki dump. Output is wiki_canonical_words.txt containing on each line an Wiki entity with the list of all words in its canonical Wiki page.')
17 |
18 | it, _ = io.open(opt.root_data_dir .. 'basic_data/textWithAnchorsFromAllWikipedia2014Feb.txt')
19 |
20 | out_file = opt.root_data_dir .. 'generated/wiki_canonical_words.txt'
21 | ouf = assert(io.open(out_file, "w"))
22 |
23 | line = it:read()
24 |
25 | -- Find anchors, e.g. anarchism
26 | local num_lines = 0
27 | local num_valid_ents = 0
28 | local num_error_ents = 0 -- Probably list or disambiguation pages.
29 |
30 | local empty_valid_ents = get_map_all_valid_ents()
31 |
32 | local cur_words = ''
33 | local cur_ent_wikiid = -1
34 |
35 | while (line) do
36 | num_lines = num_lines + 1
37 | if num_lines % 5000000 == 0 then
38 | print('Processed ' .. num_lines .. ' lines. Num valid ents = ' .. num_valid_ents .. '. Num errs = ' .. num_error_ents)
39 | end
40 |
41 | if (not line:find(' 0 and cur_words ~= '') then
48 | if cur_ent_wikiid ~= unk_ent_wikiid and is_valid_ent(cur_ent_wikiid) then
49 | ouf:write(cur_ent_wikiid .. '\t' .. get_ent_name_from_wikiid(cur_ent_wikiid) .. '\t' .. cur_words .. '\n')
50 | empty_valid_ents[cur_ent_wikiid] = nil
51 | num_valid_ents = num_valid_ents + 1
52 | else
53 | num_error_ents = num_error_ents + 1
54 | end
55 | end
56 |
57 | cur_ent_wikiid = extract_page_entity_title(line)
58 | cur_words = ''
59 | end
60 |
61 | line = it:read()
62 | end
63 | ouf:flush()
64 | io.close(ouf)
65 |
66 | -- Num valid ents = 4126137. Num errs = 332944
67 | print(' Done extracting text only from Wiki dump. Num valid ents = ' .. num_valid_ents .. '. Num errs = ' .. num_error_ents)
68 |
69 |
70 | print('Create file with all entities with empty Wikipedia pages.')
71 | local empty_ents = {}
72 | for ent_wikiid, _ in pairs(empty_valid_ents) do
73 | table.insert(empty_ents, {ent_wikiid = ent_wikiid, f = get_ent_freq(ent_wikiid)})
74 | end
75 | table.sort(empty_ents, function(a,b) return a.f > b.f end)
76 |
77 | local ouf2 = assert(io.open(opt.root_data_dir .. 'generated/empty_page_ents.txt', "w"))
78 | for _,x in pairs(empty_ents) do
79 | ouf2:write(x.ent_wikiid .. '\t' .. get_ent_name_from_wikiid(x.ent_wikiid) .. '\t' .. x.f .. '\n')
80 | end
81 | ouf2:flush()
82 | io.close(ouf2)
83 | print(' Done')
--------------------------------------------------------------------------------
/data_gen/gen_wiki_data/gen_wiki_hyp_train_data.lua:
--------------------------------------------------------------------------------
1 | -- Generate training data from Wikipedia hyperlinks by keeping the context and
2 | -- entity candidates for each hyperlink
3 |
4 | -- Format:
5 | -- ent_wikiid \t ent_name \t mention \t left_ctxt \t right_ctxt \t CANDIDATES \t [ent_wikiid,p_e_m,ent_name]+ \t GT: \t pos,ent_wikiid,p_e_m,ent_name
6 |
7 | if not opt then
8 | cmd = torch.CmdLine()
9 | cmd:option('-root_data_dir', '', 'Root path of the data, $DATA_PATH.')
10 | cmd:text()
11 | opt = cmd:parse(arg or {})
12 | assert(opt.root_data_dir ~= '', 'Specify a valid root_data_dir path argument.')
13 | end
14 |
15 | require 'torch'
16 | dofile 'data_gen/parse_wiki_dump/parse_wiki_dump_tools.lua'
17 | dofile 'data_gen/indexes/yago_crosswikis_wiki.lua'
18 | tds = tds or require 'tds'
19 |
20 | print('\nGenerating training data from Wiki dump')
21 |
22 | it, _ = io.open(opt.root_data_dir .. 'basic_data/textWithAnchorsFromAllWikipedia2014Feb.txt')
23 |
24 | out_file = opt.root_data_dir .. 'generated/wiki_hyperlink_contexts.csv'
25 | ouf = assert(io.open(out_file, "w"))
26 |
27 | -- Find anchors, e.g. anarchism
28 | local num_lines = 0
29 | local num_valid_hyp = 0
30 |
31 | local cur_words_num = 0
32 | local cur_words = {}
33 | local cur_mentions = {}
34 | local cur_mentions_num = 0
35 | local cur_ent_wikiid = -1
36 |
37 | local line = it:read()
38 | while (line) do
39 | num_lines = num_lines + 1
40 | if num_lines % 1000000 == 0 then
41 | print('Processed ' .. num_lines .. ' lines. Num valid hyp = ' .. num_valid_hyp)
42 | end
43 |
44 | -- If it's a line from the Wiki page, add its text words and its hyperlinks
45 | if (not line:find(' 0 then
88 | assert(hyp.mention:len() > 0, line)
89 | local str = header .. hyp.mention .. '\t'
90 |
91 | local left_ctxt = {}
92 | for i = math.max(0, hyp.start_off - 100), hyp.start_off - 1 do
93 | table.insert(left_ctxt, cur_words[i])
94 | end
95 | if table_len(left_ctxt) == 0 then
96 | table.insert(left_ctxt, 'EMPTYCTXT')
97 | end
98 | str = str .. table.concat(left_ctxt, ' ') .. '\t'
99 |
100 | local right_ctxt = {}
101 | for i = hyp.end_off + 1, math.min(cur_words_num, hyp.end_off + 100) do
102 | table.insert(right_ctxt, cur_words[i])
103 | end
104 | if table_len(right_ctxt) == 0 then
105 | table.insert(right_ctxt, 'EMPTYCTXT')
106 | end
107 | str = str .. table.concat(right_ctxt, ' ') .. '\tCANDIDATES\t'
108 |
109 | -- Entity candidates from p(e|m) dictionary
110 | local unsorted_cand = {}
111 | for ent_wikiid,p in pairs(ent_p_e_m_index[hyp.mention]) do
112 | table.insert(unsorted_cand, {ent_wikiid = ent_wikiid, p = p})
113 | end
114 | table.sort(unsorted_cand, function(a,b) return a.p > b.p end)
115 |
116 | local candidates = {}
117 | local gt_pos = -1
118 | for pos,e in pairs(unsorted_cand) do
119 | if pos <= 32 then
120 | table.insert(candidates, e.ent_wikiid .. ',' .. string.format("%.3f", e.p) .. ',' .. get_ent_name_from_wikiid(e.ent_wikiid))
121 | if e.ent_wikiid == hyp.ent_wikiid then
122 | gt_pos = pos
123 | end
124 | else
125 | break
126 | end
127 | end
128 | str = str .. table.concat(candidates, '\t') .. '\tGT:\t'
129 |
130 | if gt_pos > 0 then
131 | num_valid_hyp = num_valid_hyp + 1
132 | ouf:write(str .. gt_pos .. ',' .. candidates[gt_pos] .. '\n')
133 | end
134 | end
135 | end
136 | end
137 |
138 | cur_ent_wikiid = extract_page_entity_title(line)
139 | cur_words = {}
140 | cur_words_num = 0
141 | cur_mentions = {}
142 | cur_mentions_num = 0
143 | end
144 |
145 | line = it:read()
146 | end
147 | ouf:flush()
148 | io.close(ouf)
149 |
150 | print(' Done generating training data from Wiki dump. Num valid hyp = ' .. num_valid_hyp)
151 |
--------------------------------------------------------------------------------
/data_gen/indexes/wiki_disambiguation_pages_index.lua:
--------------------------------------------------------------------------------
1 | -- Loads the link disambiguation index from Wikipedia
2 |
3 | if not opt then
4 | cmd = torch.CmdLine()
5 | cmd:option('-root_data_dir', '', 'Root path of the data, $DATA_PATH.')
6 | cmd:text()
7 | opt = cmd:parse(arg or {})
8 | assert(opt.root_data_dir ~= '', 'Specify a valid root_data_dir path argument.')
9 | end
10 |
11 |
12 | require 'torch'
13 | dofile 'utils/utils.lua'
14 | tds = tds or require 'tds'
15 |
16 | print('==> Loading disambiguation index')
17 | it, _ = io.open(opt.root_data_dir .. 'basic_data/wiki_disambiguation_pages.txt')
18 | line = it:read()
19 |
20 | wiki_disambiguation_index = tds.Hash()
21 | while (line) do
22 | parts = split(line, "\t")
23 | assert(tonumber(parts[1]))
24 | wiki_disambiguation_index[tonumber(parts[1])] = 1
25 | line = it:read()
26 | end
27 |
28 | assert(wiki_disambiguation_index[579])
29 | assert(wiki_disambiguation_index[41535072])
30 |
31 | print(' Done loading disambiguation index')
32 |
--------------------------------------------------------------------------------
/data_gen/indexes/wiki_redirects_index.lua:
--------------------------------------------------------------------------------
1 | -- Loads the link redirect index from Wikipedia
2 |
3 | if not opt then
4 | cmd = torch.CmdLine()
5 | cmd:option('-root_data_dir', '', 'Root path of the data, $DATA_PATH.')
6 | cmd:text()
7 | opt = cmd:parse(arg or {})
8 | assert(opt.root_data_dir ~= '', 'Specify a valid root_data_dir path argument.')
9 | end
10 |
11 | require 'torch'
12 | dofile 'utils/utils.lua'
13 | tds = tds or require 'tds'
14 |
15 | print('==> Loading redirects index')
16 | it, _ = io.open(opt.root_data_dir .. 'basic_data/wiki_redirects.txt')
17 | line = it:read()
18 |
19 | local wiki_redirects_index = tds.Hash()
20 | while (line) do
21 | parts = split(line, "\t")
22 | wiki_redirects_index[parts[1]] = parts[2]
23 | line = it:read()
24 | end
25 |
26 | assert(wiki_redirects_index['Coercive'] == 'Coercion')
27 | assert(wiki_redirects_index['Hosford, FL'] == 'Hosford, Florida')
28 |
29 | print(' Done loading redirects index')
30 |
31 |
32 | function get_redirected_ent_title(ent_name)
33 | if wiki_redirects_index[ent_name] then
34 | return wiki_redirects_index[ent_name]
35 | else
36 | return ent_name
37 | end
38 | end
39 |
--------------------------------------------------------------------------------
/data_gen/indexes/yago_crosswikis_wiki.lua:
--------------------------------------------------------------------------------
1 | -- Loads the merged p(e|m) index.
2 | if not opt then
3 | cmd = torch.CmdLine()
4 | cmd:option('-root_data_dir', '', 'Root path of the data, $DATA_PATH.')
5 | cmd:text()
6 | opt = cmd:parse(arg or {})
7 | assert(opt.root_data_dir ~= '', 'Specify a valid root_data_dir path argument.')
8 | end
9 |
10 |
11 | require 'torch'
12 | tds = tds or require 'tds'
13 |
14 | dofile 'utils/utils.lua'
15 | dofile 'entities/ent_name2id_freq/ent_name_id.lua'
16 |
17 | ent_p_e_m_index = tds.Hash()
18 |
19 | mention_lower_to_one_upper = tds.Hash()
20 |
21 | mention_total_freq = tds.Hash()
22 |
23 | local crosswikis_textfilename = opt.root_data_dir .. 'generated/crosswikis_wikipedia_p_e_m.txt'
24 | print('==> Loading crosswikis_wikipedia from file ' .. crosswikis_textfilename)
25 | local it, _ = io.open(crosswikis_textfilename)
26 | local line = it:read()
27 |
28 | local num_lines = 0
29 | while (line) do
30 | num_lines = num_lines + 1
31 | if num_lines % 2000000 == 0 then
32 | print('Processed ' .. num_lines .. ' lines. ')
33 | end
34 |
35 | local parts = split(line , '\t')
36 | local mention = parts[1]
37 |
38 | local total = tonumber(parts[2])
39 | assert(total)
40 | if total >= 1 then
41 | ent_p_e_m_index[mention] = tds.Hash()
42 | mention_lower_to_one_upper[mention:lower()] = mention
43 | mention_total_freq[mention] = total
44 | local num_parts = table_len(parts)
45 | for i = 3, num_parts do
46 | local ent_str = split(parts[i], ',')
47 | local ent_wikiid = tonumber(ent_str[1])
48 | local freq = tonumber(ent_str[2])
49 | assert(ent_wikiid)
50 | assert(freq)
51 | ent_p_e_m_index[mention][ent_wikiid] = freq / (total + 0.0) -- not sorted
52 | end
53 | end
54 | line = it:read()
55 | end
56 |
57 | local yago_textfilename = opt.root_data_dir .. 'generated/yago_p_e_m.txt'
58 | print('==> Loading yago index from file ' .. yago_textfilename)
59 | it, _ = io.open(yago_textfilename)
60 | line = it:read()
61 |
62 | num_lines = 0
63 | while (line) do
64 | num_lines = num_lines + 1
65 | if num_lines % 2000000 == 0 then
66 | print('Processed ' .. num_lines .. ' lines. ')
67 | end
68 |
69 | local parts = split(line , '\t')
70 | local mention = parts[1]
71 |
72 | local total = tonumber(parts[2])
73 | assert(total)
74 | if total >= 1 then
75 | mention_lower_to_one_upper[mention:lower()] = mention
76 | if not mention_total_freq[mention] then
77 | mention_total_freq[mention] = total
78 | else
79 | mention_total_freq[mention] = total + mention_total_freq[mention]
80 | end
81 |
82 | local yago_ment_ent_idx = tds.Hash()
83 | local num_parts = table_len(parts)
84 | for i = 3, num_parts do
85 | local ent_str = split(parts[i], ',')
86 | local ent_wikiid = tonumber(ent_str[1])
87 | local freq = 1
88 | assert(ent_wikiid)
89 | yago_ment_ent_idx[ent_wikiid] = freq / (total + 0.0) -- not sorted
90 | end
91 |
92 | if not ent_p_e_m_index[mention] then
93 | ent_p_e_m_index[mention] = yago_ment_ent_idx
94 | else
95 | for ent_wikiid,prob in pairs(yago_ment_ent_idx) do
96 | if not ent_p_e_m_index[mention][ent_wikiid] then
97 | ent_p_e_m_index[mention][ent_wikiid] = 0.0
98 | end
99 | ent_p_e_m_index[mention][ent_wikiid] = math.min(1.0, ent_p_e_m_index[mention][ent_wikiid] + prob)
100 | end
101 |
102 | end
103 |
104 | end
105 | line = it:read()
106 | end
107 |
108 | assert(ent_p_e_m_index['Dejan Koturovic'] and ent_p_e_m_index['Jose Luis Caminero'])
109 |
110 | -- Function used to preprocess a given mention such that it has higher
111 | -- chance to have at least one valid entry in the p(e|m) index.
112 | function preprocess_mention(m)
113 | assert(ent_p_e_m_index and mention_total_freq)
114 | local cur_m = modify_uppercase_phrase(m)
115 | if (not ent_p_e_m_index[cur_m]) then
116 | cur_m = m
117 | end
118 | if (mention_total_freq[m] and (mention_total_freq[m] > mention_total_freq[cur_m])) then
119 | cur_m = m -- Cases like 'U.S.' are handed badly by modify_uppercase_phrase
120 | end
121 | -- If we cannot find the exact mention in our index, we try our luck to find it in a case insensitive index.
122 | if (not ent_p_e_m_index[cur_m]) and mention_lower_to_one_upper[cur_m:lower()] then
123 | cur_m = mention_lower_to_one_upper[cur_m:lower()]
124 | end
125 | return cur_m
126 | end
127 |
128 |
129 | print(' Done loading index')
130 |
--------------------------------------------------------------------------------
/data_gen/parse_wiki_dump/parse_wiki_dump_tools.lua:
--------------------------------------------------------------------------------
1 | -- Utility functions to extract the text and hyperlinks from each page in the Wikipedia corpus.
2 |
3 | if not table_len then
4 | dofile 'utils/utils.lua'
5 | end
6 | if not get_redirected_ent_title then
7 | dofile 'data_gen/indexes/wiki_redirects_index.lua'
8 | end
9 | if not get_ent_name_from_wikiid then
10 | dofile 'entities/ent_name2id_freq/ent_name_id.lua'
11 | end
12 |
13 |
14 | function extract_text_and_hyp(line, mark_mentions)
15 | local list_hyp = {} -- (mention, entity) pairs
16 | local text = ''
17 | local list_ent_errors = 0
18 | local parsing_errors = 0
19 | local disambiguation_ent_errors = 0
20 | local diez_ent_errors = 0
21 |
22 | local end_end_hyp = 0
23 | local begin_end_hyp = 0
24 | local begin_start_hyp, end_start_hyp = line:find('', end_start_hyp + 1)
32 | if next_quotes then
33 | local ent_name = line:sub(end_start_hyp + 1, next_quotes - 1)
34 | begin_end_hyp, end_end_hyp = line:find('', end_quotes + 1)
35 | if begin_end_hyp then
36 | local mention = line:sub(end_quotes + 1, begin_end_hyp - 1)
37 | local mention_marker = false
38 |
39 | local good_mention = true
40 | good_mention = good_mention and (not mention:find('Wikipedia'))
41 | good_mention = good_mention and (not mention:find('wikipedia'))
42 | good_mention = good_mention and (mention:len() >= 1)
43 |
44 | if good_mention then
45 | local i = ent_name:find('wikt:')
46 | if i == 1 then
47 | ent_name = ent_name:sub(6)
48 | end
49 | ent_name = preprocess_ent_name(ent_name)
50 |
51 | i = ent_name:find('List of ')
52 | if (not i) or (i ~= 1) then
53 | if ent_name:find('#') then
54 | diez_ent_errors = diez_ent_errors + 1
55 | else
56 | local ent_wikiid = get_ent_wikiid_from_name(ent_name, true)
57 | if ent_wikiid == unk_ent_wikiid then
58 | disambiguation_ent_errors = disambiguation_ent_errors + 1
59 | else
60 | -- A valid (entity,mention) pair
61 | num_mentions = num_mentions + 1
62 | table.insert(list_hyp, {mention = mention, ent_wikiid = ent_wikiid, cnt = num_mentions})
63 | if mark_mentions then
64 | mention_marker = true
65 | end
66 | end
67 | end
68 | else
69 | list_ent_errors = list_ent_errors + 1
70 | end
71 | end
72 |
73 | if (not mention_marker) then
74 | text = text .. ' ' .. mention .. ' '
75 | else
76 | text = text .. ' MMSTART' .. num_mentions .. ' ' .. mention .. ' MMEND' .. num_mentions .. ' '
77 | end
78 | else
79 | parsing_errors = parsing_errors + 1
80 | begin_start_hyp = nil
81 | end
82 | else
83 | parsing_errors = parsing_errors + 1
84 | begin_start_hyp = nil
85 | end
86 |
87 | if begin_start_hyp then
88 | begin_start_hyp, end_start_hyp = line:find('Anarchism is a political philosophy that advocatesstateless societiesoften defined as self-governed voluntary institutions, but that several authors have defined as more specific institutions based on non-hierarchical free associations..Anarchism'
109 |
110 | local test_line_2 = 'CSF pressure, as measured by lumbar puncture (LP), is 10-18 '
111 | local test_line_3 = 'Anarchism'
112 |
113 | list_hype, text = extract_text_and_hyp(test_line_1, false)
114 | print(list_hype)
115 | print(text)
116 | print()
117 |
118 | list_hype, text = extract_text_and_hyp(test_line_1, true)
119 | print(list_hype)
120 | print(text)
121 | print()
122 |
123 | list_hype, text = extract_text_and_hyp(test_line_2, true)
124 | print(list_hype)
125 | print(text)
126 | print()
127 |
128 | list_hype, text = extract_text_and_hyp(test_line_3, false)
129 | print(list_hype)
130 | print(text)
131 | print()
132 | print(' Done unit tests.')
133 | ---------------------------------------------------------
134 |
135 |
136 | function extract_page_entity_title(line)
137 | local startoff, endoff = line:find(' ' .. line:sub(startoff + 1, startquotes - 1))
142 | local starttitlestartoff, starttitleendoff = line:find(' title="')
143 | local endtitleoff, _ = line:find('">')
144 | local ent_name = line:sub(starttitleendoff + 1, endtitleoff - 1)
145 | if (ent_wikiid ~= get_ent_wikiid_from_name(ent_name, true)) then
146 | -- Most probably this is a disambiguation or list page
147 | local new_ent_wikiid = get_ent_wikiid_from_name(ent_name, true)
148 | -- print(red('Error in Wiki dump: ' .. line .. ' ' .. ent_wikiid .. ' ' .. new_ent_wikiid))
149 | return new_ent_wikiid
150 | end
151 | return ent_wikiid
152 | end
153 |
154 |
155 | local test_line_4 = ''
156 |
157 | print(extract_page_entity_title(test_line_4))
158 |
--------------------------------------------------------------------------------
/ed/args.lua:
--------------------------------------------------------------------------------
1 | -- We add params abbreviations to the abbv map if they are important hyperparameters
2 | -- used to differentiate between different methods.
3 | abbv = {}
4 |
5 | cmd = torch.CmdLine()
6 | cmd:text()
7 | cmd:text('Deep Joint Entity Disambiguation w/ Local Neural Attention')
8 | cmd:text('Command line options:')
9 |
10 | ---------------- runtime options:
11 | cmd:option('-root_data_dir', '', 'Root path of the data, $DATA_PATH.')
12 |
13 | cmd:option('-unit_tests', false, 'Run unit tests or not')
14 |
15 | ---------------- CUDA:
16 | cmd:option('-type', 'cudacudnn', 'Type: cuda | float | cudacudnn')
17 |
18 | ---------------- train data:
19 | cmd:option('-store_train_data', 'RAM', 'Where to read the training data from: RAM (tensors) | DISK (text, parsed all the time)')
20 |
21 | ---------------- loss:
22 | cmd:option('-loss', 'max-margin', 'Loss: nll | max-margin')
23 | abbv['-loss'] = ''
24 |
25 | ---------------- optimization:
26 | cmd:option('-opt', 'ADAM', 'Optimization method: SGD | ADADELTA | ADAGRAD | ADAM')
27 | abbv['-opt'] = ''
28 |
29 | cmd:option('-lr', 1e-4, 'Learning rate. Will be divided by 10 after validation F1 >= 90%.')
30 | abbv['-lr'] = 'lr'
31 |
32 | cmd:option('-batch_size', 1, 'Batch size in terms of number of documents.')
33 | abbv['-batch_size'] = 'bs'
34 |
35 | ---------------- word vectors
36 | cmd:option('-word_vecs', 'w2v', 'Word vectors type: glove | w2v (Word2Vec)')
37 | abbv['-word_vecs'] = ''
38 |
39 | ---------------- entity vectors
40 | cmd:option('-entities', 'RLTD', 'Which entity vectors to use, either just those that appear as candidates in all datasets, or all. All is impractical when storing on GPU. RLTD | ALL')
41 |
42 | cmd:option('-ent_vecs_filename', 'ent_vecs__ep_228.t7', 'File name containing entity vectors generated with entities/learn_e2v/learn_a.lua.')
43 |
44 | ---------------- context
45 | cmd:option('-ctxt_window', 100, 'Number of context words at the left plus right of each mention')
46 | abbv['-ctxt_window'] = 'ctxtW'
47 |
48 | cmd:option('-R', 25, 'Hard attention threshold: top R context words are kept, the rest are discarded.')
49 | abbv['-R'] = 'R'
50 |
51 | ---------------- model
52 | cmd:option('-model', 'global', 'ED model: local | global')
53 |
54 | cmd:option('-nn_pem_interm_size', 100, 'Number of hidden units in the f function described in Section 4 - Local score combination.')
55 | abbv['-nn_pem_interm_size'] = 'nnPEMintermS'
56 |
57 | -------------- model regularization:
58 | cmd:option('-mat_reg_norm', 1, 'Maximum norm of columns of matrices of the f network.')
59 | abbv['-mat_reg_norm'] = 'matRegNorm'
60 |
61 | ---------------- global model parameters
62 | cmd:option('-lbp_iter', 10, 'Number iterations of LBP hard-coded in a NN. Referred as T in the paper.')
63 | abbv['-lbp_iter'] = 'lbpIt'
64 |
65 | cmd:option('-lbp_damp', 0.5, 'Damping factor for LBP')
66 | abbv['-lbp_damp'] = 'lbpDamp'
67 |
68 | ----------------- reranking of candidates
69 | cmd:option('-num_cand_before_rerank', 30, '')
70 | abbv['-num_cand_before_rerank'] = 'numERerank'
71 |
72 | cmd:option('-keep_p_e_m', 4, '')
73 | abbv['-keep_p_e_m'] = 'keepPEM'
74 |
75 | cmd:option('-keep_e_ctxt', 3, '')
76 | abbv['-keep_e_ctxt'] = 'keepEC'
77 |
78 | ----------------- coreference:
79 | cmd:option('-coref', true, 'Coreference heuristic to match persons names.')
80 | abbv['-coref'] = 'coref'
81 |
82 | ------------------ test one model with saved pretrained parameters
83 | cmd:option('-test_one_model_file', '', 'Saved pretrained model filename from folder $DATA_PATH/generated/ed_models/.')
84 |
85 | ------------------ banner:
86 | cmd:option('-banner_header', '', ' Banner header to be used for plotting')
87 |
88 | cmd:text()
89 | opt = cmd:parse(arg or {})
90 |
91 | -- Whether to save the current ED model or not during training.
92 | -- It will become true after the model gets > 90% F1 score on validation set (see test.lua).
93 | opt.save = false
94 |
95 | -- Creates a nice banner from the command line arguments
96 | function get_banner(arg, abbv)
97 | local num_args = table_len(arg)
98 | local banner = opt.banner_header
99 |
100 | if opt.model == 'global' then
101 | banner = banner .. 'GLB'
102 | else
103 | banner = banner .. 'LCL'
104 | end
105 |
106 | for i = 1,num_args do
107 | if abbv[arg[i]] then
108 | banner = banner .. '|' .. abbv[arg[i]] .. '=' .. tostring(opt[arg[i]:sub(2)])
109 | end
110 | end
111 | return banner
112 | end
113 |
114 |
115 | function serialize_params(arg)
116 | local num_args = table_len(arg)
117 | local str = opt.banner_header
118 | if opt.banner_header:len() > 0 then
119 | str = str .. '|'
120 | end
121 |
122 | str = str .. 'model=' .. opt.model
123 |
124 | for i = 1,num_args do
125 | if abbv[arg[i]] then
126 | str = str .. '|' .. arg[i]:sub(2) .. '=' .. tostring(opt[arg[i]:sub(2)])
127 | end
128 | end
129 | return str
130 | end
131 |
132 | banner = get_banner(arg, abbv)
133 |
134 | params_serialized = serialize_params(arg)
135 | print('PARAMS SERIALIZED: ' .. params_serialized)
136 | print('BANNER : ' .. banner .. '\n')
137 |
138 | assert(params_serialized:len() < 255, 'Parameters string length should be < 255.')
139 |
140 |
141 | function extract_args_from_model_title(title)
142 | local x,y = title:find('model=')
143 | local parts = split(title:sub(x), '|')
144 | for _,part in pairs(parts) do
145 | if string.find(part, '=') then
146 | local components = split(part, '=')
147 | opt[components[1]] = components[2]
148 | if tonumber(components[2]) then
149 | opt[components[1]] = tonumber(components[2])
150 | end
151 | if components[2] == 'true' then
152 | opt[components[1]] = true
153 | end
154 | if components[2] == 'false' then
155 | opt[components[1]] = false
156 | end
157 | end
158 | end
159 | end
160 |
--------------------------------------------------------------------------------
/ed/ed.lua:
--------------------------------------------------------------------------------
1 | require 'optim'
2 | require 'torch'
3 | require 'gnuplot'
4 | require 'nn'
5 | require 'xlua'
6 |
7 | tds = tds or require 'tds'
8 | dofile 'utils/utils.lua'
9 |
10 | print('\n' .. green('==========> TRAINING of ED NEURAL MODELS <==========') .. '\n')
11 |
12 | dofile 'ed/args.lua'
13 |
14 | print('===> RUN TYPE (CPU/GPU): ' .. opt.type)
15 |
16 | torch.setdefaulttensortype('torch.FloatTensor')
17 | if string.find(opt.type, 'cuda') then
18 | print('==> switching to CUDA (GPU)')
19 | require 'cunn'
20 | require 'cutorch'
21 | require 'cudnn'
22 | cudnn.benchmark = true
23 | cudnn.fastest = true
24 | else
25 | print('==> running on CPU')
26 | end
27 |
28 | dofile 'utils/logger.lua'
29 | dofile 'entities/relatedness/relatedness.lua'
30 | dofile 'entities/ent_name2id_freq/ent_name_id.lua'
31 | dofile 'entities/ent_name2id_freq/e_freq_index.lua'
32 | dofile 'words/load_w_freq_and_vecs.lua' -- w ids
33 | dofile 'words/w2v/w2v.lua'
34 | dofile 'entities/pretrained_e2v/e2v.lua'
35 | dofile 'ed/minibatch/build_minibatch.lua'
36 | dofile 'ed/minibatch/data_loader.lua'
37 | dofile 'ed/models/model.lua'
38 | dofile 'ed/loss.lua'
39 | dofile 'ed/train.lua'
40 | dofile 'ed/test/test.lua'
41 |
42 | geom_unit_tests() -- Show some entity examples
43 | compute_relatedness_metrics(entity_similarity) -- UNCOMMENT
44 |
45 | train_and_test()
46 |
--------------------------------------------------------------------------------
/ed/loss.lua:
--------------------------------------------------------------------------------
1 | if opt.loss == 'nll' then
2 | criterion = nn.CrossEntropyCriterion()
3 | else
4 | -- max-margin with margin parameter = 0.01
5 | criterion = nn.MultiMarginCriterion(1, torch.ones(max_num_cand), 0.01)
6 | end
7 |
8 | if string.find(opt.type, 'cuda') then
9 | criterion = criterion:cuda()
10 | end
--------------------------------------------------------------------------------
/ed/minibatch/data_loader.lua:
--------------------------------------------------------------------------------
1 | -- Data loader for training of ED models.
2 |
3 | train_file = opt.root_data_dir .. 'generated/test_train_data/aida_train.csv'
4 | it_train, _ = io.open(train_file)
5 |
6 | print('==> Loading training data with option ' .. opt.store_train_data)
7 | local function one_doc_to_minibatch(doc_lines)
8 | -- Create empty mini batch:
9 | local num_mentions = #doc_lines
10 | assert(num_mentions > 0)
11 |
12 | local inputs = empty_minibatch_with_ids(num_mentions)
13 | local targets = torch.zeros(num_mentions)
14 |
15 | -- Fill in each example:
16 | for i = 1, num_mentions do
17 | local target = process_one_line(doc_lines[i], inputs, i, true)
18 | targets[i] = target
19 | assert(target >= 1 and target == targets[i])
20 | end
21 |
22 | return inputs, targets
23 | end
24 |
25 | if opt.store_train_data == 'RAM' then
26 | all_docs_inputs = tds.Hash()
27 | all_docs_targets = tds.Hash()
28 | doc2id = tds.Hash()
29 | id2doc = tds.Hash()
30 |
31 | local cur_doc_lines = tds.Hash()
32 | local prev_doc_id = nil
33 |
34 | local line = it_train:read()
35 | while line do
36 | local parts = split(line, '\t')
37 | local doc_name = parts[1]
38 | if not doc2id[doc_name] then
39 | if prev_doc_id then
40 | local inputs, targets = one_doc_to_minibatch(cur_doc_lines)
41 | all_docs_inputs[prev_doc_id] = minibatch_table2tds(inputs)
42 | all_docs_targets[prev_doc_id] = targets
43 | end
44 | local cur_docid = 1 + #doc2id
45 | id2doc[cur_docid] = doc_name
46 | doc2id[doc_name] = cur_docid
47 | cur_doc_lines = tds.Hash()
48 | prev_doc_id = cur_docid
49 | end
50 | cur_doc_lines[1 + #cur_doc_lines] = line
51 | line = it_train:read()
52 | end
53 | if prev_doc_id then
54 | local inputs, targets = one_doc_to_minibatch(cur_doc_lines)
55 | all_docs_inputs[prev_doc_id] = minibatch_table2tds(inputs)
56 | all_docs_targets[prev_doc_id] = targets
57 | end
58 | assert(#doc2id == #all_docs_inputs, #doc2id .. ' ' .. #all_docs_inputs)
59 |
60 | else
61 | all_doc_lines = tds.Hash()
62 | doc2id = tds.Hash()
63 | id2doc = tds.Hash()
64 |
65 | local line = it_train:read()
66 | while line do
67 | local parts = split(line, '\t')
68 | local doc_name = parts[1]
69 | if not doc2id[doc_name] then
70 | local cur_docid = 1 + #doc2id
71 | id2doc[cur_docid] = doc_name
72 | doc2id[doc_name] = cur_docid
73 | all_doc_lines[cur_docid] = tds.Hash()
74 | end
75 | all_doc_lines[doc2id[doc_name]][1 + #all_doc_lines[doc2id[doc_name]]] = line
76 | line = it_train:read()
77 | end
78 | assert(#doc2id == #all_doc_lines)
79 | end
80 |
81 |
82 | get_minibatch = function()
83 | -- Create empty mini batch:
84 | local inputs = nil
85 | local targets = nil
86 |
87 | if opt.store_train_data == 'RAM' then
88 | local random_docid = math.random(#id2doc)
89 | inputs = minibatch_tds2table(all_docs_inputs[random_docid])
90 | targets = all_docs_targets[random_docid]
91 | else
92 | local doc_lines = all_doc_lines[math.random(#id2doc)]
93 | inputs, targets = one_doc_to_minibatch(doc_lines)
94 | end
95 |
96 | -- Move data to GPU:
97 | inputs, targets = minibatch_to_correct_type(inputs, targets, true)
98 | targets = correct_type(targets)
99 |
100 | return inputs, targets
101 | end
102 |
103 | print(' Done loading training data.')
104 |
--------------------------------------------------------------------------------
/ed/models/SetConstantDiag.lua:
--------------------------------------------------------------------------------
1 | -- Torch layer that receives as input a squared matrix and sets its diagonal to a constant value.
2 |
3 | local SetConstantDiag, parent = torch.class('nn.SetConstantDiag', 'nn.Module')
4 |
5 | function SetConstantDiag:__init(constant_scalar, ip)
6 | parent.__init(self)
7 | assert(type(constant_scalar) == 'number', 'input is not scalar!')
8 | self.constant_scalar = constant_scalar
9 |
10 | -- default for inplace is false
11 | self.inplace = ip or false
12 | if (ip and type(ip) ~= 'boolean') then
13 | error('in-place flag must be boolean')
14 | end
15 | if not opt then
16 | opt = {}
17 | opt.type = 'double'
18 | dofile 'utils/utils.lua'
19 | end
20 | end
21 |
22 | function SetConstantDiag:updateOutput(input)
23 | assert(input:dim() == 3)
24 | assert(input:size(2) == input:size(3))
25 | local n = input:size(3)
26 | local prod_mat = torch.ones(n,n) - torch.eye(n)
27 | prod_mat = correct_type(prod_mat)
28 | local sum_mat = torch.eye(n):mul(self.constant_scalar)
29 | sum_mat = correct_type(sum_mat)
30 | if self.inplace then
31 | input:cmul(torch.repeatTensor(prod_mat, input:size(1), 1, 1))
32 | input:add(torch.repeatTensor(sum_mat, input:size(1), 1, 1))
33 | self.output:set(input)
34 | else
35 | self.output:resizeAs(input)
36 | self.output:copy(input)
37 | self.output:cmul(torch.repeatTensor(prod_mat, input:size(1), 1, 1))
38 | self.output:add(torch.repeatTensor(sum_mat, input:size(1), 1, 1))
39 | end
40 | return self.output
41 | end
42 |
43 | function SetConstantDiag:updateGradInput(input, gradOutput)
44 | local n = input:size(3)
45 | local prod_mat = torch.ones(n,n) - torch.eye(n)
46 | prod_mat = correct_type(prod_mat)
47 | if self.inplace then
48 | self.gradInput:set(gradOutput)
49 | else
50 | self.gradInput:resizeAs(gradOutput)
51 | self.gradInput:copy(gradOutput)
52 | end
53 | self.gradInput:cmul(torch.repeatTensor(prod_mat, input:size(1), 1, 1))
54 | return self.gradInput
55 | end
--------------------------------------------------------------------------------
/ed/models/linear_layers.lua:
--------------------------------------------------------------------------------
1 | -- Define all parametrized layers: linear layers (diagonal matrices) A,B,C + network f
2 |
3 | function new_linear_layer(out_dim)
4 | cmul = nn.CMul(out_dim)
5 | -- init weights with ones to speed up convergence
6 | cmul.weight = torch.ones(out_dim)
7 | return cmul
8 | end
9 |
10 | ---- Create shared weights
11 | A_linear = new_linear_layer(ent_vecs_size)
12 |
13 | -- Local ctxt bilinear weights
14 | B_linear = new_linear_layer(ent_vecs_size)
15 |
16 | -- Used only in the global model
17 | C_linear = new_linear_layer(ent_vecs_size)
18 |
19 | f_network = nn.Sequential()
20 | :add(nn.Linear(2,opt.nn_pem_interm_size))
21 | :add(nn.ReLU())
22 | :add(nn.Linear(opt.nn_pem_interm_size,1))
23 |
24 |
25 | function regularize_f_network()
26 | if opt.mat_reg_norm < 10 then
27 | for i = 1,f_network:size() do
28 | if f_network:get(i).weight and (f_network:get(i).weight:norm() > opt.mat_reg_norm) then
29 | f_network:get(i).weight:mul(opt.mat_reg_norm / f_network:get(i).weight:norm())
30 | end
31 | if f_network:get(i).bias and (f_network:get(i).bias:norm() > opt.mat_reg_norm) then
32 | f_network:get(i).bias:mul(opt.mat_reg_norm / f_network:get(i).bias:norm())
33 | end
34 | end
35 | end
36 | end
37 |
38 | function pack_saveable_weights()
39 | local linears = nn.Sequential():add(A_linear):add(B_linear):add(C_linear):add(f_network)
40 | return linears:float()
41 | end
42 |
43 | function unpack_saveable_weights(saved_linears)
44 | A_linear = saved_linears:get(1)
45 | B_linear = saved_linears:get(2)
46 | C_linear = saved_linears:get(3)
47 | f_network = saved_linears:get(4)
48 | end
49 |
50 |
51 | function print_net_weights()
52 | print('\nNetwork norms of parameter weights :')
53 | print('A (attention mat) = ' .. A_linear.weight:norm())
54 | print('B (ctxt embedding) = ' .. B_linear.weight:norm())
55 | print('C (pairwise mat) = ' .. C_linear.weight:norm())
56 |
57 | if opt.mat_reg_norm < 10 then
58 | print('f_network norm = ' .. f_network:get(1).weight:norm() .. ' ' ..
59 | f_network:get(1).bias:norm() .. ' ' .. f_network:get(3).weight:norm() .. ' ' ..
60 | f_network:get(3).bias:norm())
61 | else
62 | p,gp = f_network:getParameters()
63 | print('f_network norm = ' .. p:norm())
64 | end
65 | end
66 |
--------------------------------------------------------------------------------
/ed/models/model.lua:
--------------------------------------------------------------------------------
1 | dofile 'ed/models/SetConstantDiag.lua'
2 | dofile 'ed/models/linear_layers.lua'
3 | dofile 'ed/models/model_local.lua'
4 | dofile 'ed/models/model_global.lua'
5 |
6 | function get_model(num_mentions)
7 | local model_ctxt, additional_local_submodels = local_model(num_mentions, A_linear, B_linear)
8 | local model = model_ctxt
9 |
10 | if opt.model == 'global' then
11 | model = global_model(num_mentions, model_ctxt, C_linear, f_network)
12 | else
13 | assert(opt.model == 'local')
14 | end
15 |
16 | return model, additional_local_submodels
17 | end
18 |
--------------------------------------------------------------------------------
/ed/models/model_global.lua:
--------------------------------------------------------------------------------
1 | -- Definition of the neural network used for global (joint) ED. Section 5 of our paper.
2 | -- It unrolls a fixed number of LBP iterations allowing training of the CRF potentials using backprop.
3 | -- To run a simple unit test that checks the forward and backward passes, just run :
4 | -- th ed/models/model_global.lua
5 |
6 | if not opt then -- unit tests
7 | dofile 'ed/models/model_local.lua'
8 | dofile 'ed/models/SetConstantDiag.lua'
9 | dofile 'ed/models/linear_layers.lua'
10 |
11 | opt.lbp_iter = 10
12 | opt.lbp_damp = 0.5
13 | opt.model = 'global'
14 | end
15 |
16 |
17 | function global_model(num_mentions, model_ctxt, param_C_linear, param_f_network)
18 |
19 | assert(num_mentions)
20 | assert(model_ctxt)
21 | assert(param_C_linear)
22 | assert(param_f_network)
23 |
24 | local unary_plus_pairwise = nn.Sequential()
25 | :add(nn.ParallelTable()
26 | :add(nn.Sequential() -- Pairwise scores s_{ij}(y_i, y_j)
27 | :add(nn.View(num_mentions * max_num_cand, ent_vecs_size)) -- e_vecs
28 | :add(nn.ConcatTable()
29 | :add(nn.Identity())
30 | :add(param_C_linear)
31 | )
32 | :add(nn.MM(false, true)) -- s_{ij}(y_i, y_j) is s[i][y_i][j][y_j] =
33 | :add(nn.View(num_mentions, max_num_cand, num_mentions, max_num_cand))
34 | :add(nn.MulConstant(2.0 / num_mentions, true))
35 | )
36 | :add(nn.Sequential() -- Unary scores s_j(y_j)
37 | :add(nn.Replicate(num_mentions * max_num_cand, 1))
38 | :add(nn.Reshape(num_mentions, max_num_cand, num_mentions, max_num_cand))
39 | )
40 | )
41 | :add(nn.CAddTable()) -- q[i][y_i][j][y_j] = s_j(y_j) + s_{ij}(y_i, y_j): num_mentions x max_num_cand x num_mentions x max_num_cand
42 |
43 |
44 | -- Input is q[i][y_i][j] : num_mentions x max_num_cand x num_mentions
45 | local messages_one_round = nn.ConcatTable()
46 | :add(nn.SelectTable(1)) -- 1. unary_plus_pairwise : num_mentions, max_num_cand, num_mentions, max_num_cand
47 | :add(nn.Sequential()
48 | :add(nn.ConcatTable()
49 | :add(nn.Sequential()
50 | :add(nn.SelectTable(2)) -- old messages
51 | :add(nn.Exp())
52 | :add(nn.MulConstant(1.0 - opt.lbp_damp, false))
53 | )
54 | :add(nn.Sequential()
55 | :add(nn.ParallelTable()
56 | :add(nn.Identity()) -- unary plus pairwise
57 | :add(nn.Sequential() -- old messages: num_mentions, max_num_cand, num_mentions
58 | :add(nn.Sum(3)) -- g[i][y_i] := \sum_{k != i} q[i][y_i][k]
59 | :add(nn.Replicate(num_mentions * max_num_cand, 1))
60 | :add(nn.Reshape(num_mentions, max_num_cand, num_mentions, max_num_cand))
61 | )
62 | )
63 | :add(nn.CAddTable()) -- s_{j}(y_j) + s_{ij}(y_i, y_j) + g[j][y_j] : num_mentions, max_num_cand, num_mentions, max_num_cand
64 | :add(nn.Max(4)) -- unnorm_q[i][y_i][j] : num_mentions x max_num_cand x num_mentions
65 | :add(nn.Transpose({2,3}))
66 | :add(nn.View(num_mentions * num_mentions, max_num_cand))
67 | :add(nn.LogSoftMax()) -- normalization: \sum_{y_i} exp(q[i][y_i][j]) = 1
68 | :add(nn.View(num_mentions, num_mentions, max_num_cand))
69 | :add(nn.Transpose({2,3}))
70 | :add(nn.Transpose({1,2}))
71 | :add(nn.SetConstantDiag(0, true)) -- we make q[i][y_i][i] = 0, \forall i and y_i
72 | :add(nn.Transpose({1,2}))
73 | :add(nn.Exp())
74 | :add(nn.MulConstant(opt.lbp_damp, false))
75 | )
76 | )
77 | :add(nn.CAddTable()) -- 2. messages for next round: num_mentions, max_num_cand, num_mentions
78 | :add(nn.Log())
79 | )
80 |
81 |
82 | local messages_all_rounds = nn.Sequential()
83 | messages_all_rounds:add(nn.Identity())
84 | for i = 1, opt.lbp_iter do
85 | messages_all_rounds:add(messages_one_round:clone('weight','bias','gradWeight','gradBias'))
86 | end
87 |
88 | local model_gl = nn.Sequential()
89 | :add(nn.ConcatTable()
90 | :add(nn.Sequential()
91 | :add(nn.SelectTable(2))
92 | :add(nn.SelectTable(2)) -- e_vecs : num_mentions, max_num_cand, ent_vecs_size
93 | )
94 | :add(model_ctxt) -- unary scores : num_mentions, max_num_cand
95 | )
96 | :add(nn.ConcatTable()
97 | :add(nn.Sequential()
98 | :add(unary_plus_pairwise)
99 | :add(nn.ConcatTable()
100 | :add(nn.Identity())
101 | :add(nn.Sequential()
102 | :add(nn.Max(4))
103 | :add(nn.MulConstant(0, false)) -- first_round_zero_messages
104 | )
105 | )
106 | :add(messages_all_rounds)
107 | :add(nn.SelectTable(2))
108 | :add(nn.Sum(3)) -- \sum_{j} msgs[i][y_i]: num_mentions x max_num_cand
109 | )
110 | :add(nn.SelectTable(2)) -- unary scores : num_mentions x max_num_cand
111 | )
112 | :add(nn.CAddTable())
113 | :add(nn.LogSoftMax()) -- belief[i][y_i] (lbp marginals in log scale)
114 |
115 |
116 | -- Combine lbp marginals with log p(e|m) using the simple f neural network
117 | local pem_layer = nn.SelectTable(3)
118 | model_gl = nn.Sequential()
119 | :add(nn.ConcatTable()
120 | :add(nn.Sequential()
121 | :add(model_gl)
122 | :add(nn.View(num_mentions * max_num_cand, 1))
123 | )
124 | :add(nn.Sequential()
125 | :add(pem_layer)
126 | :add(nn.View(num_mentions * max_num_cand, 1))
127 | )
128 | )
129 | :add(nn.JoinTable(2))
130 | :add(param_f_network)
131 | :add(nn.View(num_mentions, max_num_cand))
132 |
133 | ------- Cuda conversions:
134 | if string.find(opt.type, 'cuda') then
135 | model_gl = model_gl:cuda()
136 | end
137 |
138 | return model_gl
139 | end
140 |
141 |
142 | --- Unit tests
143 | if unit_tests_now then
144 | print('\n Global network model unit tests:')
145 | local num_mentions = 13
146 |
147 | local inputs = {}
148 | -- ctxt_w_vecs
149 | inputs[1] = {}
150 | inputs[1][1] = torch.ones(num_mentions, opt.ctxt_window):int():mul(unk_w_id)
151 | inputs[1][2] = torch.randn(num_mentions, opt.ctxt_window, ent_vecs_size)
152 | -- e_vecs
153 | inputs[2] = {}
154 | inputs[2][1] = torch.ones(num_mentions, max_num_cand):int():mul(unk_ent_thid)
155 | inputs[2][2] = torch.randn(num_mentions, max_num_cand, ent_vecs_size)
156 | -- p(e|m)
157 | inputs[3] = torch.log(torch.rand(num_mentions, max_num_cand))
158 |
159 | local model_ctxt, _ = local_model(num_mentions, A_linear, B_linear, opt)
160 |
161 | local model_gl = global_model(num_mentions, model_ctxt, C_linear, f_network)
162 | local outputs = model_gl:forward(inputs)
163 | print(outputs)
164 | print('MIN: ' .. torch.min(outputs) .. ' MAX: ' .. torch.max(outputs))
165 | assert(outputs:size(1) == num_mentions and outputs:size(2) == max_num_cand)
166 | print('Global FWD success!')
167 |
168 | model_gl:backward(inputs, torch.randn(num_mentions, max_num_cand))
169 | print('Global BKWD success!')
170 |
171 | parameters,gradParameters = model_gl:getParameters()
172 | print(parameters:size())
173 | print(gradParameters:size())
174 | end
--------------------------------------------------------------------------------
/ed/models/model_local.lua:
--------------------------------------------------------------------------------
1 | -- Definition of the local neural network with attention used for local (independent per each mention) ED.
2 | -- Section 4 of our paper.
3 | -- To run a simple unit test that checks the forward and backward passes, just run :
4 | -- th ed/models/model_local.lua
5 |
6 | if not opt then -- unit tests
7 | unit_tests_now = true
8 | dofile 'utils/utils.lua'
9 | require 'nn'
10 | opt = {type = 'double', ctxt_window = 100, R = 25, model = 'local', nn_pem_interm_size = 100}
11 |
12 | word_vecs_size = 300
13 | ent_vecs_size = 300
14 | max_num_cand = 6
15 | unk_ent_wikiid = 1
16 | unk_ent_thid = 1
17 | unk_w_id = 1
18 | dofile 'ed/models/linear_layers.lua'
19 | word_lookup_table = nn.LookupTable(5, ent_vecs_size)
20 | ent_lookup_table = nn.LookupTable(5, ent_vecs_size)
21 | else
22 | word_lookup_table = nn.LookupTable(w2vutils.M:size(1), ent_vecs_size)
23 | word_lookup_table.weight = w2vutils.M
24 |
25 | ent_lookup_table = nn.LookupTable(e2vutils.lookup:size(1), ent_vecs_size)
26 | ent_lookup_table.weight = e2vutils.lookup
27 |
28 | if string.find(opt.type, 'cuda') then
29 | word_lookup_table = word_lookup_table:cuda()
30 | ent_lookup_table = ent_lookup_table:cuda()
31 | end
32 | end
33 |
34 | assert(word_vecs_size == 300 and ent_vecs_size == 300)
35 |
36 |
37 | ----------------- Define the model
38 | function local_model(num_mentions, param_A_linear, param_B_linear)
39 |
40 | assert(num_mentions)
41 | assert(param_A_linear)
42 | assert(param_B_linear)
43 |
44 | model = nn.Sequential()
45 |
46 | ctxt_embed_and_ent_lookup = nn.ConcatTable()
47 | :add(nn.Sequential()
48 | :add(nn.SelectTable(1))
49 | :add(nn.SelectTable(2)) -- 1 : Context words W : num_mentions x opt.ctxt_window x ent_vecs_size
50 | )
51 | :add(nn.Sequential()
52 | :add(nn.SelectTable(2))
53 | :add(nn.SelectTable(2)) -- 2 : Candidate entity vectors E : num_mentions x max_num_cand x ent_vecs_size
54 | )
55 | :add(nn.SelectTable(3)) -- 3 : log p(e|m) : num_mentions x max_num_cand
56 |
57 | model:add(ctxt_embed_and_ent_lookup)
58 |
59 |
60 | local mem_weights_p_2 = nn.ConcatTable()
61 | :add(nn.Identity()) -- 1 : {W, E, logp(e|m)}
62 | :add(nn.Sequential()
63 | :add(nn.ConcatTable()
64 | :add(nn.Sequential()
65 | :add(nn.SelectTable(2))
66 | :add(nn.View(num_mentions * max_num_cand, ent_vecs_size)) -- E : num_mentions x max_num_cand x ent_vecs_size
67 | :add(param_A_linear)
68 | :add(nn.View(num_mentions, max_num_cand, ent_vecs_size)) -- (A*) E
69 | )
70 | :add(nn.SelectTable(1)) --- W : num_mentions x opt.ctxt_window x ent_vecs_size
71 | )
72 | :add(nn.MM(false, true)) -- 2 : E^t * A * W : num_mentions x max_num_cand x opt.ctxt_window
73 | )
74 | model:add(mem_weights_p_2)
75 |
76 |
77 | local mem_weights_p_3 = nn.ConcatTable()
78 | :add(nn.SelectTable(1)) -- 1 : {W, E, logp(e|m)}
79 | :add(nn.Sequential()
80 | :add(nn.SelectTable(2))
81 | :add(nn.Max(2)) --- 2 : max(word-entity scores) : num_mentions x opt.ctxt_window
82 | )
83 |
84 | model:add(mem_weights_p_3)
85 |
86 |
87 | local mem_weights_p_4 = nn.ConcatTable()
88 | :add(nn.SelectTable(1)) -- 1 : {W, E, logp(e|m)}
89 | :add(nn.Sequential()
90 | :add(nn.SelectTable(2))
91 | -- keep only top K scored words
92 | :add(nn.ConcatTable()
93 | :add(nn.Identity()) -- all w-e scores
94 | :add(nn.Sequential()
95 | :add(nn.View(num_mentions, opt.ctxt_window, 1))
96 | :add(nn.TemporalDynamicKMaxPooling(opt.R))
97 | :add(nn.Min(2)) -- k-th largest w-e score
98 | :add(nn.View(num_mentions))
99 | :add(nn.Replicate(opt.ctxt_window, 2))
100 | )
101 | )
102 | :add(nn.ConcatTable() -- top k w-e scores (the rest are set to -infty)
103 | :add(nn.SelectTable(2)) -- k-th largest w-e score that we substract and then add again back after nn.Threshold
104 | :add(nn.Sequential()
105 | :add(nn.CSubTable())
106 | :add(nn.Threshold(0, -50, true))
107 | )
108 | )
109 | :add(nn.CAddTable())
110 | :add(nn.SoftMax()) -- 2 : sigma (attention weights normalized): num_mentions x opt.ctxt_window
111 | :add(nn.View(num_mentions, opt.ctxt_window, 1))
112 | )
113 |
114 | model:add(mem_weights_p_4)
115 |
116 |
117 | local ctxt_full_embeddings = nn.ConcatTable()
118 | :add(nn.SelectTable(1)) -- 1 : {W, E, logp(e|m)}
119 | :add(nn.Sequential()
120 | :add(nn.ConcatTable()
121 | :add(nn.Sequential()
122 | :add(nn.SelectTable(1))
123 | :add(nn.SelectTable(1)) -- W
124 | )
125 | :add(nn.SelectTable(2)) -- sigma
126 | )
127 | :add(nn.MM(true, false))
128 | :add(nn.View(num_mentions, ent_vecs_size)) -- 2 : ctxt embedding = (W * B)^\top * sigma : num_mentions x ent_vecs_size
129 | )
130 |
131 | model:add(ctxt_full_embeddings)
132 |
133 |
134 | local entity_context_sim_scores = nn.ConcatTable()
135 | :add(nn.SelectTable(1)) -- 1 : {W, E, logp(e|m)}
136 | :add(nn.Sequential()
137 | :add(nn.ConcatTable()
138 | :add(nn.Sequential()
139 | :add(nn.SelectTable(1))
140 | :add(nn.SelectTable(2)) -- E
141 | )
142 | :add(nn.Sequential()
143 | :add(nn.SelectTable(2))
144 | :add(param_B_linear)
145 | :add(nn.View(num_mentions, ent_vecs_size, 1)) -- context vectors
146 | )
147 | )
148 | :add(nn.MM()) --> 2. context * E^T
149 | :add(nn.View(num_mentions, max_num_cand))
150 | )
151 |
152 | model:add(entity_context_sim_scores)
153 |
154 | if opt.model == 'local' then
155 | model = nn.Sequential()
156 | :add(nn.ConcatTable()
157 | :add(nn.Sequential()
158 | :add(model)
159 | :add(nn.SelectTable(2)) -- context - entity similarity scores
160 | :add(nn.View(num_mentions * max_num_cand, 1))
161 | )
162 | :add(nn.Sequential()
163 | :add(nn.SelectTable(3)) -- log p(e|m) scores
164 | :add(nn.View(num_mentions * max_num_cand, 1))
165 | )
166 | )
167 | :add(nn.JoinTable(2))
168 | :add(f_network)
169 | :add(nn.View(num_mentions, max_num_cand))
170 | else
171 | model:add(nn.SelectTable(2)) -- context - entity similarity scores
172 | end
173 |
174 |
175 | ------------- Visualizing weights:
176 |
177 | -- sigma (attention weights normalized): num_mentions x opt.ctxt_window
178 | local model_debug_softmax_word_weights = nn.Sequential()
179 | :add(ctxt_embed_and_ent_lookup)
180 | :add(mem_weights_p_2)
181 | :add(mem_weights_p_3)
182 | :add(mem_weights_p_4)
183 | :add(nn.SelectTable(2))
184 | :add(nn.View(num_mentions, opt.ctxt_window))
185 |
186 | ------- Cuda conversions:
187 | if string.find(opt.type, 'cuda') then
188 | model = model:cuda()
189 | model_debug_softmax_word_weights = model_debug_softmax_word_weights:cuda()
190 | end
191 |
192 | local additional_local_submodels = {
193 | model_final_local = model,
194 | model_debug_softmax_word_weights = model_debug_softmax_word_weights,
195 | }
196 |
197 | return model, additional_local_submodels
198 | end
199 |
200 |
201 | --- Unit tests
202 | if unit_tests_now then
203 | print('Network model unit tests:')
204 | local num_mentions = 13
205 |
206 | local inputs = {}
207 |
208 | -- ctxt_w_vecs
209 | inputs[1] = {}
210 | inputs[1][1] = torch.ones(num_mentions, opt.ctxt_window):int():mul(unk_w_id)
211 | inputs[1][2] = torch.randn(num_mentions, opt.ctxt_window, ent_vecs_size)
212 | -- e_vecs
213 | inputs[2] = {}
214 | inputs[2][1] = torch.ones(num_mentions, max_num_cand):int():mul(unk_ent_thid)
215 | inputs[2][2] = torch.randn(num_mentions, max_num_cand, ent_vecs_size)
216 | -- p(e|m)
217 | inputs[3] = torch.zeros(num_mentions, max_num_cand)
218 |
219 | local model, additional_local_submodels = local_model(num_mentions, A_linear, B_linear, opt)
220 |
221 | print(additional_local_submodels.model_debug_softmax_word_weights:forward(inputs):size())
222 |
223 | local outputs = model:forward(inputs)
224 | assert(outputs:size(1) == num_mentions and outputs:size(2) == max_num_cand)
225 | print('FWD success!')
226 |
227 | model:backward(inputs, torch.randn(num_mentions, max_num_cand))
228 | print('BKWD success!')
229 |
230 | parameters,gradParameters = model:getParameters()
231 | print(parameters:size())
232 | print(gradParameters:size())
233 | end
234 |
--------------------------------------------------------------------------------
/ed/test/check_coref.lua:
--------------------------------------------------------------------------------
1 | -- Runs our trivial coreference resolution method and outputs the new set of
2 | -- entity candidates. Used for debugging the coreference resolution method.
3 |
4 | if not opt then
5 | cmd = torch.CmdLine()
6 | cmd:option('-root_data_dir', '', 'Root path of the data, $DATA_PATH.')
7 | cmd:text()
8 | opt = cmd:parse(arg or {})
9 | assert(opt.root_data_dir ~= '', 'Specify a valid root_data_dir path argument.')
10 | end
11 |
12 |
13 | require 'torch'
14 | dofile 'utils/utils.lua'
15 |
16 | tds = tds or require 'tds'
17 |
18 | dofile 'entities/ent_name2id_freq/ent_name_id.lua'
19 | dofile 'ed/test/coref.lua'
20 |
21 | file = opt.root_data_dir .. 'generated/test_train_data/aida_testB.csv'
22 |
23 | opt = {}
24 | opt.coref = true
25 |
26 | it, _ = io.open(file)
27 | local all_doc_lines = tds.Hash()
28 | local line = it:read()
29 | while line do
30 | local parts = split(line, '\t')
31 | local doc_name = parts[1]
32 | if not all_doc_lines[doc_name] then
33 | all_doc_lines[doc_name] = tds.Hash()
34 | end
35 | all_doc_lines[doc_name][1 + #all_doc_lines[doc_name]] = line
36 | line = it:read()
37 | end
38 | -- Gather coreferent mentions to increase accuracy.
39 | build_coreference_dataset(all_doc_lines, 'aida-B')
40 |
--------------------------------------------------------------------------------
/ed/test/coref_persons.lua:
--------------------------------------------------------------------------------
1 | -- Given a dataset, try to retrieve better entity candidates
2 | -- for ambiguous mentions of persons. For example, suppose a document
3 | -- contains a mention of a person called 'Peter Such' that can be easily solved with
4 | -- the current system. Now suppose that, in the same document, there
5 | -- exists a mention 'Such' referring to the same person. For this
6 | -- second highly ambiguous mention, retrieving the correct entity in
7 | -- top K candidates would be very hard. We adopt here a simple heuristical strategy of
8 | -- searching in the same document all potentially coreferent mentions that strictly contain
9 | -- the given mention as a substring. If such mentions exist and they refer to
10 | -- persons (contain at least one person candidate entity), then the ambiguous
11 | -- shorter mention gets as candidates the candidates of the longer mention.
12 |
13 | tds = tds or require 'tds'
14 | assert(get_ent_wikiid_from_name)
15 | print('==> Loading index of Wiki entities that represent persons.')
16 |
17 | local persons_ent_wikiids = tds.Hash()
18 | for line in io.lines(opt.root_data_dir .. 'basic_data/p_e_m_data/persons.txt') do
19 | local ent_wikiid = get_ent_wikiid_from_name(line, true)
20 | if ent_wikiid ~= unk_ent_wikiid then
21 | persons_ent_wikiids[ent_wikiid] = 1
22 | end
23 | end
24 |
25 | function is_person(ent_wikiid)
26 | return persons_ent_wikiids[ent_wikiid]
27 | end
28 |
29 | print(' Done loading persons index. Size = ' .. #persons_ent_wikiids)
30 |
31 |
32 | local function mention_refers_to_person(m, mention_ent_cand)
33 | local top_p = 0
34 | local top_ent = -1
35 | for e_wikiid, p_e_m in pairs(mention_ent_cand[m]) do
36 | if p_e_m > top_p then
37 | top_ent = e_wikiid
38 | top_p = p_e_m
39 | end
40 | end
41 | return is_person(top_ent)
42 | end
43 |
44 |
45 | function build_coreference_dataset(dataset_lines, banner)
46 | if (not opt.coref) then
47 | return dataset_lines
48 | else
49 |
50 | -- Create new entity candidates
51 | local coref_dataset_lines = tds.Hash()
52 | for doc_id, lines_map in pairs(dataset_lines) do
53 |
54 | coref_dataset_lines[doc_id] = tds.Hash()
55 |
56 | -- Collect entity candidates for each mention.
57 | local mention_ent_cand = {}
58 | for _,sample_line in pairs(lines_map) do
59 | local parts = split(sample_line, "\t")
60 | assert(doc_id == parts[1])
61 | local mention = parts[3]:lower()
62 | if not mention_ent_cand[mention] then
63 | mention_ent_cand[mention] = {}
64 | assert(parts[6] == 'CANDIDATES')
65 | if parts[7] ~= 'EMPTYCAND' then
66 | local num_cand = 1
67 | while parts[6 + num_cand] ~= 'GT:' do
68 | local cand_parts = split(parts[6 + num_cand], ',')
69 | local cand_ent_wikiid = tonumber(cand_parts[1])
70 | local cand_p_e_m = tonumber(cand_parts[2])
71 | assert(cand_p_e_m >= 0, cand_p_e_m)
72 | assert(cand_ent_wikiid)
73 | mention_ent_cand[mention][cand_ent_wikiid] = cand_p_e_m
74 | num_cand = num_cand + 1
75 | end
76 | end
77 | end
78 | end
79 |
80 | -- Find coreferent mentions
81 | for _,sample_line in pairs(lines_map) do
82 | local parts = split(sample_line, "\t")
83 | assert(doc_id == parts[1])
84 | local mention = parts[3]:lower()
85 |
86 | assert(mention_ent_cand[mention])
87 | assert(parts[table_len(parts) - 1] == 'GT:')
88 |
89 | -- Grd trth infos
90 | local grd_trth_parts = split(parts[table_len(parts)], ',')
91 | local grd_trth_idx = tonumber(grd_trth_parts[1])
92 | assert(grd_trth_idx == -1 or table_len(grd_trth_parts) >= 4, sample_line)
93 | local grd_trth_entwikiid = -1
94 | if table_len(grd_trth_parts) >= 3 then
95 | grd_trth_entwikiid = tonumber(grd_trth_parts[2])
96 | end
97 | assert(grd_trth_entwikiid)
98 |
99 | -- Merge lists of entity candidates
100 | local added_list = {}
101 | local num_added_mentions = 0
102 | for m,_ in pairs(mention_ent_cand) do
103 | local stupid_lua_pattern = string.gsub(mention, '%.', '%%%.')
104 | stupid_lua_pattern = string.gsub(stupid_lua_pattern, '%-', '%%%-')
105 | if m ~= mention and (string.find(m, ' ' .. stupid_lua_pattern) or string.find(m, stupid_lua_pattern .. ' ')) and mention_refers_to_person(m, mention_ent_cand) then
106 |
107 | if banner == 'aida-B' then
108 | print(blue('coref mention = ' .. m .. ' replaces original mention = ' .. mention) ..
109 | ' ; DOC = ' .. doc_id)
110 | end
111 |
112 | num_added_mentions = num_added_mentions + 1
113 | for e_wikiid, p_e_m in pairs(mention_ent_cand[m]) do
114 | if not added_list[e_wikiid] then
115 | added_list[e_wikiid] = 0.0
116 | end
117 | added_list[e_wikiid] = added_list[e_wikiid] + p_e_m
118 | end
119 | end
120 | end
121 |
122 | -- Average:
123 | for e_wikiid, _ in pairs(added_list) do
124 | added_list[e_wikiid] = added_list[e_wikiid] / num_added_mentions
125 | end
126 |
127 | -- Merge the two lists
128 | local merged_list = mention_ent_cand[mention]
129 | if num_added_mentions > 0 then
130 | merged_list = added_list
131 | end
132 |
133 | local sorted_list = {}
134 | for ent_wikiid,p in pairs(merged_list) do
135 | table.insert(sorted_list, {ent_wikiid = ent_wikiid, p = p})
136 | end
137 | table.sort(sorted_list, function(a,b) return a.p > b.p end)
138 |
139 | -- Write the new line
140 | local str = parts[1] .. '\t' .. parts[2] .. '\t' .. parts[3] .. '\t' .. parts[4] .. '\t'
141 | .. parts[5] .. '\t' .. parts[6] .. '\t'
142 |
143 | if table_len(sorted_list) == 0 then
144 | str = str .. 'EMPTYCAND\tGT:\t-1'
145 | if grd_trth_entwikiid ~= unk_ent_wikiid then
146 | str = str .. ',' .. grd_trth_entwikiid .. ',' ..
147 | get_ent_name_from_wikiid(grd_trth_entwikiid)
148 | end
149 | else
150 | local candidates = {}
151 | local gt_pos = -1
152 | for pos,e in pairs(sorted_list) do
153 | if pos <= 100 then
154 | table.insert(candidates, e.ent_wikiid .. ',' ..
155 | string.format("%.3f", e.p) .. ',' .. get_ent_name_from_wikiid(e.ent_wikiid))
156 | if e.ent_wikiid == grd_trth_entwikiid then
157 | gt_pos = pos
158 | end
159 | else
160 | break
161 | end
162 | end
163 | str = str .. table.concat(candidates, '\t') .. '\tGT:\t'
164 |
165 | if gt_pos > 0 then
166 | str = str .. gt_pos .. ',' .. candidates[gt_pos]
167 | else
168 | str = str .. '-1'
169 | if grd_trth_entwikiid ~= unk_ent_wikiid then
170 | str = str .. ',' .. grd_trth_entwikiid .. ',' ..
171 | get_ent_name_from_wikiid(grd_trth_entwikiid)
172 | end
173 | end
174 | end
175 |
176 | coref_dataset_lines[doc_id][1 + #coref_dataset_lines[doc_id]] = str
177 | end
178 | end
179 |
180 | assert(#dataset_lines == #coref_dataset_lines)
181 | for doc_id, lines_map in pairs(dataset_lines) do
182 | assert(table_len(dataset_lines[doc_id]) == table_len(coref_dataset_lines[doc_id]))
183 | end
184 |
185 | return coref_dataset_lines
186 | end
187 | end
--------------------------------------------------------------------------------
/ed/test/ent_freq_stats_test.lua:
--------------------------------------------------------------------------------
1 | -- Statistics of annotated entities based on their frequency in Wikipedia corpus
2 | -- Table 6 (left) from our paper
3 | local function ent_freq_to_key(f)
4 | if f == 0 then
5 | return '0'
6 | elseif f == 1 then
7 | return '1'
8 | elseif f <= 5 then
9 | return '2-5'
10 | elseif f <= 10 then
11 | return '6-10'
12 | elseif f <= 20 then
13 | return '11-20'
14 | elseif f <= 50 then
15 | return '21-50'
16 | else
17 | return '50+'
18 | end
19 | end
20 |
21 |
22 | function new_ent_freq_map()
23 | local m = {}
24 | m['0'] = 0.0
25 | m['1'] = 0.0
26 | m['2-5'] = 0.0
27 | m['6-10'] = 0.0
28 | m['11-20'] = 0.0
29 | m['21-50'] = 0.0
30 | m['50+'] = 0.0
31 | return m
32 | end
33 |
34 | function add_freq_to_ent_freq_map(m, f)
35 | m[ent_freq_to_key(f)] = m[ent_freq_to_key(f)] + 1
36 | end
37 |
38 | function print_ent_freq_maps_stats(smallm, bigm)
39 | print(' ===> entity frequency stats :')
40 | for k,_ in pairs(smallm) do
41 | local perc = 0
42 | if bigm[k] > 0 then
43 | perc = 100.0 * smallm[k] / bigm[k]
44 | end
45 | assert(perc <= 100)
46 | print('freq = ' .. k .. ' : num = ' .. bigm[k] ..
47 | ' ; correctly classified = ' .. smallm[k] ..
48 | ' ; perc = ' .. string.format("%.2f", perc))
49 | end
50 | print('')
51 | end
52 |
53 |
54 |
--------------------------------------------------------------------------------
/ed/test/ent_p_e_m_stats_test.lua:
--------------------------------------------------------------------------------
1 | -- Statistics of annotated entities based on their p(e|m) prio
2 | -- Table 6 (right) from our paper
3 |
4 | local function ent_prior_to_key(f)
5 | if f <= 0.001 then
6 | return '<=0.001'
7 | elseif f <= 0.003 then
8 | return '0.001-0.003'
9 | elseif f <= 0.01 then
10 | return '0.003-0.01'
11 | elseif f <= 0.03 then
12 | return '0.01-0.03'
13 | elseif f <= 0.1 then
14 | return '0.03-0.1'
15 | elseif f <= 0.3 then
16 | return '0.1-0.3'
17 | else
18 | return '0.3+'
19 | end
20 | end
21 |
22 |
23 | function new_ent_prior_map()
24 | local m = {}
25 | m['<=0.001'] = 0.0
26 | m['0.001-0.003'] = 0.0
27 | m['0.003-0.01'] = 0.0
28 | m['0.01-0.03'] = 0.0
29 | m['0.03-0.1'] = 0.0
30 | m['0.1-0.3'] = 0.0
31 | m['0.3+'] = 0.0
32 | return m
33 | end
34 |
35 | function add_prior_to_ent_prior_map(m, f)
36 | m[ent_prior_to_key(f)] = m[ent_prior_to_key(f)] + 1
37 | end
38 |
39 | function print_ent_prior_maps_stats(smallm, bigm)
40 | print(' ===> entity p(e|m) stats :')
41 | for k,_ in pairs(smallm) do
42 | local perc = 0
43 | if bigm[k] > 0 then
44 | perc = 100.0 * smallm[k] / bigm[k]
45 | end
46 | assert(perc <= 100)
47 | print('p(e|m) = ' .. k .. ' : num = ' .. bigm[k] ..
48 | ' ; correctly classified = ' .. smallm[k] ..
49 | ' ; perc = ' .. string.format("%.2f", perc))
50 | end
51 | end
52 |
53 |
54 |
--------------------------------------------------------------------------------
/ed/test/test_one_loaded_model.lua:
--------------------------------------------------------------------------------
1 | -- Test one single ED model trained using ed/ed.lua
2 |
3 | -- Run: CUDA_VISIBLE_DEVICES=0 th ed/test/test_one_loaded_model.lua -root_data_dir $DATA_PATH -model global -ent_vecs_filename $ENTITY_VECS -test_one_model_file $ED_MODEL_FILENAME
4 | require 'optim'
5 | require 'torch'
6 | require 'gnuplot'
7 | require 'nn'
8 | require 'xlua'
9 | tds = tds or require 'tds'
10 |
11 | dofile 'utils/utils.lua'
12 | print('\n' .. green('==========> Test a pre-trained ED neural model <==========') .. '\n')
13 |
14 | dofile 'ed/args.lua'
15 |
16 | print('===> RUN TYPE: ' .. opt.type)
17 | torch.setdefaulttensortype('torch.FloatTensor')
18 | if string.find(opt.type, 'cuda') then
19 | print('==> switching to CUDA (GPU)')
20 | require 'cunn'
21 | require 'cutorch'
22 | require 'cudnn'
23 | cudnn.benchmark = true
24 | cudnn.fastest = true
25 | else
26 | print('==> running on CPU')
27 | end
28 |
29 | extract_args_from_model_title(opt.test_one_model_file)
30 |
31 | dofile 'utils/logger.lua'
32 | dofile 'entities/relatedness/relatedness.lua'
33 | dofile 'entities/ent_name2id_freq/ent_name_id.lua'
34 | dofile 'entities/ent_name2id_freq/e_freq_index.lua'
35 | dofile 'words/load_w_freq_and_vecs.lua'
36 | dofile 'words/w2v/w2v.lua'
37 | dofile 'entities/pretrained_e2v/e2v.lua'
38 | dofile 'ed/minibatch/build_minibatch.lua'
39 | dofile 'ed/models/model.lua'
40 | dofile 'ed/test/test.lua'
41 |
42 | local saved_linears = torch.load(opt.root_data_dir .. 'generated/ed_models/' .. opt.test_one_model_file)
43 | unpack_saveable_weights(saved_linears)
44 |
45 | test()
46 |
--------------------------------------------------------------------------------
/ed/train.lua:
--------------------------------------------------------------------------------
1 | -- Training of ED models.
2 |
3 | if opt.opt == 'SGD' then
4 | optimState = {
5 | learningRate = opt.lr,
6 | momentum = 0.9,
7 | learningRateDecay = 5e-7
8 | }
9 | optimMethod = optim.sgd
10 |
11 | elseif opt.opt == 'ADAM' then -- See: http://cs231n.github.io/neural-networks-3/#update
12 | optimState = {
13 | learningRate = opt.lr,
14 | }
15 | optimMethod = optim.adam
16 |
17 | elseif opt.opt == 'ADADELTA' then -- See: http://cs231n.github.io/neural-networks-3/#update
18 | -- Run with default parameters, no need for learning rate and other stuff
19 | optimState = {}
20 | optimConfig = {}
21 | optimMethod = optim.adadelta
22 |
23 | elseif opt.opt == 'ADAGRAD' then -- See: http://cs231n.github.io/neural-networks-3/#update
24 | optimMethod = optim.adagrad
25 | optimState = {
26 | learningRate = opt.lr
27 | }
28 |
29 | else
30 | error('unknown optimization method')
31 | end
32 | ----------------------------------------------------------------------
33 |
34 | -- Each batch is one document, so we test/validate/save the current model after each set of
35 | -- 5000 documents. Since aida-train contains 946 documents, this is equivalent with 5 full epochs.
36 | num_batches_per_epoch = 5000
37 |
38 | function train_and_test()
39 |
40 | print('\nDone testing for ' .. banner)
41 | print('Params serialized = ' .. params_serialized)
42 |
43 | -- epoch tracker
44 | epoch = 1
45 |
46 | local processed_so_far = 0
47 |
48 | local f_bs = 0
49 | local gradParameters_bs = nil
50 |
51 | while true do
52 | local time = sys.clock()
53 | print('\n')
54 | print('One epoch = ' .. (num_batches_per_epoch / 1000) .. ' full passes over AIDA-TRAIN in our case.')
55 | print(green('==> TRAINING EPOCH #' .. epoch .. ' <=='))
56 |
57 | print_net_weights()
58 |
59 | local processed_mentions = 0
60 | for batch_index = 1,num_batches_per_epoch do
61 | -- Read one mini-batch from one data_thread:
62 | local inputs, targets = get_minibatch()
63 |
64 | local num_mentions = targets:size(1)
65 | processed_mentions = processed_mentions + num_mentions
66 |
67 | local model, _ = get_model(num_mentions)
68 | model:training()
69 |
70 | -- Retrieve parameters and gradients:
71 | -- extracts and flattens all model's parameters into a 1-dim vector
72 | parameters,gradParameters = model:getParameters()
73 | gradParameters:zero()
74 |
75 | -- Just in case:
76 | collectgarbage()
77 | collectgarbage()
78 |
79 | -- Reset gradients
80 | gradParameters:zero()
81 |
82 | -- Evaluate function for complete mini batch
83 |
84 | local outputs = model:forward(inputs)
85 | assert(outputs:size(1) == num_mentions and outputs:size(2) == max_num_cand)
86 | local f = criterion:forward(outputs, targets)
87 |
88 | -- Estimate df/dW
89 | local df_do = criterion:backward(outputs, targets)
90 |
91 | model:backward(inputs, df_do)
92 |
93 | if opt.batch_size == 1 or batch_index % opt.batch_size == 1 then
94 | gradParameters_bs = gradParameters:clone():zero()
95 | f_bs = 0
96 | end
97 |
98 | gradParameters_bs:add(gradParameters)
99 | f_bs = f_bs + f
100 |
101 | if opt.batch_size == 1 or batch_index % opt.batch_size == 0 then
102 |
103 | gradParameters_bs:div(opt.batch_size)
104 | f_bs = f_bs / opt.batch_size
105 |
106 | -- Create closure to evaluate f(X) and df/dX
107 | local feval = function(x)
108 | return f_bs, gradParameters_bs
109 | end
110 |
111 | -- Optimize on current mini-batch
112 | optimState.learningRate = opt.lr
113 | optimMethod(feval, parameters, optimState)
114 |
115 | -- Regularize the f_network with projected SGD.
116 | regularize_f_network()
117 | end
118 |
119 | -- Display progress
120 | processed_so_far = processed_so_far + num_mentions
121 | if processed_so_far > 100000000 then
122 | processed_so_far = processed_so_far - 100000000
123 | end
124 | xlua.progress(processed_so_far, 100000000)
125 | end
126 |
127 | -- Measure time taken
128 | time = sys.clock() - time
129 | time = time / processed_mentions
130 |
131 | print("\n==> time to learn 1 sample = " .. (time*1000) .. 'ms')
132 |
133 | -- Test:
134 | test(epoch)
135 | print('\nDone testing for ' .. banner)
136 | print('Params serialized = ' .. params_serialized)
137 |
138 | -- Save the current model:
139 | if opt.save then
140 | local filename = opt.root_data_dir .. 'generated/ed_models/' .. params_serialized .. '|ep=' .. epoch
141 | print('==> saving model to '..filename)
142 | torch.save(filename, pack_saveable_weights())
143 | end
144 |
145 | -- Next epoch
146 | epoch = epoch + 1
147 | end
148 | end
149 |
--------------------------------------------------------------------------------
/entities/ent_name2id_freq/e_freq_gen.lua:
--------------------------------------------------------------------------------
1 | -- Creates a file that contains entity frequencies.
2 |
3 | if not opt then
4 | cmd = torch.CmdLine()
5 | cmd:option('-root_data_dir', '', 'Root path of the data, $DATA_PATH.')
6 | cmd:text()
7 | opt = cmd:parse(arg or {})
8 | assert(opt.root_data_dir ~= '', 'Specify a valid root_data_dir path argument.')
9 | end
10 |
11 |
12 | require 'torch'
13 | tds = tds or require 'tds'
14 |
15 | dofile 'utils/utils.lua'
16 | dofile 'entities/ent_name2id_freq/ent_name_id.lua'
17 |
18 | entity_freqs = tds.Hash()
19 |
20 | local num_lines = 0
21 | it, _ = io.open(opt.root_data_dir .. 'generated/crosswikis_wikipedia_p_e_m.txt')
22 | line = it:read()
23 |
24 | while (line) do
25 | num_lines = num_lines + 1
26 | if num_lines % 2000000 == 0 then
27 | print('Processed ' .. num_lines .. ' lines. ')
28 | end
29 |
30 | local parts = split(line , '\t')
31 | local num_parts = table_len(parts)
32 | for i = 3, num_parts do
33 | local ent_str = split(parts[i], ',')
34 | local ent_wikiid = tonumber(ent_str[1])
35 | local freq = tonumber(ent_str[2])
36 | assert(ent_wikiid)
37 | assert(freq)
38 |
39 | if (not entity_freqs[ent_wikiid]) then
40 | entity_freqs[ent_wikiid] = 0
41 | end
42 | entity_freqs[ent_wikiid] = entity_freqs[ent_wikiid] + freq
43 | end
44 | line = it:read()
45 | end
46 |
47 |
48 | -- Writing word frequencies
49 | print('Sorting and writing')
50 | sorted_ent_freq = {}
51 | for ent_wikiid,freq in pairs(entity_freqs) do
52 | if freq >= 10 then
53 | table.insert(sorted_ent_freq, {ent_wikiid = ent_wikiid, freq = freq})
54 | end
55 | end
56 |
57 | table.sort(sorted_ent_freq, function(a,b) return a.freq > b.freq end)
58 |
59 | out_file = opt.root_data_dir .. 'generated/ent_wiki_freq.txt'
60 | ouf = assert(io.open(out_file, "w"))
61 | total_freq = 0
62 | for _,x in pairs(sorted_ent_freq) do
63 | ouf:write(x.ent_wikiid .. '\t' .. get_ent_name_from_wikiid(x.ent_wikiid) .. '\t' .. x.freq .. '\n')
64 | total_freq = total_freq + x.freq
65 | end
66 | ouf:flush()
67 | io.close(ouf)
68 |
69 | print('Total freq = ' .. total_freq .. '\n')
70 |
--------------------------------------------------------------------------------
/entities/ent_name2id_freq/e_freq_index.lua:
--------------------------------------------------------------------------------
1 | -- Loads an index containing entity -> frequency pairs.
2 | -- TODO: rewrite this file in a simpler way (is complicated because of some past experiments).
3 | tds = tds or require 'tds'
4 |
5 | print('==> Loading entity freq map')
6 |
7 | local ent_freq_file = opt.root_data_dir .. 'generated/ent_wiki_freq.txt'
8 |
9 | min_freq = 1
10 | e_freq = tds.Hash()
11 | e_freq.ent_f_start = tds.Hash()
12 | e_freq.ent_f_end = tds.Hash()
13 | e_freq.total_freq = 0
14 | e_freq.sorted = tds.Hash()
15 |
16 | cur_start = 1
17 | cnt = 0
18 | for line in io.lines(ent_freq_file) do
19 | local parts = split(line, '\t')
20 | local ent_wikiid = tonumber(parts[1])
21 | local ent_f = tonumber(parts[3])
22 | assert(ent_wikiid)
23 | assert(ent_f)
24 |
25 | if (not rewtr) or rewtr.reltd_ents_wikiid_to_rltdid[ent_wikiid] then
26 | e_freq.ent_f_start[ent_wikiid] = cur_start
27 | e_freq.ent_f_end[ent_wikiid] = cur_start + ent_f - 1
28 | e_freq.sorted[cnt] = ent_wikiid
29 | cur_start = cur_start + ent_f
30 | cnt = cnt + 1
31 | end
32 | end
33 |
34 | e_freq.total_freq = cur_start - 1
35 |
36 | print(' Done loading entity freq index. Size = ' .. cnt)
37 |
38 | function get_ent_freq(ent_wikiid)
39 | if e_freq.ent_f_start[ent_wikiid] then
40 | return e_freq.ent_f_end[ent_wikiid] - e_freq.ent_f_start[ent_wikiid] + 1
41 | end
42 | return 0
43 | end
44 |
45 |
--------------------------------------------------------------------------------
/entities/ent_name2id_freq/ent_name_id.lua:
--------------------------------------------------------------------------------
1 | ------------------ Load entity name-id mappings ------------------
2 | -- Each entity has:
3 | -- a) a Wikipedia URL referred as 'name' here
4 | -- b) a Wikipedia ID referred as 'ent_wikiid' or 'wikiid' here
5 | -- c) an ID that will be used in the entity embeddings lookup table. Referred as 'ent_thid' or 'thid' here.
6 |
7 | tds = tds or require 'tds' -- saves lots of memory for ent_name_id.lua. Mem overflow with normal {}
8 | local rltd_only = false
9 | if opt and opt.entities and opt.entities ~= 'ALL' then
10 | assert(rewtr.reltd_ents_wikiid_to_rltdid, 'Import relatedness.lua before ent_name_id.lua')
11 | rltd_only = true
12 | end
13 |
14 | -- Unk entity wikid:
15 | unk_ent_wikiid = 1
16 |
17 | local entity_wiki_txtfilename = opt.root_data_dir .. 'basic_data/wiki_name_id_map.txt'
18 | local entity_wiki_t7filename = opt.root_data_dir .. 'generated/ent_name_id_map.t7'
19 | if rltd_only then
20 | entity_wiki_t7filename = opt.root_data_dir .. 'generated/ent_name_id_map_RLTD.t7'
21 | end
22 |
23 | print('==> Loading entity wikiid - name map')
24 |
25 | local e_id_name = nil
26 |
27 | if paths.filep(entity_wiki_t7filename) then
28 | print(' ---> from t7 file: ' .. entity_wiki_t7filename)
29 | e_id_name = torch.load(entity_wiki_t7filename)
30 |
31 | else
32 | print(' ---> t7 file NOT found. Loading from disk (slower). Out f = ' .. entity_wiki_t7filename)
33 | dofile 'data_gen/indexes/wiki_disambiguation_pages_index.lua'
34 | print(' Still loading entity wikiid - name map ...')
35 |
36 | e_id_name = tds.Hash()
37 |
38 | -- map for entity name to entity wiki id
39 | e_id_name.ent_wikiid2name = tds.Hash()
40 | e_id_name.ent_name2wikiid = tds.Hash()
41 |
42 | -- map for entity wiki id to tensor id. Size = 4.4M
43 | if not rltd_only then
44 | e_id_name.ent_wikiid2thid = tds.Hash()
45 | e_id_name.ent_thid2wikiid = tds.Hash()
46 | end
47 |
48 | local cnt = 0
49 | local cnt_freq = 0
50 | for line in io.lines(entity_wiki_txtfilename) do
51 | local parts = split(line, '\t')
52 | local ent_name = parts[1]
53 | local ent_wikiid = tonumber(parts[2])
54 |
55 | if (not wiki_disambiguation_index[ent_wikiid]) then
56 | if (not rltd_only) or rewtr.reltd_ents_wikiid_to_rltdid[ent_wikiid] then
57 | e_id_name.ent_wikiid2name[ent_wikiid] = ent_name
58 | e_id_name.ent_name2wikiid[ent_name] = ent_wikiid
59 | end
60 | if not rltd_only then
61 | cnt = cnt + 1
62 | e_id_name.ent_wikiid2thid[ent_wikiid] = cnt
63 | e_id_name.ent_thid2wikiid[cnt] = ent_wikiid
64 | end
65 | end
66 | end
67 |
68 | if not rltd_only then
69 | cnt = cnt + 1
70 | e_id_name.ent_wikiid2thid[unk_ent_wikiid] = cnt
71 | e_id_name.ent_thid2wikiid[cnt] = unk_ent_wikiid
72 | end
73 | e_id_name.ent_wikiid2name[unk_ent_wikiid] = 'UNK_ENT'
74 | e_id_name.ent_name2wikiid['UNK_ENT'] = unk_ent_wikiid
75 |
76 | torch.save(entity_wiki_t7filename, e_id_name)
77 | end
78 |
79 | if not rltd_only then
80 | unk_ent_thid = e_id_name.ent_wikiid2thid[unk_ent_wikiid]
81 | else
82 | unk_ent_thid = rewtr.reltd_ents_wikiid_to_rltdid[unk_ent_wikiid]
83 | end
84 |
85 | ------------------------ Functions for wikiids and names-----------------
86 | function get_map_all_valid_ents()
87 | local m = tds.Hash()
88 | for ent_wikiid, _ in pairs(e_id_name.ent_wikiid2name) do
89 | m[ent_wikiid] = 1
90 | end
91 | return m
92 | end
93 |
94 | is_valid_ent = function(ent_wikiid)
95 | if e_id_name.ent_wikiid2name[ent_wikiid] then
96 | return true
97 | end
98 | return false
99 | end
100 |
101 |
102 | get_ent_name_from_wikiid = function(ent_wikiid)
103 | local ent_name = e_id_name.ent_wikiid2name[ent_wikiid]
104 | if (not ent_wikiid) or (not ent_name) then
105 | return 'NIL'
106 | end
107 | return ent_name
108 | end
109 |
110 | preprocess_ent_name = function(ent_name)
111 | ent_name = trim1(ent_name)
112 | ent_name = string.gsub(ent_name, '&', '&')
113 | ent_name = string.gsub(ent_name, '"', '"')
114 | ent_name = ent_name:gsub('_', ' ')
115 | ent_name = first_letter_to_uppercase(ent_name)
116 | if get_redirected_ent_title then
117 | ent_name = get_redirected_ent_title(ent_name)
118 | end
119 | return ent_name
120 | end
121 |
122 | get_ent_wikiid_from_name = function(ent_name, not_verbose)
123 | local verbose = (not not_verbose)
124 | ent_name = preprocess_ent_name(ent_name)
125 | local ent_wikiid = e_id_name.ent_name2wikiid[ent_name]
126 | if (not ent_wikiid) or (not ent_name) then
127 | if verbose then
128 | print(red('Entity ' .. ent_name .. ' not found. Redirects file needs to be loaded for better performance.'))
129 | end
130 | return unk_ent_wikiid
131 | end
132 | return ent_wikiid
133 | end
134 |
135 | ------------------------ Functions for thids and wikiids -----------------
136 | -- ent wiki id -> thid
137 | get_thid = function (ent_wikiid)
138 | if rltd_only then
139 | ent_thid = rewtr.reltd_ents_wikiid_to_rltdid[ent_wikiid]
140 | else
141 | ent_thid = e_id_name.ent_wikiid2thid[ent_wikiid]
142 | end
143 | if (not ent_wikiid) or (not ent_thid) then
144 | return unk_ent_thid
145 | end
146 | return ent_thid
147 | end
148 |
149 | contains_thid = function (ent_wikiid)
150 | if rltd_only then
151 | ent_thid = rewtr.reltd_ents_wikiid_to_rltdid[ent_wikiid]
152 | else
153 | ent_thid = e_id_name.ent_wikiid2thid[ent_wikiid]
154 | end
155 | if ent_wikiid == nil or ent_thid == nil then
156 | return false
157 | end
158 | return true
159 | end
160 |
161 | get_total_num_ents = function()
162 | if rltd_only then
163 | assert(table_len(rewtr.reltd_ents_wikiid_to_rltdid) == rewtr.num_rltd_ents)
164 | return table_len(rewtr.reltd_ents_wikiid_to_rltdid)
165 | else
166 | return #e_id_name.ent_thid2wikiid
167 | end
168 | end
169 |
170 | get_wikiid_from_thid = function(ent_thid)
171 | if rltd_only then
172 | ent_wikiid = rewtr.reltd_ents_rltdid_to_wikiid[ent_thid]
173 | else
174 | ent_wikiid = e_id_name.ent_thid2wikiid[ent_thid]
175 | end
176 | if ent_wikiid == nil or ent_thid == nil then
177 | return unk_ent_wikiid
178 | end
179 | return ent_wikiid
180 | end
181 |
182 | -- tensor of ent wiki ids --> tensor of thids
183 | get_ent_thids = function (ent_wikiids_tensor)
184 | local ent_thid_tensor = ent_wikiids_tensor:clone()
185 | if ent_wikiids_tensor:dim() == 2 then
186 | for i = 1,ent_thid_tensor:size(1) do
187 | for j = 1,ent_thid_tensor:size(2) do
188 | ent_thid_tensor[i][j] = get_thid(ent_wikiids_tensor[i][j])
189 | end
190 | end
191 | elseif ent_wikiids_tensor:dim() == 1 then
192 | for i = 1,ent_thid_tensor:size(1) do
193 | ent_thid_tensor[i] = get_thid(ent_wikiids_tensor[i])
194 | end
195 | else
196 | print('Tensor with > 2 dimentions not supported')
197 | os.exit()
198 | end
199 | return ent_thid_tensor
200 | end
201 |
202 | print(' Done loading entity name - wikiid. Size thid index = ' .. get_total_num_ents())
203 |
--------------------------------------------------------------------------------
/entities/learn_e2v/batch_dataset_a.lua:
--------------------------------------------------------------------------------
1 | dofile 'utils/utils.lua'
2 |
3 | if opt.entities == 'ALL' then
4 | wiki_words_train_file = opt.root_data_dir .. 'generated/wiki_canonical_words.txt'
5 | wiki_hyp_train_file = opt.root_data_dir .. 'generated/wiki_hyperlink_contexts.csv'
6 | else
7 | wiki_words_train_file = opt.root_data_dir .. 'generated/wiki_canonical_words_RLTD.txt'
8 | wiki_hyp_train_file = opt.root_data_dir .. 'generated/wiki_hyperlink_contexts_RLTD.csv'
9 | end
10 |
11 | wiki_words_it, _ = io.open(wiki_words_train_file)
12 | wiki_hyp_it, _ = io.open(wiki_hyp_train_file)
13 |
14 | assert(opt.num_passes_wiki_words)
15 | local train_data_source = 'wiki-canonical'
16 | local num_passes_wiki_words = 1
17 |
18 | local function read_one_line()
19 | if train_data_source == 'wiki-canonical' then
20 | line = wiki_words_it:read()
21 | else
22 | assert(train_data_source == 'wiki-canonical-hyperlinks')
23 | line = wiki_hyp_it:read()
24 | end
25 | if (not line) then
26 | if num_passes_wiki_words == opt.num_passes_wiki_words then
27 | train_data_source = 'wiki-canonical-hyperlinks'
28 | print('\n\n' .. 'Start training on Wiki Hyperlinks' .. '\n\n')
29 | end
30 | print('Training file is done. Num passes = ' .. num_passes_wiki_words .. '. Reopening.')
31 | num_passes_wiki_words = num_passes_wiki_words + 1
32 | if train_data_source == 'wiki-canonical' then
33 | wiki_words_it, _ = io.open(wiki_words_train_file)
34 | line = wiki_words_it:read()
35 | else
36 | wiki_hyp_it, _ = io.open(wiki_hyp_train_file)
37 | line = wiki_hyp_it:read()
38 | end
39 | end
40 | return line
41 | end
42 |
43 | local line = nil
44 |
45 | local function patch_of_lines(num)
46 | local lines = {}
47 | local cnt = 0
48 | assert(num > 0)
49 |
50 | while cnt < num do
51 | line = read_one_line()
52 | cnt = cnt + 1
53 | table.insert(lines, line)
54 | end
55 |
56 | assert(table_len(lines) == num)
57 | return lines
58 | end
59 |
60 |
61 | function get_minibatch()
62 | -- Create empty mini batch:
63 | local lines = patch_of_lines(opt.batch_size)
64 | local inputs = empty_minibatch()
65 | local targets = correct_type(torch.ones(opt.batch_size, opt.num_words_per_ent))
66 |
67 | -- Fill in each example:
68 | for i = 1,opt.batch_size do
69 | local sample_line = lines[i] -- load new example line
70 | local target = process_one_line(sample_line, inputs, i)
71 | targets[i]:copy(target)
72 | end
73 |
74 | --- Minibatch post processing:
75 | postprocess_minibatch(inputs, targets)
76 | targets = targets:view(opt.batch_size * opt.num_words_per_ent)
77 |
78 | -- Special target for the NEG and NCE losses
79 | if opt.loss == 'neg' or opt.loss == 'nce' then
80 | nce_targets = torch.ones(opt.batch_size * opt.num_words_per_ent, opt.num_neg_words):mul(-1)
81 | for j = 1,opt.batch_size * opt.num_words_per_ent do
82 | nce_targets[j][targets[j]] = 1
83 | end
84 | targets = nce_targets
85 | end
86 |
87 | return inputs, targets
88 | end
89 |
--------------------------------------------------------------------------------
/entities/learn_e2v/e2v_a.lua:
--------------------------------------------------------------------------------
1 | -- Entity embeddings utilities
2 |
3 | assert(opt.num_words_per_ent)
4 |
5 | -- Word lookup:
6 | geom_w2v_M = w2vutils.M:float()
7 |
8 | -- Stats:
9 | local num_invalid_ent_wikiids = 0
10 | local total_ent_wiki_vec_requests = 0
11 | local last_wrote = 0
12 | local function invalid_ent_wikiids_stats(ent_thid)
13 | total_ent_wiki_vec_requests = total_ent_wiki_vec_requests + 1
14 | if ent_thid == unk_ent_thid then
15 | num_invalid_ent_wikiids = num_invalid_ent_wikiids + 1
16 | end
17 | if (num_invalid_ent_wikiids % 15000 == 0 and num_invalid_ent_wikiids ~= last_wrote) then
18 | last_wrote = num_invalid_ent_wikiids
19 | local perc = 100.0 * num_invalid_ent_wikiids / total_ent_wiki_vec_requests
20 | print(red('*** Perc invalid ent wikiids = ' .. perc .. ' . Absolute num = ' .. num_invalid_ent_wikiids))
21 | end
22 | end
23 |
24 | -- ent id -> vec
25 | function geom_entwikiid2vec(ent_wikiid)
26 | local ent_thid = get_thid(ent_wikiid)
27 | assert(ent_thid)
28 | invalid_ent_wikiids_stats(ent_thid)
29 | local ent_vec = nn.Normalize(2):forward(lookup_ent_vecs.weight[ent_thid]:float())
30 | return ent_vec
31 | end
32 |
33 | -- ent name -> vec
34 | local function geom_entname2vec(ent_name)
35 | assert(ent_name)
36 | return geom_entwikiid2vec(get_ent_wikiid_from_name(ent_name))
37 | end
38 |
39 | function entity_similarity(e1_wikiid, e2_wikiid)
40 | local e1_vec = geom_entwikiid2vec(e1_wikiid)
41 | local e2_vec = geom_entwikiid2vec(e2_wikiid)
42 | return e1_vec * e2_vec
43 | end
44 |
45 |
46 | local function geom_top_k_closest_words(ent_name, ent_vec, k)
47 | local tf_map = ent_wiki_words_4EX[ent_name]
48 | local w_not_found = {}
49 | for w,_ in pairs(tf_map) do
50 | if tf_map[w] >= 10 then
51 | w_not_found[w] = tf_map[w]
52 | end
53 | end
54 |
55 | distances = geom_w2v_M * ent_vec
56 |
57 | local best_scores, best_word_ids = topk(distances, k)
58 | local returnwords = {}
59 | local returndistances = {}
60 | for i = 1,k do
61 | local w = get_word_from_id(best_word_ids[i])
62 | if is_stop_word_or_number(w) then
63 | table.insert(returnwords, red(w))
64 | elseif tf_map[w] then
65 | if tf_map[w] >= 15 then
66 | table.insert(returnwords, yellow(w .. '{' .. tf_map[w] .. '}'))
67 | else
68 | table.insert(returnwords, skyblue(w .. '{' .. tf_map[w] .. '}'))
69 | end
70 | w_not_found[w] = nil
71 | else
72 | table.insert(returnwords, w)
73 | end
74 | assert(best_scores[i] == distances[best_word_ids[i]], best_scores[i] .. ' ' .. distances[best_word_ids[i]])
75 | table.insert(returndistances, distances[best_word_ids[i]])
76 | end
77 | return returnwords, returndistances, w_not_found
78 | end
79 |
80 |
81 | local function geom_most_similar_words_to_ent(ent_name, k)
82 | local ent_wikiid = get_ent_wikiid_from_name(ent_name)
83 | local k = k or 1
84 | local ent_vec = geom_entname2vec(ent_name)
85 | assert(math.abs(1 - ent_vec:norm()) < 0.01 or ent_vec:norm() == 0, ':::: ' .. ent_vec:norm())
86 |
87 | print('\nTo entity: ' .. blue(ent_name) .. '; vec norm = ' .. ent_vec:norm() .. ':')
88 | neighbors, scores, w_not_found = geom_top_k_closest_words(ent_name, ent_vec, k)
89 | print(green('WORDS MODEL: ') .. list_with_scores_to_str(neighbors, scores))
90 |
91 | local str = yellow('WORDS NOT FOUND: ')
92 | for w,tf in pairs(w_not_found) do
93 | if tf >= 20 then
94 | str = str .. yellow(w .. '{' .. tf .. '}; ')
95 | else
96 | str = str .. w .. '{' .. tf .. '}; '
97 | end
98 | end
99 | print('\n' .. str)
100 | print('============================================================================')
101 | end
102 |
103 |
104 | -- Unit tests :
105 | function geom_unit_tests()
106 | print('\n' .. yellow('Words to Entity Similarity test:'))
107 | for i=1,table_len(ent_names_4EX) do
108 | geom_most_similar_words_to_ent(ent_names_4EX[i], 200)
109 | end
110 | end
111 |
--------------------------------------------------------------------------------
/entities/learn_e2v/learn_a.lua:
--------------------------------------------------------------------------------
1 | -- Training of entity embeddings.
2 |
3 | -- To run:
4 | -- i) delete all _RLTD files
5 | -- ii) th entities/relatedness/filter_wiki_canonical_words_RLTD.lua ; th entities/relatedness/filter_wiki_hyperlink_contexts_RLTD.lua
6 | -- iii) th entities/learn_e2v/learn_a.lua -root_data_dir /path/to/your/ed_data/files/
7 |
8 | -- Training of entity vectors
9 | require 'optim'
10 | require 'torch'
11 | require 'gnuplot'
12 | require 'nn'
13 | require 'xlua'
14 |
15 | dofile 'utils/utils.lua'
16 |
17 | cmd = torch.CmdLine()
18 | cmd:text()
19 | cmd:text('Learning entity vectors')
20 | cmd:text()
21 | cmd:text('Options:')
22 |
23 | ---------------- runtime options:
24 | cmd:option('-type', 'cudacudnn', 'Type: double | float | cuda | cudacudnn')
25 |
26 | cmd:option('-root_data_dir', '', 'Root path of the data, $DATA_PATH.')
27 |
28 | cmd:option('-optimization', 'ADAGRAD', 'Optimization method: RMSPROP | ADAGRAD | ADAM | SGD')
29 |
30 | cmd:option('-lr', 0.3, 'Learning rate')
31 |
32 | cmd:option('-batch_size', 500, 'Mini-batch size (1 = pure stochastic)')
33 |
34 | cmd:option('-word_vecs', 'w2v', '300d word vectors type: glove | w2v')
35 |
36 | cmd:option('-num_words_per_ent', 20, 'Num positive words sampled for the given entity at ' ..
37 | 'each iteration.')
38 |
39 | cmd:option('-num_neg_words', 5, 'Num negative words sampled for each positive word.')
40 |
41 | cmd:option('-unig_power', 0.6, 'Negative sampling unigram power (0.75 used in Word2Vec).')
42 |
43 | cmd:option('-entities', 'RLTD',
44 | 'Set of entities for which we train embeddings: 4EX (tiny, for debug) | ' ..
45 | 'RLTD (restricted set) | ALL (all Wiki entities, too big to fit on a single GPU)')
46 |
47 | cmd:option('-init_vecs_title_words', true, 'whether the entity embeddings should be initialized with the average of ' ..
48 | 'title word embeddings. Helps to speed up convergence speed of entity embeddings learning.')
49 |
50 | cmd:option('-loss', 'maxm', 'Loss function: nce (noise contrastive estimation) | ' ..
51 | 'neg (negative sampling) | is (importance sampling) | maxm (max-margin)')
52 |
53 | cmd:option('-data', 'wiki-canonical-hyperlinks', 'Training data: wiki-canonical (only) | ' ..
54 | 'wiki-canonical-hyperlinks')
55 |
56 | -- Only when opt.data = wiki-canonical-hyperlinks
57 | cmd:option('-num_passes_wiki_words', 200, 'Num passes (per entity) over Wiki canonical pages before ' ..
58 | 'changing to using Wiki hyperlinks.')
59 |
60 | cmd:option('-hyp_ctxt_len', 10, 'Left and right context window length for hyperlinks.')
61 |
62 | cmd:option('-banner_header', '', 'Banner header')
63 |
64 | cmd:text()
65 | opt = cmd:parse(arg or {})
66 |
67 | banner = '' .. opt.banner_header .. ';obj-' .. opt.loss .. ';' .. opt.data
68 | if opt.data ~= 'wiki-canonical' then
69 | banner = banner .. ';hypCtxtL-' .. opt.hyp_ctxt_len
70 | banner = banner .. ';numWWpass-' .. opt.num_passes_wiki_words
71 | end
72 | banner = banner .. ';WperE-' .. opt.num_words_per_ent
73 | banner = banner .. ';' .. opt.word_vecs .. ';negW-' .. opt.num_neg_words
74 | banner = banner .. ';ents-' .. opt.entities .. ';unigP-' .. opt.unig_power
75 | banner = banner .. ';bs-' .. opt.batch_size .. ';' .. opt.optimization .. '-lr-' .. opt.lr
76 |
77 | print('\n' .. blue('BANNER : ' .. banner))
78 |
79 | print('\n===> RUN TYPE: ' .. opt.type)
80 |
81 | torch.setdefaulttensortype('torch.FloatTensor')
82 | if string.find(opt.type, 'cuda') then
83 | print('==> switching to CUDA (GPU)')
84 | require 'cunn'
85 | require 'cutorch'
86 | require 'cudnn'
87 | cudnn.benchmark = true
88 | cudnn.fastest = true
89 | else
90 | print('==> running on CPU')
91 | end
92 |
93 | dofile 'utils/logger.lua'
94 | dofile 'entities/relatedness/relatedness.lua'
95 | dofile 'entities/ent_name2id_freq/ent_name_id.lua'
96 | dofile 'words/load_w_freq_and_vecs.lua'
97 | dofile 'words/w2v/w2v.lua'
98 | dofile 'entities/learn_e2v/minibatch_a.lua'
99 | dofile 'entities/learn_e2v/model_a.lua'
100 | dofile 'entities/learn_e2v/e2v_a.lua'
101 | dofile 'entities/learn_e2v/batch_dataset_a.lua'
102 |
103 | if opt.loss == 'neg' or opt.loss == 'nce' then
104 | criterion = nn.SoftMarginCriterion()
105 | elseif opt.loss == 'maxm' then
106 | criterion = nn.MultiMarginCriterion(1, torch.ones(opt.num_neg_words), 0.1)
107 | elseif opt.loss == 'is' then
108 | criterion = nn.CrossEntropyCriterion()
109 | end
110 |
111 | if string.find(opt.type, 'cuda') then
112 | criterion = criterion:cuda()
113 | end
114 |
115 |
116 | ----------------------------------------------------------------------
117 | if opt.optimization == 'ADAGRAD' then -- See: http://cs231n.github.io/neural-networks-3/#update
118 | dofile 'utils/optim/adagrad_mem.lua'
119 | optimMethod = adagrad_mem
120 | optimState = {
121 | learningRate = opt.lr
122 | }
123 | elseif opt.optimization == 'RMSPROP' then -- See: cs231n.github.io/neural-networks-3/#update
124 | dofile 'utils/optim/rmsprop_mem.lua'
125 | optimMethod = rmsprop_mem
126 | optimState = {
127 | learningRate = opt.lr
128 | }
129 | elseif opt.optimization == 'SGD' then -- See: http://cs231n.github.io/neural-networks-3/#update
130 | optimState = {
131 | learningRate = opt.lr,
132 | learningRateDecay = 5e-7
133 | }
134 | optimMethod = optim.sgd
135 | elseif opt.optimization == 'ADAM' then -- See: http://cs231n.github.io/neural-networks-3/#update
136 | optimState = {
137 | learningRate = opt.lr,
138 | }
139 | optimMethod = optim.adam
140 | else
141 | error('unknown optimization method')
142 | end
143 |
144 |
145 | ----------------------------------------------------------------------
146 | function train_ent_vecs()
147 | print('Training entity vectors w/ params: ' .. banner)
148 |
149 | -- Retrieve parameters and gradients:
150 | -- extracts and flattens all model's parameters into a 1-dim vector
151 | parameters,gradParameters = model:getParameters()
152 | gradParameters:zero()
153 |
154 | local processed_so_far = 0
155 | if opt.entities == 'ALL' then
156 | num_batches_per_epoch = 4000
157 | elseif opt.entities == 'RLTD' then
158 | num_batches_per_epoch = 2000
159 | elseif opt.entities == '4EX' then
160 | num_batches_per_epoch = 400
161 | end
162 | local test_every_num_epochs = 1
163 | local save_every_num_epochs = 3
164 |
165 | -- epoch tracker
166 | epoch = 1
167 |
168 | -- Initial testing:
169 | geom_unit_tests() -- Show some examples
170 |
171 | print('Training params: ' .. banner)
172 |
173 | -- do one epoch
174 | print('\n==> doing epoch on training data:')
175 | print("==> online epoch # " .. epoch .. ' [batch size = ' .. opt.batch_size .. ']')
176 |
177 | while true do
178 |
179 | local time = sys.clock()
180 | print(green('\n===> TRAINING EPOCH #' .. epoch .. '; num batches ' .. num_batches_per_epoch .. ' <==='))
181 |
182 | local avg_loss_before_opt_per_epoch = 0.0
183 | local avg_loss_after_opt_per_epoch = 0.0
184 |
185 | for batch_index = 1,num_batches_per_epoch do
186 | -- Read one mini-batch from one data_thread:
187 | inputs, targets = get_minibatch()
188 |
189 | -- Move data to GPU:
190 | minibatch_to_correct_type(inputs)
191 | targets = correct_type(targets)
192 |
193 | -- create closure to evaluate f(X) and df/dX
194 | local feval = function(x)
195 | -- get new parameters
196 | if x ~= parameters then
197 | parameters:copy(x)
198 | end
199 |
200 | -- reset gradients
201 | gradParameters:zero()
202 |
203 | -- evaluate function for complete mini batch
204 | local outputs = model:forward(inputs)
205 | assert(outputs:size(1) == opt.batch_size * opt.num_words_per_ent and
206 | outputs:size(2) == opt.num_neg_words)
207 |
208 | local f = criterion:forward(outputs, targets)
209 |
210 | -- estimate df/dW
211 | local df_do = criterion:backward(outputs, targets)
212 | local gradInput = model:backward(inputs, df_do)
213 |
214 | -- return f and df/dX
215 | return f,gradParameters
216 | end
217 |
218 | -- Debug info:
219 | local loss_before_opt = criterion:forward(model:forward(inputs), targets)
220 | avg_loss_before_opt_per_epoch = avg_loss_before_opt_per_epoch + loss_before_opt
221 |
222 | -- Optimize on current mini-batch
223 | optimMethod(feval, parameters, optimState)
224 |
225 | local loss_after_opt = criterion:forward(model:forward(inputs), targets)
226 | avg_loss_after_opt_per_epoch = avg_loss_after_opt_per_epoch + loss_after_opt
227 | if loss_after_opt > loss_before_opt then
228 | print(red('!!!!!! LOSS INCREASED: ' .. loss_before_opt .. ' --> ' .. loss_after_opt))
229 | end
230 |
231 | -- Display progress
232 | train_size = 17000000 ---------- 4 passes over the Wiki entity set
233 | processed_so_far = processed_so_far + opt.batch_size
234 | if processed_so_far > train_size then
235 | processed_so_far = processed_so_far - train_size
236 | end
237 | xlua.progress(processed_so_far, train_size)
238 | end
239 |
240 | avg_loss_before_opt_per_epoch = avg_loss_before_opt_per_epoch / num_batches_per_epoch
241 | avg_loss_after_opt_per_epoch = avg_loss_after_opt_per_epoch / num_batches_per_epoch
242 | print(yellow('\nAvg loss before opt = ' .. avg_loss_before_opt_per_epoch ..
243 | '; Avg loss after opt = ' .. avg_loss_after_opt_per_epoch))
244 |
245 | -- time taken
246 | time = sys.clock() - time
247 | time = time / (num_batches_per_epoch * opt.batch_size)
248 | print("==> time to learn 1 full entity = " .. (time*1000) .. 'ms')
249 |
250 | geom_unit_tests() -- Show some entity examples
251 |
252 | -- Various testing measures:
253 | if (epoch % test_every_num_epochs == 0) then
254 | if opt.entities ~= '4EX' then
255 | compute_relatedness_metrics(entity_similarity)
256 | end
257 | end
258 |
259 | -- Save model:
260 | if (epoch % save_every_num_epochs == 0) then
261 | print('==> saving model to ' .. opt.root_data_dir .. 'generated/ent_vecs/ent_vecs__ep_' .. epoch .. '.t7')
262 | torch.save(opt.root_data_dir .. 'generated/ent_vecs/ent_vecs__ep_' .. epoch .. '.t7', nn.Normalize(2):forward(lookup_ent_vecs.weight:float()))
263 | end
264 |
265 | print('Training params: ' .. banner)
266 |
267 | -- next epoch
268 | epoch = epoch + 1
269 | end
270 | end
271 |
272 | train_ent_vecs()
273 |
--------------------------------------------------------------------------------
/entities/learn_e2v/minibatch_a.lua:
--------------------------------------------------------------------------------
1 | assert(opt.entities == '4EX' or opt.entities == 'ALL' or opt.entities == 'RLTD', opt.entities)
2 |
3 | function empty_minibatch()
4 | local ctxt_word_ids = torch.ones(opt.batch_size, opt.num_words_per_ent, opt.num_neg_words):mul(unk_w_id)
5 | local ent_component_words = torch.ones(opt.batch_size, opt.num_words_per_ent):int()
6 | local ent_wikiids = torch.ones(opt.batch_size):int()
7 | local ent_thids = torch.ones(opt.batch_size):int()
8 | return {{ctxt_word_ids}, {ent_component_words}, {ent_thids, ent_wikiids}}
9 | end
10 |
11 | -- Get functions:
12 | function get_pos_and_neg_w_ids(minibatch)
13 | return minibatch[1][1]
14 | end
15 | function get_pos_and_neg_w_vecs(minibatch)
16 | return minibatch[1][2]
17 | end
18 | function get_pos_and_neg_w_unig_at_power(minibatch)
19 | return minibatch[1][3]
20 | end
21 | function get_ent_wiki_w_ids(minibatch)
22 | return minibatch[2][1]
23 | end
24 | function get_ent_wiki_w_vecs(minibatch)
25 | return minibatch[2][2]
26 | end
27 | function get_ent_thids_batch(minibatch)
28 | return minibatch[3][1]
29 | end
30 | function get_ent_wikiids(minibatch)
31 | return minibatch[3][2]
32 | end
33 |
34 |
35 | -- Fills in the minibatch and returns the grd truth word index per each example.
36 | -- An example in our case is an entity, a positive word sampled from \hat{p}(e|m)
37 | -- and several negative words sampled from \hat{p}(w)^\alpha.
38 | function process_one_line(line, minibatch, mb_index)
39 | if opt.entities == '4EX' then
40 | line = ent_lines_4EX[ent_names_4EX[math.random(1, table_len(ent_names_4EX))]]
41 | end
42 |
43 | local parts = split(line, '\t')
44 | local num_parts = table_len(parts)
45 |
46 | if num_parts == 3 then ---------> Words from the Wikipedia canonical page
47 | assert(table_len(parts) == 3, line)
48 | ent_wikiid = tonumber(parts[1])
49 | words_plus_stop_words = split(parts[3], ' ')
50 |
51 | else --------> Words from Wikipedia hyperlinks
52 | assert(num_parts >= 9, line .. ' --> ' .. num_parts)
53 | assert(parts[6] == 'CANDIDATES', line)
54 |
55 | local last_part = parts[num_parts]
56 | local ent_str = split(last_part, ',')
57 | ent_wikiid = tonumber(ent_str[2])
58 |
59 | words_plus_stop_words = {}
60 | local left_ctxt_w = split(parts[4], ' ')
61 | local left_ctxt_w_num = table_len(left_ctxt_w)
62 | for i = math.max(1, left_ctxt_w_num - opt.hyp_ctxt_len + 1), left_ctxt_w_num do
63 | table.insert(words_plus_stop_words, left_ctxt_w[i])
64 | end
65 | local right_ctxt_w = split(parts[5], ' ')
66 | local right_ctxt_w_num = table_len(right_ctxt_w)
67 | for i = 1, math.min(right_ctxt_w_num, opt.hyp_ctxt_len) do
68 | table.insert(words_plus_stop_words, right_ctxt_w[i])
69 | end
70 | end
71 |
72 | assert(ent_wikiid)
73 | local ent_thid = get_thid(ent_wikiid)
74 | assert(get_wikiid_from_thid(ent_thid) == ent_wikiid)
75 | get_ent_thids_batch(minibatch)[mb_index] = ent_thid
76 | assert(get_ent_thids_batch(minibatch)[mb_index] == ent_thid)
77 |
78 | get_ent_wikiids(minibatch)[mb_index] = ent_wikiid
79 |
80 |
81 | -- Remove stop words from entity wiki words representations.
82 | local positive_words_in_this_iter = {}
83 | local num_positive_words_this_iter = 0
84 | for _,w in pairs(words_plus_stop_words) do
85 | if contains_w(w) then
86 | table.insert(positive_words_in_this_iter, w)
87 | num_positive_words_this_iter = num_positive_words_this_iter + 1
88 | end
89 | end
90 |
91 | -- Try getting some words from the entity title if the canonical page is empty.
92 | if num_positive_words_this_iter == 0 then
93 | local ent_name = parts[2]
94 | words_plus_stop_words = split_in_words(ent_name)
95 | for _,w in pairs(words_plus_stop_words) do
96 | if contains_w(w) then
97 | table.insert(positive_words_in_this_iter, w)
98 | num_positive_words_this_iter = num_positive_words_this_iter + 1
99 | end
100 | end
101 |
102 | -- Still empty ? Get some random words then.
103 | if num_positive_words_this_iter == 0 then
104 | table.insert(positive_words_in_this_iter, get_word_from_id(random_unigram_at_unig_power_w_id()))
105 | end
106 | end
107 |
108 | local targets = torch.zeros(opt.num_words_per_ent):int()
109 |
110 | -- Sample some negative words:
111 | get_pos_and_neg_w_ids(minibatch)[mb_index]:apply(
112 | function(x)
113 | -- Random negative words sampled sampled from \hat{p}(w)^\alpha.
114 | return random_unigram_at_unig_power_w_id()
115 | end
116 | )
117 |
118 | -- Sample some positive words:
119 | for i = 1,opt.num_words_per_ent do
120 | local positive_w = positive_words_in_this_iter[math.random(1, num_positive_words_this_iter)]
121 | local positive_w_id = get_id_from_word(positive_w)
122 |
123 | -- Set the positive word in a random position. Remember that index (used in training).
124 | local grd_trth = math.random(1, opt.num_neg_words)
125 | get_ent_wiki_w_ids(minibatch)[mb_index][i] = positive_w_id
126 | assert(get_ent_wiki_w_ids(minibatch)[mb_index][i] == positive_w_id)
127 | targets[i] = grd_trth
128 | get_pos_and_neg_w_ids(minibatch)[mb_index][i][grd_trth] = positive_w_id
129 | end
130 |
131 | return targets
132 | end
133 |
134 |
135 | -- Fill minibatch with word and entity vectors:
136 | function postprocess_minibatch(minibatch, targets)
137 |
138 | minibatch[1][1] = get_pos_and_neg_w_ids(minibatch):view(opt.batch_size * opt.num_words_per_ent * opt.num_neg_words)
139 | minibatch[2][1] = get_ent_wiki_w_ids(minibatch):view(opt.batch_size * opt.num_words_per_ent)
140 |
141 | -- ctxt word vecs
142 | minibatch[1][2] = w2vutils:lookup_w_vecs(get_pos_and_neg_w_ids(minibatch))
143 |
144 | minibatch[1][3] = torch.zeros(opt.batch_size * opt.num_words_per_ent * opt.num_neg_words)
145 | minibatch[1][3]:map(minibatch[1][1], function(_,w_id) return get_w_unnorm_unigram_at_power(w_id) end)
146 | end
147 |
148 |
149 | -- Convert mini batch to correct type (e.g. move data to GPU):
150 | function minibatch_to_correct_type(minibatch)
151 | minibatch[1][1] = correct_type(minibatch[1][1])
152 | minibatch[2][1] = correct_type(minibatch[2][1])
153 | minibatch[1][2] = correct_type(minibatch[1][2])
154 | minibatch[1][3] = correct_type(minibatch[1][3])
155 | minibatch[3][1] = correct_type(minibatch[3][1])
156 | end
--------------------------------------------------------------------------------
/entities/learn_e2v/model_a.lua:
--------------------------------------------------------------------------------
1 | -- Definition of the neural network used to learn entity embeddings.
2 | -- To run a simple unit test that checks the forward and backward passes, just run :
3 | -- th entities/learn_e2v/model_a.lua
4 |
5 | if not opt then -- unit tests
6 | unit_tests = true
7 | dofile 'utils/utils.lua'
8 | require 'nn'
9 | cmd = torch.CmdLine()
10 | cmd:option('-type', 'double', 'type: double | float | cuda | cudacudnn')
11 | cmd:option('-batch_size', 7, 'mini-batch size (1 = pure stochastic)')
12 | cmd:option('-num_words_per_ent', 100, 'num positive words per entity per iteration.')
13 | cmd:option('-num_neg_words', 25, 'num negative words in the partition function.')
14 | cmd:option('-loss', 'nce', 'nce | neg | is | maxm')
15 | cmd:option('-init_vecs_title_words', false, 'whether the entity embeddings should be initialized with the average of title word embeddings. Helps to speed up convergence speed of entity embeddings learning.')
16 | opt = cmd:parse(arg or {})
17 | word_vecs_size = 5
18 | ent_vecs_size = word_vecs_size
19 | lookup_ent_vecs = nn.LookupTable(100, ent_vecs_size)
20 | end -- end unit tests
21 |
22 |
23 | if not unit_tests then
24 | ent_vecs_size = word_vecs_size
25 |
26 | -- Init ents vectors
27 | print('\n==> Init entity embeddings matrix. Num ents = ' .. get_total_num_ents())
28 | lookup_ent_vecs = nn.LookupTable(get_total_num_ents(), ent_vecs_size)
29 |
30 | -- Zero out unk_ent_thid vector for unknown entities.
31 | lookup_ent_vecs.weight[unk_ent_thid]:copy(torch.zeros(ent_vecs_size))
32 |
33 | -- Init entity vectors with average of title word embeddings.
34 | -- This would help speed-up training.
35 | if opt.init_vecs_title_words then
36 | print('Init entity embeddings with average of title word vectors to speed up learning.')
37 | for ent_thid = 1,get_total_num_ents() do
38 | local init_ent_vec = torch.zeros(ent_vecs_size)
39 | local ent_name = get_ent_name_from_wikiid(get_wikiid_from_thid(ent_thid))
40 | words_plus_stop_words = split_in_words(ent_name)
41 | local num_words_title = 0
42 | for _,w in pairs(words_plus_stop_words) do
43 | if contains_w(w) then -- Remove stop words.
44 | init_ent_vec:add(w2vutils.M[get_id_from_word(w)]:float())
45 | num_words_title = num_words_title + 1
46 | end
47 | end
48 |
49 | if num_words_title > 0 then
50 | if num_words_title > 3 then
51 | assert(init_ent_vec:norm() > 0, ent_name)
52 | end
53 | init_ent_vec:div(num_words_title)
54 | end
55 |
56 | if init_ent_vec:norm() > 0 then
57 | lookup_ent_vecs.weight[ent_thid]:copy(init_ent_vec)
58 | end
59 | end
60 | end
61 |
62 | collectgarbage(); collectgarbage();
63 | print(' Done init.')
64 | end
65 |
66 | ---------------- Model Definition --------------------------------
67 | cosine_words_ents = nn.Sequential()
68 | :add(nn.ConcatTable()
69 | :add(nn.Sequential()
70 | :add(nn.SelectTable(1))
71 | :add(nn.SelectTable(2)) -- ctxt words vectors
72 | :add(nn.Normalize(2))
73 | :add(nn.View(opt.batch_size, opt.num_words_per_ent * opt.num_neg_words, ent_vecs_size)))
74 | :add(nn.Sequential()
75 | :add(nn.SelectTable(3))
76 | :add(nn.SelectTable(1))
77 | :add(lookup_ent_vecs) -- entity vectors
78 | :add(nn.Normalize(2))
79 | :add(nn.View(opt.batch_size, 1, ent_vecs_size))))
80 | :add(nn.MM(false, true))
81 | :add(nn.View(opt.batch_size * opt.num_words_per_ent, opt.num_neg_words))
82 |
83 | model = nn.Sequential()
84 | :add(cosine_words_ents)
85 | :add(nn.View(opt.batch_size * opt.num_words_per_ent, opt.num_neg_words))
86 |
87 |
88 | if opt.loss == 'is' then
89 | model = nn.Sequential()
90 | :add(nn.ConcatTable()
91 | :add(model)
92 | :add(nn.Sequential()
93 | :add(nn.SelectTable(1))
94 | :add(nn.SelectTable(3)) -- unigram distributions at power
95 | :add(nn.Log())
96 | :add(nn.View(opt.batch_size * opt.num_words_per_ent, opt.num_neg_words))))
97 | :add(nn.CSubTable())
98 |
99 | elseif opt.loss == 'nce' then
100 | model = nn.Sequential()
101 | :add(nn.ConcatTable()
102 | :add(model)
103 | :add(nn.Sequential()
104 | :add(nn.SelectTable(1))
105 | :add(nn.SelectTable(3)) -- unigram distributions at power
106 | :add(nn.MulConstant(opt.num_neg_words - 1))
107 | :add(nn.Log())
108 | :add(nn.View(opt.batch_size * opt.num_words_per_ent, opt.num_neg_words))))
109 | :add(nn.CSubTable())
110 | end
111 |
112 | ---------------------------------------------------------------------------------------------
113 |
114 | ------- Cuda conversions:
115 | if string.find(opt.type, 'cuda') then
116 | model = model:cuda() --- This has to be called always before cudnn.convert
117 | end
118 |
119 | if string.find(opt.type, 'cudacudnn') then
120 | cudnn.convert(model, cudnn)
121 | end
122 |
123 |
124 | --- Unit tests
125 | if unit_tests then
126 | print('Network model unit tests:')
127 | local inputs = {}
128 |
129 | inputs[1] = {}
130 | inputs[1][1] = correct_type(torch.ones(opt.batch_size * opt.num_words_per_ent * opt.num_neg_words)) -- ctxt words
131 |
132 | inputs[2] = {}
133 | inputs[2][1] = correct_type(torch.ones(opt.batch_size * opt.num_words_per_ent)) -- ent wiki words
134 |
135 | inputs[3] = {}
136 | inputs[3][1] = correct_type(torch.ones(opt.batch_size)) -- ent th ids
137 | inputs[3][2] = torch.ones(opt.batch_size) -- ent wikiids
138 |
139 | -- ctxt word vecs
140 | inputs[1][2] = correct_type(torch.ones(opt.batch_size * opt.num_words_per_ent * opt.num_neg_words, word_vecs_size))
141 |
142 | inputs[1][3] = correct_type(torch.randn(opt.batch_size * opt.num_words_per_ent * opt.num_neg_words))
143 |
144 | local outputs = model:forward(inputs)
145 |
146 | assert(outputs:size(1) == opt.batch_size * opt.num_words_per_ent and
147 | outputs:size(2) == opt.num_neg_words)
148 | print('FWD success!')
149 |
150 | model:backward(inputs, correct_type(torch.randn(opt.batch_size * opt.num_words_per_ent, opt.num_neg_words)))
151 | print('BKWD success!')
152 | end
153 |
--------------------------------------------------------------------------------
/entities/pretrained_e2v/check_ents.lua:
--------------------------------------------------------------------------------
1 | if not opt then
2 | cmd = torch.CmdLine()
3 | cmd:option('-root_data_dir', '', 'Root path of the data, $DATA_PATH.')
4 | cmd:option('-ent_vecs_filename', 'ent_vecs__ep_228.t7', 'File name containing entity vectors generated with entities/learn_e2v/learn_a.lua.')
5 | cmd:text()
6 | opt = cmd:parse(arg or {})
7 | assert(opt.root_data_dir ~= '', 'Specify a valid root_data_dir path argument.')
8 | end
9 |
10 |
11 | require 'optim'
12 | require 'torch'
13 | require 'gnuplot'
14 | require 'nn'
15 | require 'xlua'
16 |
17 | dofile 'utils/utils.lua'
18 | dofile 'ed/args.lua'
19 |
20 | tds = require 'tds'
21 |
22 |
23 | print('===> RUN TYPE: ' .. opt.type)
24 | torch.setdefaulttensortype('torch.FloatTensor')
25 | if string.find(opt.type, 'cuda') then
26 | print('==> switching to CUDA (GPU)')
27 | require 'cunn'
28 | require 'cutorch'
29 | require 'cudnn'
30 | cudnn.benchmark = true
31 | cudnn.fastest = true
32 | else
33 | print('==> running on CPU')
34 | end
35 |
36 | dofile 'utils/logger.lua'
37 | dofile 'entities/relatedness/relatedness.lua'
38 | dofile 'entities/ent_name2id_freq/ent_name_id.lua'
39 | dofile 'entities/ent_name2id_freq/e_freq_index.lua'
40 | dofile 'words/load_w_freq_and_vecs.lua'
41 | dofile 'entities/pretrained_e2v/e2v.lua'
--------------------------------------------------------------------------------
/entities/pretrained_e2v/e2v.lua:
--------------------------------------------------------------------------------
1 | -- Loads pre-trained entity vectors trained using the file entity/learn_e2v/learn_a.lua
2 |
3 | assert(opt.ent_vecs_filename)
4 | print('==> Loading pre-trained entity vectors: e2v from file ' .. opt.ent_vecs_filename)
5 |
6 | assert(opt.entities == 'RLTD', 'Only RLTD entities are currently supported. ALL entities would blow the GPU memory.')
7 |
8 | -- Defining variables:
9 | ent_vecs_size = 300
10 |
11 | geom_w2v_M = w2vutils.M:float()
12 |
13 | e2vutils = {}
14 |
15 | -- Lookup table: ids -> tensor of vecs
16 | e2vutils.lookup = torch.load(opt.root_data_dir .. 'generated/ent_vecs/' .. opt.ent_vecs_filename)
17 | e2vutils.lookup = nn.Normalize(2):forward(e2vutils.lookup) -- Needs to be normalized to have norm 1.
18 |
19 | assert(e2vutils.lookup:size(1) == get_total_num_ents() and
20 | e2vutils.lookup:size(2) == ent_vecs_size, e2vutils.lookup:size(1) .. ' ' .. get_total_num_ents())
21 | assert(e2vutils.lookup[unk_ent_thid]:norm() == 0, e2vutils.lookup[unk_ent_thid]:norm())
22 |
23 | -- ent wikiid -> vec
24 | e2vutils.entwikiid2vec = function(self, ent_wikiid)
25 | local thid = get_thid(ent_wikiid)
26 | return self.lookup[thid]:float()
27 | end
28 | assert(torch.norm(e2vutils:entwikiid2vec(unk_ent_wikiid)) == 0)
29 |
30 |
31 | e2vutils.entname2vec = function (self,ent_name)
32 | assert(ent_name)
33 | return e2vutils:entwikiid2vec(get_ent_wikiid_from_name(ent_name))
34 | end
35 |
36 | -- Entity similarity based on cosine distance (note that entity vectors are normalized).
37 | function entity_similarity(e1_wikiid, e2_wikiid)
38 | local e1_vec = e2vutils:entwikiid2vec(e1_wikiid)
39 | local e2_vec = e2vutils:entwikiid2vec(e2_wikiid)
40 | return e1_vec * e2_vec
41 | end
42 |
43 |
44 | -----------------------------------------------------------------------
45 | ---- Some unit tests to understand the quality of these embeddings ----
46 | -----------------------------------------------------------------------
47 | local function geom_top_k_closest_words(ent_name, ent_vec, k)
48 | local tf_map = ent_wiki_words_4EX[ent_name]
49 | local w_not_found = {}
50 | for w,_ in pairs(tf_map) do
51 | if tf_map[w] >= 10 then
52 | w_not_found[w] = tf_map[w]
53 | end
54 | end
55 |
56 | distances = geom_w2v_M * ent_vec
57 |
58 | local best_scores, best_word_ids = topk(distances, k)
59 | local returnwords = {}
60 | local returndistances = {}
61 | for i = 1,k do
62 | local w = get_word_from_id(best_word_ids[i])
63 | if get_w_id_freq(best_word_ids[i]) >= 200 then
64 | local w_freq_str = '[fr=' .. get_w_id_freq(best_word_ids[i]) .. ']'
65 | if is_stop_word_or_number(w) then
66 | table.insert(returnwords, red(w .. w_freq_str))
67 | elseif tf_map[w] then
68 | if tf_map[w] >= 15 then
69 | table.insert(returnwords, yellow(w .. w_freq_str .. '{tf=' .. tf_map[w] .. '}'))
70 | else
71 | table.insert(returnwords, skyblue(w .. w_freq_str .. '{tf=' .. tf_map[w] .. '}'))
72 | end
73 | w_not_found[w] = nil
74 | else
75 | table.insert(returnwords, w .. w_freq_str)
76 | end
77 | assert(best_scores[i] == distances[best_word_ids[i]], best_scores[i] .. ' ' .. distances[best_word_ids[i]])
78 | table.insert(returndistances, distances[best_word_ids[i]])
79 | end
80 | end
81 | return returnwords, returndistances, w_not_found
82 | end
83 |
84 |
85 | local function geom_most_similar_words_to_ent(ent_name, k)
86 | local ent_wikiid = get_ent_wikiid_from_name(ent_name)
87 | local k = k or 1
88 | local ent_vec = e2vutils:entname2vec(ent_name)
89 | assert(math.abs(1 - ent_vec:norm()) < 0.01 or ent_vec:norm() == 0, ':::: ' .. ent_vec:norm())
90 |
91 | print('\nTo entity: ' .. blue(ent_name) .. '; vec norm = ' .. ent_vec:norm() .. ':')
92 | neighbors, scores, w_not_found = geom_top_k_closest_words(ent_name, ent_vec, k)
93 | print(green('TOP CLOSEST WORDS: ') .. list_with_scores_to_str(neighbors, scores))
94 |
95 | local str = yellow('WORDS NOT FOUND: ')
96 | for w,tf in pairs(w_not_found) do
97 | if tf >= 20 then
98 | str = str .. yellow(w .. '{' .. tf .. '}; ')
99 | else
100 | str = str .. w .. '{' .. tf .. '}; '
101 | end
102 | end
103 | print('\n' .. str)
104 | print('============================================================================')
105 | end
106 |
107 | function geom_unit_tests()
108 | print('\n' .. yellow('TOP CLOSEST WORDS to a given entity based on cosine distance:'))
109 | print('For each word, we show the unigram frequency [fr] and the cosine similarity.')
110 | print('Infrequent words [fr < 500] have noisy embeddings, thus should be trusted less.')
111 | print('WORDS NOT FOUND contains frequent words from the Wikipedia canonical page that are not found in the TOP CLOSEST WORDS list.')
112 |
113 | for i=1,table_len(ent_names_4EX) do
114 | geom_most_similar_words_to_ent(ent_names_4EX[i], 300)
115 | end
116 | end
117 |
118 | print(' Done reading e2v data. Entity vocab size = ' .. e2vutils.lookup:size(1))
119 |
--------------------------------------------------------------------------------
/entities/pretrained_e2v/e2v_txt_reader.lua:
--------------------------------------------------------------------------------
1 | print('==> loading e2v')
2 |
3 | local V = torch.ones(get_total_num_ents(), ent_vecs_size):mul(1e-10) -- not zero because of cosine_distance layer
4 |
5 | local cnt = 0
6 | for line in io.lines(e2v_txtfilename) do
7 | cnt = cnt + 1
8 | if cnt % 1000000 == 0 then
9 | print('=======> processed ' .. cnt .. ' lines')
10 | end
11 |
12 | local parts = split(line, ' ')
13 | assert(table_len(parts) == ent_vecs_size + 1)
14 | local ent_wikiid = tonumber(parts[1])
15 | local vec = torch.zeros(ent_vecs_size)
16 | for i=1,ent_vecs_size do
17 | vec[i] = tonumber(parts[i + 1])
18 | end
19 |
20 | if (contains_thid(ent_wikiid)) then
21 | V[get_thid(ent_wikiid)] = vec
22 | else
23 | print('Ent id = ' .. ent_wikiid .. ' does not have a vector. ')
24 | end
25 | end
26 |
27 | print(' Done loading entity vectors. Size = ' .. cnt .. '\n')
28 |
29 | print('Writing t7 File for future usage. Next time Ent2Vec will load faster!')
30 | torch.save(e2v_t7filename, V)
31 | print(' Done saving.\n')
32 |
33 | return V
--------------------------------------------------------------------------------
/entities/relatedness/filter_wiki_canonical_words_RLTD.lua:
--------------------------------------------------------------------------------
1 |
2 | if not opt then
3 | cmd = torch.CmdLine()
4 | cmd:option('-root_data_dir', '', 'Root path of the data, $DATA_PATH.')
5 | cmd:text()
6 | opt = cmd:parse(arg or {})
7 | assert(opt.root_data_dir ~= '', 'Specify a valid root_data_dir path argument.')
8 | end
9 |
10 |
11 | dofile 'utils/utils.lua'
12 | dofile 'entities/relatedness/relatedness.lua'
13 |
14 | input = opt.root_data_dir .. 'generated/wiki_canonical_words.txt'
15 |
16 | output = opt.root_data_dir .. 'generated/wiki_canonical_words_RLTD.txt'
17 | ouf = assert(io.open(output, "w"))
18 |
19 | print('\nStarting dataset filtering.')
20 |
21 | local cnt = 0
22 | for line in io.lines(input) do
23 | cnt = cnt + 1
24 | if cnt % 500000 == 0 then
25 | print(' =======> processed ' .. cnt .. ' lines')
26 | end
27 |
28 | local parts = split(line, '\t')
29 | assert(table_len(parts) == 3)
30 |
31 | local ent_wikiid = tonumber(parts[1])
32 | local ent_name = parts[2]
33 | assert(ent_wikiid)
34 |
35 | if rewtr.reltd_ents_wikiid_to_rltdid[ent_wikiid] then
36 | ouf:write(line .. '\n')
37 | end
38 | end
39 |
40 | ouf:flush()
41 | io.close(ouf)
42 |
--------------------------------------------------------------------------------
/entities/relatedness/filter_wiki_hyperlink_contexts_RLTD.lua:
--------------------------------------------------------------------------------
1 | -- Filter all training data s.t. only candidate entities and ground truth entities for which
2 | -- we have a valid entity embedding are kept.
3 |
4 | if not opt then
5 | cmd = torch.CmdLine()
6 | cmd:option('-root_data_dir', '', 'Root path of the data, $DATA_PATH.')
7 | cmd:text()
8 | opt = cmd:parse(arg or {})
9 | assert(opt.root_data_dir ~= '', 'Specify a valid root_data_dir path argument.')
10 | end
11 |
12 |
13 | dofile 'utils/utils.lua'
14 | dofile 'entities/relatedness/relatedness.lua'
15 |
16 | input = opt.root_data_dir .. 'generated/wiki_hyperlink_contexts.csv'
17 |
18 | output = opt.root_data_dir .. 'generated/wiki_hyperlink_contexts_RLTD.csv'
19 | ouf = assert(io.open(output, "w"))
20 |
21 | print('\nStarting dataset filtering.')
22 | local cnt = 0
23 | for line in io.lines(input) do
24 | cnt = cnt + 1
25 | if cnt % 50000 == 0 then
26 | print(' =======> processed ' .. cnt .. ' lines')
27 | end
28 |
29 | local parts = split(line, '\t')
30 | local grd_str = parts[table_len(parts)]
31 | assert(parts[table_len(parts) - 1] == 'GT:')
32 | local grd_str_parts = split(grd_str, ',')
33 |
34 | local grd_pos = tonumber(grd_str_parts[1])
35 | assert(grd_pos)
36 |
37 | local grd_ent_wikiid = tonumber(grd_str_parts[2])
38 | assert(grd_ent_wikiid)
39 |
40 | if rewtr.reltd_ents_wikiid_to_rltdid[grd_ent_wikiid] then
41 | assert(parts[6] == 'CANDIDATES')
42 |
43 | local output_line = parts[1] .. '\t' .. parts[2] .. '\t' .. parts[3] .. '\t' .. parts[4] .. '\t' .. parts[5] .. '\t' .. parts[6] .. '\t'
44 |
45 | local new_grd_pos = -1
46 | local new_grd_str_without_idx = nil
47 |
48 | local i = 1
49 | local added_ents = 0
50 | while (parts[6 + i] ~= 'GT:') do
51 | local str = parts[6 + i]
52 | local str_parts = split(str, ',')
53 | local ent_wikiid = tonumber(str_parts[1])
54 | if rewtr.reltd_ents_wikiid_to_rltdid[ent_wikiid] then
55 | added_ents = added_ents + 1
56 | output_line = output_line .. str .. '\t'
57 | end
58 | if (i == grd_pos) then
59 | assert(ent_wikiid == grd_ent_wikiid, 'Error for: ' .. line)
60 | new_grd_pos = added_ents
61 | new_grd_str_without_idx = str
62 | end
63 |
64 | i = i + 1
65 | end
66 |
67 | assert(new_grd_pos > 0)
68 | output_line = output_line .. 'GT:\t' .. new_grd_pos .. ',' .. new_grd_str_without_idx
69 |
70 | ouf:write(output_line .. '\n')
71 | end
72 | end
73 |
74 | ouf:flush()
75 | io.close(ouf)
76 |
--------------------------------------------------------------------------------
/utils/logger.lua:
--------------------------------------------------------------------------------
1 | --[[ Logger: a simple class to log symbols during training,
2 | and automate plot generation
3 | Example:
4 | logger = optim.Logger('somefile.log') -- file to save stuff
5 | for i = 1,N do -- log some symbols during
6 | train_error = ... -- training/testing
7 | test_error = ...
8 | logger:add{['training error'] = train_error,
9 | ['test error'] = test_error}
10 | end
11 | logger:style{['training error'] = '-', -- define styles for plots
12 | ['test error'] = '-'}
13 | logger:plot() -- and plot
14 | ---- OR ---
15 | logger = optim.Logger('somefile.log') -- file to save stuff
16 | logger:setNames{'training error', 'test error'}
17 | for i = 1,N do -- log some symbols during
18 | train_error = ... -- training/testing
19 | test_error = ...
20 | logger:add{train_error, test_error}
21 | end
22 | logger:style{'-', '-'} -- define styles for plots
23 | logger:plot() -- and plot
24 | ]]
25 | require 'xlua'
26 | local Logger = torch.class('Logger')
27 |
28 | function Logger:__init(filename, timestamp)
29 | if filename then
30 | self.name = filename
31 | os.execute('mkdir ' .. (sys.uname() ~= 'windows' and '-p ' or '') .. ' "' .. paths.dirname(filename) .. '"')
32 | if timestamp then
33 | -- append timestamp to create unique log file
34 | filename = filename .. '-'..os.date("%Y_%m_%d_%X")
35 | end
36 | self.file = io.open(filename,'w')
37 | self.epsfile = self.name .. '.eps'
38 | else
39 | self.file = io.stdout
40 | self.name = 'stdout'
41 | print(' warning: no path provided, logging to std out')
42 | end
43 | self.empty = true
44 | self.symbols = {}
45 | self.styles = {}
46 | self.names = {}
47 | self.idx = {}
48 | self.figure = nil
49 | self.showPlot = false
50 | self.plotRawCmd = nil
51 | self.defaultStyle = '+'
52 | end
53 |
54 | function Logger:setNames(names)
55 | self.names = names
56 | self.empty = false
57 | self.nsymbols = #names
58 | for k,key in pairs(names) do
59 | self.file:write(key .. '\t')
60 | self.symbols[k] = {}
61 | self.styles[k] = {self.defaultStyle}
62 | self.idx[key] = k
63 | end
64 | self.file:write('\n')
65 | self.file:flush()
66 | end
67 |
68 | function Logger:add(symbols)
69 | -- (1) first time ? print symbols' names on first row
70 | if self.empty then
71 | self.empty = false
72 | self.nsymbols = #symbols
73 | for k,val in pairs(symbols) do
74 | self.file:write(k .. '\t')
75 | self.symbols[k] = {}
76 | self.styles[k] = {self.defaultStyle}
77 | self.names[k] = k
78 | end
79 | self.idx = self.names
80 | self.file:write('\n')
81 | end
82 | -- (2) print all symbols on one row
83 | for k,val in pairs(symbols) do
84 | if type(val) == 'number' then
85 | self.file:write(string.format('%11.4e',val) .. '\t')
86 | elseif type(val) == 'string' then
87 | self.file:write(val .. '\t')
88 | else
89 | xlua.error('can only log numbers and strings', 'Logger')
90 | end
91 | end
92 | self.file:write('\n')
93 | self.file:flush()
94 | -- (3) save symbols in internal table
95 | for k,val in pairs(symbols) do
96 | table.insert(self.symbols[k], val)
97 | end
98 | end
99 |
100 | function Logger:style(symbols)
101 | for name,style in pairs(symbols) do
102 | if type(style) == 'string' then
103 | self.styles[name] = {style}
104 | elseif type(style) == 'table' then
105 | self.styles[name] = style
106 | else
107 | xlua.error('style should be a string or a table of strings','Logger')
108 | end
109 | end
110 | end
111 |
112 | function Logger:plot(ylabel, xlabel)
113 | if not xlua.require('gnuplot') then
114 | if not self.warned then
115 | print(' warning: cannot plot with this version of Torch')
116 | self.warned = true
117 | end
118 | return
119 | end
120 | local plotit = false
121 | local plots = {}
122 | local plotsymbol =
123 | function(name,list)
124 | if #list > 1 then
125 | local nelts = #list
126 | local plot_y = torch.Tensor(nelts)
127 | for i = 1,nelts do
128 | plot_y[i] = list[i]
129 | end
130 | for _,style in ipairs(self.styles[name]) do
131 | table.insert(plots, {self.names[name], plot_y, style})
132 | end
133 | plotit = true
134 | end
135 | end
136 |
137 | -- plot all symbols
138 | for name,list in pairs(self.symbols) do
139 | plotsymbol(name,list)
140 | end
141 |
142 | if plotit then
143 | if self.showPlot then
144 | self.figure = gnuplot.figure(self.figure)
145 | gnuplot.plot(plots)
146 | gnuplot.xlabel(xlabel)
147 | gnuplot.ylabel(ylabel)
148 | if self.plotRawCmd then gnuplot.raw(self.plotRawCmd) end
149 | gnuplot.grid('on')
150 | gnuplot.title(banner)
151 | end
152 |
153 | if self.epsfile then
154 | os.execute('rm -f "' .. self.epsfile .. '"')
155 | local epsfig = gnuplot.epsfigure(self.epsfile)
156 | gnuplot.plot(plots)
157 | gnuplot.xlabel(xlabel)
158 | gnuplot.ylabel(ylabel)
159 | if self.plotRawCmd then gnuplot.raw(self.plotRawCmd) end
160 | gnuplot.grid('on')
161 | gnuplot.title(banner)
162 | gnuplot.plotflush()
163 | gnuplot.close(epsfig)
164 | end
165 | end
166 | end
--------------------------------------------------------------------------------
/utils/optim/adadelta_mem.lua:
--------------------------------------------------------------------------------
1 | -- Memory optimized implementation of optim.adadelta
2 |
3 | --[[ ADADELTA implementation for SGD http://arxiv.org/abs/1212.5701
4 | ARGS:
5 | - `opfunc` : a function that takes a single input (X), the point of
6 | evaluation, and returns f(X) and df/dX
7 | - `x` : the initial point
8 | - `config` : a table of hyper-parameters
9 | - `config.rho` : interpolation parameter
10 | - `config.eps` : for numerical stability
11 | - `config.weightDecay` : weight decay
12 | - `state` : a table describing the state of the optimizer; after each
13 | call the state is modified
14 | - `state.paramVariance` : vector of temporal variances of parameters
15 | - `state.accDelta` : vector of accummulated delta of gradients
16 | RETURN:
17 | - `x` : the new x vector
18 | - `f(x)` : the function, evaluated before the update
19 | ]]
20 | function adadelta_mem(opfunc, x, config, state)
21 | -- (0) get/update state
22 | if config == nil and state == nil then
23 | print('no state table, ADADELTA initializing')
24 | end
25 | local config = config or {}
26 | local state = state or config
27 | local rho = config.rho or 0.9
28 | local eps = config.eps or 1e-6
29 | local wd = config.weightDecay or 0
30 | state.evalCounter = state.evalCounter or 0
31 | -- (1) evaluate f(x) and df/dx
32 | local fx,dfdx = opfunc(x)
33 |
34 | -- (2) weight decay
35 | if wd ~= 0 then
36 | dfdx:add(wd, x)
37 | end
38 |
39 | -- (3) parameter update
40 | if not state.paramVariance then
41 | state.paramVariance = torch.Tensor():typeAs(x):resizeAs(dfdx):zero()
42 | state.delta = torch.Tensor():typeAs(x):resizeAs(dfdx):zero()
43 | state.accDelta = torch.Tensor():typeAs(x):resizeAs(dfdx):zero()
44 | end
45 | state.paramVariance:mul(rho):addcmul(1-rho,dfdx,dfdx)
46 | state.paramVariance:add(eps):sqrt() -- std
47 | state.delta:copy(state.accDelta):add(eps):sqrt():cdiv(state.paramVariance):cmul(dfdx)
48 | x:add(-1, state.delta)
49 | state.accDelta:mul(rho):addcmul(1-rho, state.delta, state.delta)
50 | state.paramVariance:cmul(state.paramVariance):csub(eps) -- recompute variance again
51 |
52 | -- (4) update evaluation counter
53 | state.evalCounter = state.evalCounter + 1
54 |
55 | -- return x*, f(x) before optimization
56 | return x,{fx}
57 | end
--------------------------------------------------------------------------------
/utils/optim/adagrad_mem.lua:
--------------------------------------------------------------------------------
1 | -- Memory optimized implementation of optim.adagrad
2 |
3 | --[[ ADAGRAD implementation for SGD
4 | ARGS:
5 | - `opfunc` : a function that takes a single input (X), the point of
6 | evaluation, and returns f(X) and df/dX
7 | - `x` : the initial point
8 | - `state` : a table describing the state of the optimizer; after each
9 | call the state is modified
10 | - `state.learningRate` : learning rate
11 | - `state.paramVariance` : vector of temporal variances of parameters
12 | - `state.weightDecay` : scalar that controls weight decay
13 | RETURN:
14 | - `x` : the new x vector
15 | - `f(x)` : the function, evaluated before the update
16 | ]]
17 |
18 | function adagrad_mem(opfunc, x, config, state)
19 | -- (0) get/update state
20 | if config == nil and state == nil then
21 | print('no state table, ADAGRAD initializing')
22 | end
23 | local config = config or {}
24 | local state = state or config
25 | local lr = config.learningRate or 1e-3
26 | local lrd = config.learningRateDecay or 0
27 | local wd = config.weightDecay or 0
28 | state.evalCounter = state.evalCounter or 0
29 | local nevals = state.evalCounter
30 |
31 | -- (1) evaluate f(x) and df/dx
32 | local fx,dfdx = opfunc(x)
33 |
34 | -- (2) weight decay with a single parameter
35 | if wd ~= 0 then
36 | dfdx:add(wd, x)
37 | end
38 |
39 | -- (3) learning rate decay (annealing)
40 | local clr = lr / (1 + nevals*lrd)
41 |
42 | -- (4) parameter update with single or individual learning rates
43 | if not state.paramVariance then
44 | state.paramVariance = torch.Tensor():typeAs(x):resizeAs(dfdx):zero()
45 | end
46 |
47 | state.paramVariance:addcmul(1,dfdx,dfdx)
48 | state.paramVariance:add(1e-10)
49 | state.paramVariance:sqrt() -- Keeps the std
50 | x:addcdiv(-clr, dfdx, state.paramVariance)
51 | state.paramVariance:cmul(state.paramVariance) -- Keeps the variance again
52 | state.paramVariance:add(-1e-10)
53 |
54 |
55 | -- (5) update evaluation counter
56 | state.evalCounter = state.evalCounter + 1
57 |
58 | -- return x*, f(x) before optimization
59 | return x,{fx}
60 | end
--------------------------------------------------------------------------------
/utils/optim/rmsprop_mem.lua:
--------------------------------------------------------------------------------
1 | -- Memory optimized implementation of optim.rmsprop
2 |
3 | --[[ An implementation of RMSprop
4 | ARGS:
5 | - 'opfunc' : a function that takes a single input (X), the point
6 | of a evaluation, and returns f(X) and df/dX
7 | - 'x' : the initial point
8 | - 'config` : a table with configuration parameters for the optimizer
9 | - 'config.learningRate' : learning rate
10 | - 'config.alpha' : smoothing constant
11 | - 'config.epsilon' : value with which to initialise m
12 | - 'config.weightDecay' : weight decay
13 | - 'state' : a table describing the state of the optimizer;
14 | after each call the state is modified
15 | - 'state.m' : leaky sum of squares of parameter gradients,
16 | - 'state.tmp' : and the square root (with epsilon smoothing)
17 | RETURN:
18 | - `x` : the new x vector
19 | - `f(x)` : the function, evaluated before the update
20 | ]]
21 |
22 | function rmsprop_mem(opfunc, x, config, state)
23 | -- (0) get/update state
24 | local config = config or {}
25 | local state = state or config
26 | local lr = config.learningRate or 1e-2
27 | local alpha = config.alpha or 0.99
28 | local epsilon = config.epsilon or 1e-8
29 | local wd = config.weightDecay or 0
30 | local mfill = config.initialMean or 0
31 |
32 | -- (1) evaluate f(x) and df/dx
33 | local fx, dfdx = opfunc(x)
34 |
35 | -- (2) weight decay
36 | if wd ~= 0 then
37 | dfdx:add(wd, x)
38 | end
39 |
40 | -- (3) initialize mean square values and square gradient storage
41 | if not state.m then
42 | state.m = torch.Tensor():typeAs(x):resizeAs(dfdx):fill(mfill)
43 | end
44 |
45 | -- (4) calculate new (leaky) mean squared values
46 | state.m:mul(alpha)
47 | state.m:addcmul(1.0-alpha, dfdx, dfdx)
48 |
49 | -- (5) perform update
50 | state.m:add(epsilon)
51 | state.m:sqrt()
52 | x:addcdiv(-lr, dfdx, state.m)
53 | state.m:cmul(state.m)
54 | state.m:add(-epsilon)
55 |
56 | -- return x*, f(x) before optimization
57 | return x, {fx}
58 | end
--------------------------------------------------------------------------------
/utils/utils.lua:
--------------------------------------------------------------------------------
1 | function topk(one_dim_tensor, k)
2 | local bestk, indices = torch.topk(one_dim_tensor, k, true)
3 | local sorted, newindices = torch.sort(bestk, true)
4 | local oldindices = torch.LongTensor(k)
5 | for i = 1,k do
6 | oldindices[i] = indices[newindices[i]]
7 | end
8 | return sorted, oldindices
9 | end
10 |
11 |
12 | function list_with_scores_to_str(list, scores)
13 | local str = ''
14 | for i,v in pairs(list) do
15 | str = str .. list[i] .. '[' .. string.format("%.2f", scores[i]) .. ']; '
16 | end
17 | return str
18 | end
19 |
20 | function table_len(t)
21 | local count = 0
22 | for _ in pairs(t) do count = count + 1 end
23 | return count
24 | end
25 |
26 |
27 | function split(inputstr, sep)
28 | if sep == nil then
29 | sep = "%s"
30 | end
31 | local t={} ; i=1
32 | for str in string.gmatch(inputstr, "([^"..sep.."]+)") do
33 | t[i] = str
34 | i = i + 1
35 | end
36 | return t
37 | end
38 | -- Unit test:
39 | assert(6 == #split('aa_bb cc__dd ee _ _ __ff' , '_ '))
40 |
41 |
42 |
43 | function correct_type(data)
44 | if opt.type == 'float' then return data:float()
45 | elseif opt.type == 'double' then return data:double()
46 | elseif string.find(opt.type, 'cuda') then return data:cuda()
47 | else print('Unsuported type')
48 | end
49 | end
50 |
51 | -- color fonts:
52 | function red(s)
53 | return '\27[31m' .. s .. '\27[39m'
54 | end
55 |
56 | function green(s)
57 | return '\27[32m' .. s .. '\27[39m'
58 | end
59 |
60 | function yellow(s)
61 | return '\27[33m' .. s .. '\27[39m'
62 | end
63 |
64 | function blue(s)
65 | return '\27[34m' .. s .. '\27[39m'
66 | end
67 |
68 | function violet(s)
69 | return '\27[35m' .. s .. '\27[39m'
70 | end
71 |
72 | function skyblue(s)
73 | return '\27[36m' .. s .. '\27[39m'
74 | end
75 |
76 |
77 |
78 | function split_in_words(inputstr)
79 | local words = {}
80 | for word in inputstr:gmatch("%w+") do table.insert(words, word) end
81 | return words
82 | end
83 |
84 |
85 | function first_letter_to_uppercase(s)
86 | return s:sub(1,1):upper() .. s:sub(2)
87 | end
88 |
89 | function modify_uppercase_phrase(s)
90 | if (s == s:upper()) then
91 | local words = split_in_words(s:lower())
92 | local res = {}
93 | for _,w in pairs(words) do
94 | table.insert(res, first_letter_to_uppercase(w))
95 | end
96 | return table.concat(res, ' ')
97 | else
98 | return s
99 | end
100 | end
101 |
102 | function blue_num_str(n)
103 | return blue(string.format("%.3f", n))
104 | end
105 |
106 |
107 |
108 | function string_starts(s, m)
109 | return string.sub(s,1,string.len(m)) == m
110 | end
111 |
112 | -- trim:
113 | function trim1(s)
114 | return (s:gsub("^%s*(.-)%s*$", "%1"))
115 | end
116 |
117 |
118 | function nice_print_red_green(a,b)
119 | local s = string.format("%.3f", a) .. ':' .. string.format("%.3f", b) .. '['
120 | if a > b then
121 | return s .. red(string.format("%.3f", a-b)) .. ']'
122 | elseif a < b then
123 | return s .. green(string.format("%.3f", b-a)) .. ']'
124 | else
125 | return s .. '0]'
126 | end
127 | end
128 |
--------------------------------------------------------------------------------
/words/load_w_freq_and_vecs.lua:
--------------------------------------------------------------------------------
1 | -- Loads all common words in both Wikipedia and Word2vec/Glove , their unigram frequencies and their pre-trained Word2Vec embeddings.
2 |
3 | -- To load this as a standalone file do:
4 | -- th> opt = {word_vecs = 'w2v', root_data_dir = '$DATA_PATH'}
5 | -- th> dofile 'words/load_w_freq_and_vecs.lua'
6 |
7 | if not opt then
8 | cmd = torch.CmdLine()
9 | cmd:option('-root_data_dir', '', 'Root path of the data, $DATA_PATH.')
10 | cmd:text()
11 | opt = cmd:parse(arg or {})
12 | assert(opt.root_data_dir ~= '', 'Specify a valid root_data_dir path argument.')
13 | end
14 |
15 |
16 | default_path = opt.root_data_dir .. 'basic_data/wordEmbeddings/'
17 |
18 | tds = tds or require 'tds'
19 | torch.setdefaulttensortype('torch.FloatTensor')
20 | if not list_with_scores_to_str then
21 | dofile 'utils/utils.lua'
22 | end
23 | if not is_stop_word_or_number then
24 | dofile 'words/stop_words.lua'
25 | end
26 |
27 | assert(opt, 'Define opt')
28 | assert(opt.word_vecs, 'Define opt.word_vecs')
29 |
30 | print('==> Loading common w2v + top freq list of words')
31 |
32 | local output_t7filename = opt.root_data_dir .. 'generated/common_top_words_freq_vectors_' .. opt.word_vecs .. '.t7'
33 |
34 | if paths.filep(output_t7filename) then
35 | print(' ---> from t7 file.')
36 | common_w2v_freq_words = torch.load(output_t7filename)
37 |
38 | else
39 | print(' ---> t7 file NOT found. Loading from disk instead (slower). Out file = ' .. output_t7filename)
40 | local freq_words = tds.Hash()
41 |
42 | print(' word freq index ...')
43 | local num_freq_words = 1
44 | local w_freq_file = opt.root_data_dir .. 'generated/word_wiki_freq.txt'
45 | for line in io.lines(w_freq_file) do
46 | local parts = split(line, '\t')
47 | local w = parts[1]
48 | local w_f = tonumber(parts[2])
49 | if not is_stop_word_or_number(w) then
50 | freq_words[w] = w_f
51 | num_freq_words = num_freq_words + 1
52 | end
53 | end
54 |
55 | common_w2v_freq_words = tds.Hash()
56 |
57 | print(' word vectors index ...')
58 | if opt.word_vecs == 'glove' then
59 | w2v_txtfilename = default_path .. 'Glove/glove.840B.300d.txt'
60 | local line_num = 0
61 | for line in io.lines(w2v_txtfilename) do
62 | line_num = line_num + 1
63 | if line_num % 200000 == 0 then
64 | print(' Processed ' .. line_num)
65 | end
66 | local parts = split(line, ' ')
67 | local w = parts[1]
68 | if freq_words[w] then
69 | common_w2v_freq_words[w] = 1
70 | end
71 | end
72 |
73 | else
74 | assert(opt.word_vecs == 'w2v')
75 | w2v_binfilename = default_path .. 'Word2Vec/GoogleNews-vectors-negative300.bin'
76 | local word_vecs_size = 300
77 | file = torch.DiskFile(w2v_binfilename,'r')
78 | file:ascii()
79 | local vocab_size = file:readInt()
80 | local size = file:readInt()
81 | assert(size == word_vecs_size, 'Wrong size : ' .. size .. ' vs ' .. word_vecs_size)
82 |
83 | function read_string_w2v(file)
84 | local str = {}
85 | while true do
86 | local char = file:readChar()
87 | if char == 32 or char == 10 or char == 0 then
88 | break
89 | else
90 | str[#str+1] = char
91 | end
92 | end
93 | str = torch.CharStorage(str)
94 | return str:string()
95 | end
96 |
97 | --Reading Contents
98 | file:binary()
99 | local line_num = 0
100 | for i = 1,vocab_size do
101 | line_num = line_num + 1
102 | if line_num % 200000 == 0 then
103 | print('Processed ' .. line_num)
104 | end
105 | local w = read_string_w2v(file)
106 | local v = torch.FloatTensor(file:readFloat(word_vecs_size))
107 | if freq_words[w] then
108 | common_w2v_freq_words[w] = 1
109 | end
110 | end
111 | end
112 |
113 | print('Writing t7 File for future usage. Next time loading will be faster!')
114 | torch.save(output_t7filename, common_w2v_freq_words)
115 | end
116 |
117 | -- Now load the freq and w2v indexes
118 | dofile 'words/w_freq/w_freq_index.lua'
119 |
--------------------------------------------------------------------------------
/words/stop_words.lua:
--------------------------------------------------------------------------------
1 | all_stop_words = { ['a'] = 1, ['about'] = 1, ['above'] = 1, ['across'] = 1, ['after'] = 1, ['afterwards'] = 1, ['again'] = 1, ['against'] = 1, ['all'] = 1, ['almost'] = 1, ['alone'] = 1, ['along'] = 1, ['already'] = 1, ['also'] = 1, ['although'] = 1, ['always'] = 1, ['am'] = 1, ['among'] = 1, ['amongst'] = 1, ['amoungst'] = 1, ['amount'] = 1, ['an'] = 1, ['and'] = 1, ['another'] = 1, ['any'] = 1, ['anyhow'] = 1, ['anyone'] = 1, ['anything'] = 1, ['anyway'] = 1, ['anywhere'] = 1, ['are'] = 1, ['around'] = 1, ['as'] = 1, ['at'] = 1, ['back'] = 1, ['be'] = 1, ['became'] = 1, ['because'] = 1, ['become'] = 1, ['becomes'] = 1, ['becoming'] = 1, ['been'] = 1, ['before'] = 1, ['beforehand'] = 1, ['behind'] = 1, ['being'] = 1, ['below'] = 1, ['beside'] = 1, ['besides'] = 1, ['between'] = 1, ['beyond'] = 1, ['both'] = 1, ['bottom'] = 1, ['but'] = 1, ['by'] = 1, ['call'] = 1, ['can'] = 1, ['cannot'] = 1, ['cant'] = 1, ['dont'] = 1, ['co'] = 1, ['con'] = 1, ['could'] = 1, ['couldnt'] = 1, ['cry'] = 1, ['de'] = 1, ['describe'] = 1, ['detail'] = 1, ['do'] = 1, ['done'] = 1, ['down'] = 1, ['due'] = 1, ['during'] = 1, ['each'] = 1, ['eg'] = 1, ['eight'] = 1, ['either'] = 1, ['eleven'] = 1, ['else'] = 1, ['elsewhere'] = 1, ['empty'] = 1, ['enough'] = 1, ['etc'] = 1, ['even'] = 1, ['ever'] = 1, ['every'] = 1, ['everyone'] = 1, ['everything'] = 1, ['everywhere'] = 1, ['except'] = 1, ['few'] = 1, ['fifteen'] = 1, ['fify'] = 1, ['fill'] = 1, ['find'] = 1, ['fire'] = 1, ['first'] = 1, ['five'] = 1, ['for'] = 1, ['former'] = 1, ['formerly'] = 1, ['forty'] = 1, ['found'] = 1, ['four'] = 1, ['from'] = 1, ['front'] = 1, ['full'] = 1, ['further'] = 1, ['get'] = 1, ['give'] = 1, ['go'] = 1, ['had'] = 1, ['has'] = 1, ['hasnt'] = 1, ['have'] = 1, ['he'] = 1, ['hence'] = 1, ['her'] = 1, ['here'] = 1, ['hereafter'] = 1, ['hereby'] = 1, ['herein'] = 1, ['hereupon'] = 1, ['hers'] = 1, ['herself'] = 1, ['him'] = 1, ['himself'] = 1, ['his'] = 1, ['how'] = 1, ['however'] = 1, ['hundred'] = 1, ['i'] = 1, ['ie'] = 1, ['if'] = 1, ['in'] = 1, ['inc'] = 1, ['indeed'] = 1, ['interest'] = 1, ['into'] = 1, ['is'] = 1, ['it'] = 1, ['its'] = 1, ['itself'] = 1, ['keep'] = 1, ['last'] = 1, ['latter'] = 1, ['latterly'] = 1, ['least'] = 1, ['less'] = 1, ['ltd'] = 1, ['made'] = 1, ['many'] = 1, ['may'] = 1, ['me'] = 1, ['meanwhile'] = 1, ['might'] = 1, ['mill'] = 1, ['mine'] = 1, ['more'] = 1, ['moreover'] = 1, ['most'] = 1, ['mostly'] = 1, ['move'] = 1, ['much'] = 1, ['must'] = 1, ['my'] = 1, ['myself'] = 1, ['name'] = 1, ['namely'] = 1, ['neither'] = 1, ['never'] = 1, ['nevertheless'] = 1, ['next'] = 1, ['nine'] = 1, ['no'] = 1, ['nobody'] = 1, ['none'] = 1, ['noone'] = 1, ['nor'] = 1, ['not'] = 1, ['nothing'] = 1, ['now'] = 1, ['nowhere'] = 1, ['of'] = 1, ['off'] = 1, ['often'] = 1, ['on'] = 1, ['once'] = 1, ['one'] = 1, ['only'] = 1, ['onto'] = 1, ['or'] = 1, ['other'] = 1, ['others'] = 1, ['otherwise'] = 1, ['our'] = 1, ['ours'] = 1, ['ourselves'] = 1, ['out'] = 1, ['over'] = 1, ['own'] = 1, ['part'] = 1, ['per'] = 1, ['perhaps'] = 1, ['please'] = 1, ['put'] = 1, ['rather'] = 1, ['re'] = 1, ['same'] = 1, ['see'] = 1, ['seem'] = 1, ['seemed'] = 1, ['seeming'] = 1, ['seems'] = 1, ['serious'] = 1, ['several'] = 1, ['she'] = 1, ['should'] = 1, ['show'] = 1, ['side'] = 1, ['since'] = 1, ['sincere'] = 1, ['six'] = 1, ['sixty'] = 1, ['so'] = 1, ['some'] = 1, ['somehow'] = 1, ['someone'] = 1, ['something'] = 1, ['sometime'] = 1, ['sometimes'] = 1, ['somewhere'] = 1, ['still'] = 1, ['such'] = 1, ['system'] = 1, ['take'] = 1, ['ten'] = 1, ['than'] = 1, ['that'] = 1, ['the'] = 1, ['their'] = 1, ['them'] = 1, ['themselves'] = 1, ['then'] = 1, ['thence'] = 1, ['there'] = 1, ['thereafter'] = 1, ['thereby'] = 1, ['therefore'] = 1, ['therein'] = 1, ['thereupon'] = 1, ['these'] = 1, ['they'] = 1, ['thick'] = 1, ['thin'] = 1, ['third'] = 1, ['this'] = 1, ['those'] = 1, ['though'] = 1, ['three'] = 1, ['through'] = 1, ['throughout'] = 1, ['thru'] = 1, ['thus'] = 1, ['to'] = 1, ['together'] = 1, ['too'] = 1, ['top'] = 1, ['toward'] = 1, ['towards'] = 1, ['twelve'] = 1, ['twenty'] = 1, ['two'] = 1, ['un'] = 1, ['under'] = 1, ['until'] = 1, ['up'] = 1, ['upon'] = 1, ['us'] = 1, ['very'] = 1, ['via'] = 1, ['was'] = 1, ['we'] = 1, ['well'] = 1, ['were'] = 1, ['what'] = 1, ['whatever'] = 1, ['when'] = 1, ['whence'] = 1, ['whenever'] = 1, ['where'] = 1, ['whereafter'] = 1, ['whereas'] = 1, ['whereby'] = 1, ['wherein'] = 1, ['whereupon'] = 1, ['wherever'] = 1, ['whether'] = 1, ['which'] = 1, ['while'] = 1, ['whither'] = 1, ['who'] = 1, ['whoever'] = 1, ['whole'] = 1, ['whom'] = 1, ['whose'] = 1, ['why'] = 1, ['will'] = 1, ['with'] = 1, ['within'] = 1, ['without'] = 1, ['would'] = 1, ['yet'] = 1, ['you'] = 1, ['your'] = 1, ['yours'] = 1, ['yourself'] = 1, ['yourselves'] = 1, ['st'] = 1, ['years'] = 1, ['yourselves'] = 1, ['new'] = 1, ['used'] = 1, ['known'] = 1, ['year'] = 1, ['later'] = 1, ['including'] = 1, ['used'] = 1, ['end'] = 1, ['did'] = 1, ['just'] = 1, ['best'] = 1, ['using'] = 1}
2 |
3 | function is_stop_word_or_number(w)
4 | assert(type(w) == 'string', w)
5 | if (all_stop_words[w:lower()]) or tonumber(w) or w:len() <= 1 then
6 | return true
7 | else
8 | return false
9 | end
10 | end
--------------------------------------------------------------------------------
/words/w2v/glove_reader.lua:
--------------------------------------------------------------------------------
1 | local M = torch.zeros(total_num_words(), word_vecs_size):float()
2 |
3 | --Reading Contents
4 | for line in io.lines(w2v_txtfilename) do
5 | local parts = split(line, ' ')
6 | local w = parts[1]
7 | local w_id = get_id_from_word(w)
8 | if w_id ~= unk_w_id then
9 | for i=2, #parts do
10 | M[w_id][i-1] = tonumber(parts[i])
11 | end
12 | end
13 | end
14 |
15 | return M
16 |
17 |
18 |
--------------------------------------------------------------------------------
/words/w2v/w2v.lua:
--------------------------------------------------------------------------------
1 | -- Loads pre-trained word embeddings from either Word2Vec or Glove
2 |
3 | assert(get_id_from_word)
4 | assert(common_w2v_freq_words)
5 | assert(total_num_words)
6 |
7 | word_vecs_size = 300
8 |
9 | -- Loads pre-trained glove or word2vec embeddings:
10 | if opt.word_vecs == 'glove' then
11 | -- Glove downloaded from: http://nlp.stanford.edu/projects/glove/
12 | w2v_txtfilename = default_path .. 'Glove/glove.840B.300d.txt'
13 | w2v_t7filename = opt.root_data_dir .. 'generated/glove.840B.300d.t7'
14 | w2v_reader = 'words/w2v/glove_reader.lua'
15 | elseif opt.word_vecs == 'w2v' then
16 | -- Word2Vec downloaded from: https://code.google.com/archive/p/word2vec/
17 | w2v_binfilename = default_path .. 'Word2Vec/GoogleNews-vectors-negative300.bin'
18 | w2v_t7filename = opt.root_data_dir .. 'generated/GoogleNews-vectors-negative300.t7'
19 | w2v_reader = 'words/w2v/word2vec_reader.lua'
20 | end
21 |
22 | ---------------------- Code: -----------------------
23 | w2vutils = {}
24 |
25 | print('==> Loading ' .. opt.word_vecs .. ' vectors')
26 | if not paths.filep(w2v_t7filename) then
27 | print(' ---> t7 file NOT found. Loading w2v from the bin/txt file instead (slower).')
28 | w2vutils.M = require(w2v_reader)
29 | print('Writing t7 File for future usage. Next time Word2Vec loading will be faster!')
30 | torch.save(w2v_t7filename, w2vutils.M)
31 | else
32 | print(' ---> from t7 file.')
33 | w2vutils.M = torch.load(w2v_t7filename)
34 | end
35 |
36 | -- Move the word embedding matrix on the GPU if we do some training.
37 | -- In this way we can perform word embedding lookup much faster.
38 | if opt and string.find(opt.type, 'cuda') then
39 | w2vutils.M = w2vutils.M:cuda()
40 | end
41 |
42 | ---------- Define additional functions -----------------
43 | -- word -> vec
44 | w2vutils.get_w_vec = function (self,word)
45 | local w_id = get_id_from_word(word)
46 | return w2vutils.M[w_id]:clone()
47 | end
48 |
49 | -- word_id -> vec
50 | w2vutils.get_w_vec_from_id = function (self,w_id)
51 | return w2vutils.M[w_id]:clone()
52 | end
53 |
54 | w2vutils.lookup_w_vecs = function (self,word_id_tensor)
55 | assert(word_id_tensor:dim() <= 2, 'Only word id tensors w/ 1 or 2 dimensions are supported.')
56 | local output = torch.FloatTensor()
57 | local word_ids = word_id_tensor:long()
58 | if opt and string.find(opt.type, 'cuda') then
59 | output = output:cuda()
60 | word_ids = word_ids:cuda()
61 | end
62 |
63 | if word_ids:dim() == 2 then
64 | output:index(w2vutils.M, 1, word_ids:view(-1))
65 | output = output:view(word_ids:size(1), word_ids:size(2), w2vutils.M:size(2))
66 | elseif word_ids:dim() == 1 then
67 | output:index(w2vutils.M, 1, word_ids)
68 | output = output:view(word_ids:size(1), w2vutils.M:size(2))
69 | end
70 |
71 | return output
72 | end
73 |
74 | -- Normalize word vectors to have norm 1 .
75 | w2vutils.renormalize = function (self)
76 | w2vutils.M[unk_w_id]:mul(0)
77 | w2vutils.M[unk_w_id]:add(1)
78 | w2vutils.M:cdiv(w2vutils.M:norm(2,2):expand(w2vutils.M:size()))
79 | local x = w2vutils.M:norm(2,2):view(-1) - 1
80 | assert(x:norm() < 0.1, x:norm())
81 | assert(w2vutils.M[100]:norm() < 1.001 and w2vutils.M[100]:norm() > 0.99)
82 | w2vutils.M[unk_w_id]:mul(0)
83 | end
84 |
85 | w2vutils:renormalize()
86 |
87 | print(' Done reading w2v data. Word vocab size = ' .. w2vutils.M:size(1))
88 |
89 | -- Phrase embedding using average of vectors of words in the phrase
90 | w2vutils.phrase_avg_vec = function(self, phrase)
91 | local words = split_in_words(phrase)
92 | local num_words = table_len(words)
93 | local num_existent_words = 0
94 | local vec = torch.zeros(word_vecs_size)
95 | for i = 1,num_words do
96 | local w = words[i]
97 | local w_id = get_id_from_word(w)
98 | if w_id ~= unk_w_id then
99 | vec:add(w2vutils:get_w_vec_from_id(w_id))
100 | num_existent_words = num_existent_words + 1
101 | end
102 | end
103 | if (num_existent_words > 0) then
104 | vec:div(num_existent_words)
105 | end
106 | return vec
107 | end
108 |
109 | w2vutils.top_k_closest_words = function (self,vec, k, mat)
110 | local k = k or 1
111 | vec = vec:float()
112 | local distances = torch.mv(mat, vec)
113 | local best_scores, best_word_ids = topk(distances, k)
114 | local returnwords = {}
115 | local returndistances = {}
116 | for i = 1,k do
117 | local w = get_word_from_id(best_word_ids[i])
118 | if is_stop_word_or_number(w) then
119 | table.insert(returnwords, red(w))
120 | else
121 | table.insert(returnwords, w)
122 | end
123 | assert(best_scores[i] == distances[best_word_ids[i]], best_scores[i] .. ' ' .. distances[best_word_ids[i]])
124 | table.insert(returndistances, distances[best_word_ids[i]])
125 | end
126 | return returnwords, returndistances
127 | end
128 |
129 | w2vutils.most_similar2word = function(self, word, k)
130 | local k = k or 1
131 | local v = w2vutils:get_w_vec(word)
132 | neighbors, scores = w2vutils:top_k_closest_words(v, k, w2vutils.M)
133 | print('To word ' .. skyblue(word) .. ' : ' .. list_with_scores_to_str(neighbors, scores))
134 | end
135 |
136 | w2vutils.most_similar2vec = function(self, vec, k)
137 | local k = k or 1
138 | neighbors, scores = w2vutils:top_k_closest_words(vec, k, w2vutils.M)
139 | print(list_with_scores_to_str(neighbors, scores))
140 | end
141 |
142 |
143 | --------------------- Unit tests ----------------------------------------
144 | local unit_tests = opt.unit_tests or false
145 | if (unit_tests) then
146 | print('\nWord to word similarity test:')
147 | w2vutils:most_similar2word('nice', 5)
148 | w2vutils:most_similar2word('france', 5)
149 | w2vutils:most_similar2word('hello', 5)
150 | end
151 |
152 | -- Computes for each word w : \sum_v exp() and \sum_v
153 | w2vutils.total_word_correlation = function(self, k, j)
154 | local exp_Z = torch.zeros(w2vutils.M:narrow(1, 1, j):size(1))
155 |
156 | local sum_t = w2vutils.M:narrow(1, 1, j):sum(1) -- 1 x d
157 | local sum_Z = (w2vutils.M:narrow(1, 1, j) * sum_t:t()):view(-1) -- num_w
158 |
159 | print(red('Top words by sum_Z:'))
160 | best_sum_Z, best_word_ids = topk(sum_Z, k)
161 | for i = 1,k do
162 | local w = get_word_from_id(best_word_ids[i])
163 | assert(best_sum_Z[i] == sum_Z[best_word_ids[i]])
164 | print(w .. ' [' .. best_sum_Z[i] .. ']; ')
165 | end
166 |
167 | print('\n' .. red('Bottom words by sum_Z:'))
168 | best_sum_Z, best_word_ids = topk(- sum_Z, k)
169 | for i = 1,k do
170 | local w = get_word_from_id(best_word_ids[i])
171 | assert(best_sum_Z[i] == - sum_Z[best_word_ids[i]])
172 | print(w .. ' [' .. sum_Z[best_word_ids[i]] .. ']; ')
173 | end
174 | end
175 |
176 |
177 | -- Plot with gnuplot:
178 | -- set palette model RGB defined ( 0 'white', 1 'pink', 2 'green' , 3 'blue', 4 'red' )
179 | -- plot 'tsne-w2v-vecs.txt_1000' using 1:2:3 with labels offset 0,1, '' using 1:2:4 w points pt 7 ps 2 palette
180 | w2vutils.tsne = function(self, num_rand_words)
181 | local topic1 = {'japan', 'china', 'france', 'switzerland', 'romania', 'india', 'australia', 'country', 'city', 'tokyo', 'nation', 'capital', 'continent', 'europe', 'asia', 'earth', 'america'}
182 | local topic2 = {'football', 'striker', 'goalkeeper', 'basketball', 'coach', 'championship', 'cup',
183 | 'soccer', 'player', 'captain', 'qualifier', 'goal', 'under-21', 'halftime', 'standings', 'basketball',
184 | 'games', 'league', 'rugby', 'hockey', 'fifa', 'fans', 'maradona', 'mutu', 'hagi', 'beckham', 'injury', 'game',
185 | 'kick', 'penalty'}
186 | local topic_avg = {'japan national football team', 'germany national football team',
187 | 'china national football team', 'brazil soccer', 'japan soccer', 'germany soccer', 'china soccer',
188 | 'fc barcelona', 'real madrid'}
189 |
190 | local stop_words_array = {}
191 | for w,_ in pairs(stop_words) do
192 | table.insert(stop_words_array, w)
193 | end
194 |
195 | local topic1_len = table_len(topic1)
196 | local topic2_len = table_len(topic2)
197 | local topic_avg_len = table_len(topic_avg)
198 | local stop_words_len = table_len(stop_words_array)
199 |
200 | torch.setdefaulttensortype('torch.DoubleTensor')
201 | w2vutils.M = w2vutils.M:double()
202 |
203 | local tensor = torch.zeros(num_rand_words + stop_words_len + topic1_len + topic2_len + topic_avg_len, word_vecs_size)
204 | local tensor_w_ids = torch.zeros(num_rand_words)
205 | local tensor_colors = torch.zeros(tensor:size(1))
206 |
207 | for i = 1,num_rand_words do
208 | tensor_w_ids[i] = math.random(1,25000)
209 | tensor_colors[i] = 0
210 | tensor[i]:copy(w2vutils.M[tensor_w_ids[i]])
211 | end
212 |
213 | for i = 1, stop_words_len do
214 | tensor_colors[num_rand_words + i] = 1
215 | tensor[num_rand_words + i]:copy(w2vutils:phrase_avg_vec(stop_words_array[i]))
216 | end
217 |
218 | for i = 1, topic1_len do
219 | tensor_colors[num_rand_words + stop_words_len + i] = 2
220 | tensor[num_rand_words + stop_words_len + i]:copy(w2vutils:phrase_avg_vec(topic1[i]))
221 | end
222 |
223 | for i = 1, topic2_len do
224 | tensor_colors[num_rand_words + stop_words_len + topic1_len + i] = 3
225 | tensor[num_rand_words + stop_words_len + topic1_len + i]:copy(w2vutils:phrase_avg_vec(topic2[i]))
226 | end
227 |
228 | for i = 1, topic_avg_len do
229 | tensor_colors[num_rand_words + stop_words_len + topic1_len + topic2_len + i] = 4
230 | tensor[num_rand_words + stop_words_len + topic1_len + topic2_len + i]:copy(w2vutils:phrase_avg_vec(topic_avg[i]))
231 | end
232 |
233 | local manifold = require 'manifold'
234 | opts = {ndims = 2, perplexity = 30, pca = 50, use_bh = false}
235 | mapped_x1 = manifold.embedding.tsne(tensor, opts)
236 | assert(mapped_x1:size(1) == tensor:size(1) and mapped_x1:size(2) == 2)
237 | ouf_vecs = assert(io.open('tsne-w2v-vecs.txt_' .. num_rand_words, "w"))
238 | for i = 1,mapped_x1:size(1) do
239 | local w = nil
240 | if tensor_colors[i] == 0 then
241 | w = get_word_from_id(tensor_w_ids[i])
242 | elseif tensor_colors[i] == 1 then
243 | w = stop_words_array[i - num_rand_words]:gsub(' ', '-')
244 | elseif tensor_colors[i] == 2 then
245 | w = topic1[i - num_rand_words - stop_words_len]:gsub(' ', '-')
246 | elseif tensor_colors[i] == 3 then
247 | w = topic2[i - num_rand_words - stop_words_len - topic1_len]:gsub(' ', '-')
248 | elseif tensor_colors[i] == 4 then
249 | w = topic_avg[i - num_rand_words - stop_words_len - topic1_len - topic2_len]:gsub(' ', '-')
250 | end
251 | assert(w)
252 |
253 | local v = mapped_x1[i]
254 | for j = 1,2 do
255 | ouf_vecs:write(v[j] .. ' ')
256 | end
257 | ouf_vecs:write(w .. ' ' .. tensor_colors[i] .. '\n')
258 | end
259 | io.close(ouf_vecs)
260 | print(' DONE')
261 | end
262 |
--------------------------------------------------------------------------------
/words/w2v/word2vec_reader.lua:
--------------------------------------------------------------------------------
1 | -- Adapted from https://github.com/rotmanmi/word2vec.torch
2 | function read_string_w2v(file)
3 | local str = {}
4 | while true do
5 | local char = file:readChar()
6 | if char == 32 or char == 10 or char == 0 then
7 | break
8 | else
9 | str[#str+1] = char
10 | end
11 | end
12 | str = torch.CharStorage(str)
13 | return str:string()
14 | end
15 |
16 |
17 | file = torch.DiskFile(w2v_binfilename,'r')
18 |
19 | --Reading Header
20 | file:ascii()
21 | local vocab_size = file:readInt()
22 | local size = file:readInt()
23 | assert(size == word_vecs_size, 'Wrong size : ' .. size .. ' vs ' .. word_vecs_size)
24 |
25 | local M = torch.zeros(total_num_words(), word_vecs_size):float()
26 |
27 | --Reading Contents
28 | file:binary()
29 | local num_phrases = 0
30 | for i = 1,vocab_size do
31 | local w = read_string_w2v(file)
32 | local v = torch.FloatTensor(file:readFloat(word_vecs_size))
33 | local w_id = get_id_from_word(w)
34 | if w_id ~= unk_w_id then
35 | M[w_id]:copy(v)
36 | end
37 | end
38 |
39 | print('Num words = ' .. total_num_words() .. '. Num phrases = ' .. num_phrases)
40 |
41 | return M
42 |
--------------------------------------------------------------------------------
/words/w_freq/w_freq_gen.lua:
--------------------------------------------------------------------------------
1 | -- Computes an unigram frequency of each word in the Wikipedia corpus
2 |
3 | if not opt then
4 | cmd = torch.CmdLine()
5 | cmd:option('-root_data_dir', '', 'Root path of the data, $DATA_PATH.')
6 | cmd:text()
7 | opt = cmd:parse(arg or {})
8 | assert(opt.root_data_dir ~= '', 'Specify a valid root_data_dir path argument.')
9 | end
10 |
11 |
12 | require 'torch'
13 | dofile 'utils/utils.lua'
14 |
15 | tds = tds or require 'tds'
16 |
17 | word_freqs = tds.Hash()
18 |
19 | local num_lines = 0
20 | it, _ = io.open(opt.root_data_dir .. 'generated/wiki_canonical_words.txt')
21 | line = it:read()
22 |
23 | while (line) do
24 | num_lines = num_lines + 1
25 | if num_lines % 100000 == 0 then
26 | print('Processed ' .. num_lines .. ' lines. ')
27 | end
28 |
29 | local parts = split(line , '\t')
30 | local words = split(parts[3], ' ')
31 | for _,w in pairs(words) do
32 | if (not word_freqs[w]) then
33 | word_freqs[w] = 0
34 | end
35 | word_freqs[w] = word_freqs[w] + 1
36 | end
37 | line = it:read()
38 | end
39 |
40 |
41 | -- Writing word frequencies
42 | print('Sorting and writing')
43 | sorted_word_freq = {}
44 | for w,freq in pairs(word_freqs) do
45 | if freq >= 10 then
46 | table.insert(sorted_word_freq, {w = w, freq = freq})
47 | end
48 | end
49 |
50 | table.sort(sorted_word_freq, function(a,b) return a.freq > b.freq end)
51 |
52 | out_file = opt.root_data_dir .. 'generated/word_wiki_freq.txt'
53 | ouf = assert(io.open(out_file, "w"))
54 | total_freq = 0
55 | for _,x in pairs(sorted_word_freq) do
56 | ouf:write(x.w .. '\t' .. x.freq .. '\n')
57 | total_freq = total_freq + x.freq
58 | end
59 | ouf:flush()
60 | io.close(ouf)
61 |
62 | print('Total freq = ' .. total_freq .. '\n')
63 |
--------------------------------------------------------------------------------
/words/w_freq/w_freq_index.lua:
--------------------------------------------------------------------------------
1 | -- Loads all words and their frequencies and IDs from a dictionary.
2 | assert(common_w2v_freq_words)
3 | if not opt.unig_power then
4 | opt.unig_power = 0.6
5 | end
6 |
7 | print('==> Loading word freq map with unig power ' .. red(opt.unig_power))
8 | local w_freq_file = opt.root_data_dir .. 'generated/word_wiki_freq.txt'
9 |
10 | local w_freq = {}
11 | w_freq.id2word = tds.Hash()
12 | w_freq.word2id = tds.Hash()
13 |
14 | w_freq.w_f_start = tds.Hash()
15 | w_freq.w_f_end = tds.Hash()
16 | w_freq.total_freq = 0.0
17 |
18 | w_freq.w_f_at_unig_power_start = tds.Hash()
19 | w_freq.w_f_at_unig_power_end = tds.Hash()
20 | w_freq.total_freq_at_unig_power = 0.0
21 |
22 | -- UNK word id
23 | unk_w_id = 1
24 | w_freq.word2id['UNK_W'] = unk_w_id
25 | w_freq.id2word[unk_w_id] = 'UNK_W'
26 |
27 | local tmp_wid = 1
28 | for line in io.lines(w_freq_file) do
29 | local parts = split(line, '\t')
30 | local w = parts[1]
31 | if common_w2v_freq_words[w] then
32 | tmp_wid = tmp_wid + 1
33 | local w_id = tmp_wid
34 | w_freq.id2word[w_id] = w
35 | w_freq.word2id[w] = w_id
36 |
37 | local w_f = tonumber(parts[2])
38 | assert(w_f)
39 | if w_f < 100 then
40 | w_f = 100
41 | end
42 | w_freq.w_f_start[w_id] = w_freq.total_freq
43 | w_freq.total_freq = w_freq.total_freq + w_f
44 | w_freq.w_f_end[w_id] = w_freq.total_freq
45 |
46 | w_freq.w_f_at_unig_power_start[w_id] = w_freq.total_freq_at_unig_power
47 | w_freq.total_freq_at_unig_power = w_freq.total_freq_at_unig_power + math.pow(w_f, opt.unig_power)
48 | w_freq.w_f_at_unig_power_end[w_id] = w_freq.total_freq_at_unig_power
49 | end
50 | end
51 |
52 | w_freq.total_num_words = tmp_wid
53 |
54 | print(' Done loading word freq index. Num words = ' .. w_freq.total_num_words .. '; total freq = ' .. w_freq.total_freq)
55 |
56 | --------------------------------------------
57 | total_num_words = function()
58 | return w_freq.total_num_words
59 | end
60 |
61 | contains_w_id = function(w_id)
62 | assert(w_id >= 1 and w_id <= total_num_words(), w_id)
63 | return (w_id ~= unk_w_id)
64 | end
65 |
66 | -- id -> word
67 | get_word_from_id = function(w_id)
68 | assert(w_id >= 1 and w_id <= total_num_words(), w_id)
69 | return w_freq.id2word[w_id]
70 | end
71 |
72 | -- word -> id
73 | get_id_from_word = function(w)
74 | local w_id = w_freq.word2id[w]
75 | if w_id == nil then
76 | return unk_w_id
77 | end
78 | return w_id
79 | end
80 |
81 | contains_w = function(w)
82 | return contains_w_id(get_id_from_word(w))
83 | end
84 |
85 | -- word frequency:
86 | function get_w_id_freq(w_id)
87 | assert(contains_w_id(w_id), w_id)
88 | return w_freq.w_f_end[w_id] - w_freq.w_f_start[w_id] + 1
89 | end
90 |
91 | -- p(w) prior:
92 | function get_w_id_unigram(w_id)
93 | return get_w_id_freq(w_id) / w_freq.total_freq
94 | end
95 |
96 | function get_w_tensor_log_unigram(vec_w_ids)
97 | assert(vec_w_ids:dim() == 2)
98 | local v = torch.zeros(vec_w_ids:size(1), vec_w_ids:size(2))
99 | for i= 1,vec_w_ids:size(1) do
100 | for j = 1,vec_w_ids:size(2) do
101 | v[i][j] = math.log(get_w_id_unigram(vec_w_ids[i][j]))
102 | end
103 | end
104 | return v
105 | end
106 |
107 |
108 | if (opt.unit_tests) then
109 | print(get_w_id_unigram(get_id_from_word('the')))
110 | print(get_w_id_unigram(get_id_from_word('of')))
111 | print(get_w_id_unigram(get_id_from_word('and')))
112 | print(get_w_id_unigram(get_id_from_word('romania')))
113 | end
114 |
115 | -- Generates an random word sampled from the word unigram frequency.
116 | local function random_w_id(total_freq, w_f_start, w_f_end)
117 | local j = math.random() * total_freq
118 | local i_start = 2
119 | local i_end = total_num_words()
120 |
121 | while i_start <= i_end do
122 | local i_mid = math.floor((i_start + i_end) / 2)
123 | local w_id_mid = i_mid
124 | if w_f_start[w_id_mid] <= j and j <= w_f_end[w_id_mid] then
125 | return w_id_mid
126 | elseif (w_f_start[w_id_mid] > j) then
127 | i_end = i_mid - 1
128 | elseif (w_f_end[w_id_mid] < j) then
129 | i_start = i_mid + 1
130 | end
131 | end
132 | print(red('Binary search error !!'))
133 | end
134 |
135 | -- Frequent word subsampling procedure from the Word2Vec paper.
136 | function random_unigram_at_unig_power_w_id()
137 | return random_w_id(w_freq.total_freq_at_unig_power, w_freq.w_f_at_unig_power_start, w_freq.w_f_at_unig_power_end)
138 | end
139 |
140 |
141 | function get_w_unnorm_unigram_at_power(w_id)
142 | return math.pow(get_w_id_unigram(w_id), opt.unig_power)
143 | end
144 |
145 |
146 | function unit_test_random_unigram_at_unig_power_w_id(k_samples)
147 | local empirical_dist = {}
148 | for i=1,k_samples do
149 | local w_id = random_unigram_at_unig_power_w_id()
150 | assert(w_id ~= unk_w_id)
151 | if not empirical_dist[w_id] then
152 | empirical_dist[w_id] = 0
153 | end
154 | empirical_dist[w_id] = empirical_dist[w_id] + 1
155 | end
156 | print('Now sorting ..')
157 | local sorted_empirical_dist = {}
158 | for k,v in pairs(empirical_dist) do
159 | table.insert(sorted_empirical_dist, {w_id = k, f = v})
160 | end
161 | table.sort(sorted_empirical_dist, function(a,b) return a.f > b.f end)
162 |
163 | local str = ''
164 | for i = 1,math.min(100, table_len(sorted_empirical_dist)) do
165 | str = str .. get_word_from_id(sorted_empirical_dist[i].w_id) .. '{' .. sorted_empirical_dist[i].f .. '}; '
166 | end
167 | print('Unit test random sampling: ' .. str)
168 | end
169 |
--------------------------------------------------------------------------------