├── .gitignore
├── README.md
├── bleu.py
├── eval.py
├── generate.py
├── prep_ada.jl
├── prep_embedding_matrix.py
├── prep_vocab.py
├── prep_w2v.py
├── source
├── __init__.py
├── attention_skipgram.py
├── constants.py
├── datasets.py
├── layers.py
├── model.py
├── pipeline.py
└── utils.py
├── train.py
└── train_attention_skipgram.py
/.gitignore:
--------------------------------------------------------------------------------
1 | *.pyc
2 | data/
3 | nohup.out
4 | .ipynb_checkpoints
5 | sftp-config.json
6 | .DS_Store
7 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Conditional Generators of Words Definitions
2 |
3 | This repo contains code for our paper [Conditional Generators of Words Definitions](https://arxiv.org/abs/1806.10090).
4 |
5 | __Abstract__
6 |
7 | We explore recently introduced definition modeling technique that provided the tool for evaluation of different distributed
8 | vector representations of words through modeling dictionary definitions of words. In this work, we study the problem of word ambiguities in definition modeling and propose a possible solution by employing latent variable modeling and soft attention mechanisms. Our quantitative and qualitative evaluation and analysis of the model shows that taking into account words ambiguity and polysemy leads to performance improvement.
9 |
10 | # Citation
11 |
12 | ```
13 | @InProceedings{P18-2043,
14 | author = "Gadetsky, Artyom and Yakubovskiy, Ilya and Vetrov, Dmitry",
15 | title = "Conditional Generators of Words Definitions",
16 | booktitle = "Proceedings of the 56th Annual Meeting of the Association for Computational Linguistics (Volume 2: Short Papers)",
17 | year = "2018",
18 | publisher = "Association for Computational Linguistics",
19 | pages = "266--271",
20 | location = "Melbourne, Australia",
21 | url = "http://aclweb.org/anthology/P18-2043"
22 | }
23 | ```
24 |
25 | # Environment requirements and Data Preparation
26 |
27 | * Install conda environment with the following packages:
28 |
29 | ```
30 | Python 3.6
31 | Pytorch 0.4
32 | Numpy 1.14
33 | Tqdm 4.23
34 | Gensim 3.4
35 | ```
36 |
37 | * To install AdaGram software to use Adaptive conditioning:
38 |
39 | Download Julia 0.6 binaries from [official site](https://julialang.org/downloads/) and add alias in ~/.bashrc
40 | ```
41 | alias julia='JULIA_BINARY_PATH/bin/julia'
42 | ```
43 | Use `source ~/.bashrc` to reload ~/.bashrc
44 |
45 | Then activate julia interpreter using `julia` and install following packages:
46 | ```
47 | Pkg.clone("https://github.com/mirestrepo/AdaGram.jl")
48 | Pkg.build("AdaGram")
49 | Pkg.add("ArgParse")
50 | Pkg.add("JSON")
51 | Pkg.add("NPZ")
52 | exit()
53 | ```
54 | Then add in ~/.bashrc
55 | ```
56 | export PATH="JULIA_BINARY_PATH/bin:$PATH"
57 | export LD_LIBRARY_PATH="JULIA_INSTALL_PATH/v0.6/AdaGram/lib:$LD_LIBRARY_PATH"
58 | ```
59 | And finally to apply exports
60 | ```
61 | source ~/.bashrc
62 | ```
63 | * To install Mosesdecoder (for BLEU) follow instructions on the [official site](http://www.statmt.org/moses/?n=Development.GetStarted)
64 |
65 | * To get data for language model (LM) pretraining:
66 | ```
67 | cd pytorch-definitions
68 | mkdir data
69 | cd data
70 | wget https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-103-v1.zip
71 | unzip wikitext-103-v1.zip
72 | ```
73 | * To get data for Google word vectors use [official site](https://code.google.com/archive/p/word2vec/). You need .bin.gz file. Don't forget to `gunzip` downloaded file to extract binaries
74 |
75 | * Adaptive Skip-gram vectors are available upon request. Also you can train your owns using instructions in the [official repo](https://github.com/sbos/AdaGram.jl)
76 |
77 | * The Definition Modeling data is available upon request because of Oxford Dictionaries distribution license. Also you can collect your own. If you want to collect your own, then you should prepare 3 datasplits: train, test and val. Each datasplit is python array with the following format saved as json file:
78 |
79 | ```
80 | data = [
81 | [
82 | ["word"],
83 | ["word1", "word2", ...],
84 | ["word1", "word2", ...]
85 | ],
86 | ...
87 | ]
88 | So i-th element of the data:
89 | data[i][0][0] - word being defined (string)
90 | data[i][1] - definition (list of strings)
91 | data[i][2] - context to understand word meaning (list of strings)
92 | ```
93 |
94 | # Usage
95 | Firstly, you need to prepare vocabs, vectors and etc for using model:
96 |
97 | * To prepare vocabs use `python prep_vocab.py`
98 |
99 | ```
100 | usage: prep_vocab.py [-h] --defs DEFS [DEFS ...] [--lm LM [LM ...]] [--same]
101 | --save SAVE [--save_context SAVE_CONTEXT] --save_chars
102 | SAVE_CHARS
103 |
104 | Prepare vocabularies for model
105 |
106 | optional arguments:
107 | -h, --help show this help message and exit
108 | --defs DEFS [DEFS ...]
109 | location of json file with definitions.
110 | --lm LM [LM ...] location of txt file with text for LM pre-training
111 | --same use same vocab for definitions and contexts
112 | --save SAVE where to save prepaired vocabulary (for words from
113 | definitions)
114 | --save_context SAVE_CONTEXT
115 | where to save vocabulary (for words from contexts)
116 | --save_chars SAVE_CHARS
117 | where to save char vocabulary (for chars from all
118 | words)
119 | ```
120 |
121 | * To prepare w2v vectors use `python prep_w2v.py`
122 | ```
123 | usage: prep_w2v.py [-h] --defs DEFS [DEFS ...] --save SAVE [SAVE ...] --w2v
124 | W2V
125 |
126 | Prepare word vectors for Input conditioning
127 |
128 | optional arguments:
129 | -h, --help show this help message and exit
130 | --defs DEFS [DEFS ...]
131 | location of json file with definitions.
132 | --save SAVE [SAVE ...]
133 | where to save files
134 | --w2v W2V location of binary w2v file
135 | ```
136 |
137 | * To prepare Adagram vectors use `julia prep_ada.jl`
138 | ```
139 | usage: prep_ada.jl --defs DEFS [DEFS...] --save SAVE [SAVE...]
140 | --ada ADA [-h]
141 |
142 | Prepare word vectors for Input-Adaptive conditioning
143 |
144 | optional arguments:
145 | --defs DEFS [DEFS...]
146 | location of json file with definitions.
147 | --save SAVE [SAVE...]
148 | where to save files
149 | --ada ADA location of AdaGram file
150 | -h, --help show this help message and exit
151 | ```
152 | * If you want to init embedding matrix of the model with Google word vectors then prepare it using
153 | `python prep_embedding_matrix.py` and then use path to saved weights as `--w2v_weights` in `train.py`
154 | ```
155 | usage: prep_embedding_matrix.py [-h] --voc VOC --w2v W2V --save SAVE
156 |
157 | Prepare word vectors for embedding layer in the model
158 |
159 | optional arguments:
160 | -h, --help show this help message and exit
161 | --voc VOC location of model vocabulary file
162 | --w2v W2V location of binary w2v file
163 | --save SAVE where to save prepaired matrix
164 | ```
165 |
166 | Now all is already ready for model usage!
167 |
168 | * To train model use `python train.py`
169 | ```
170 | usage: train.py [-h] [--pretrain] --voc VOC [--train_defs TRAIN_DEFS]
171 | [--eval_defs EVAL_DEFS] [--test_defs TEST_DEFS]
172 | [--input_train INPUT_TRAIN] [--input_eval INPUT_EVAL]
173 | [--input_test INPUT_TEST]
174 | [--input_adaptive_train INPUT_ADAPTIVE_TRAIN]
175 | [--input_adaptive_eval INPUT_ADAPTIVE_EVAL]
176 | [--input_adaptive_test INPUT_ADAPTIVE_TEST]
177 | [--context_voc CONTEXT_VOC] [--ch_voc CH_VOC]
178 | [--train_lm TRAIN_LM] [--eval_lm EVAL_LM] [--test_lm TEST_LM]
179 | [--bptt BPTT] --nx NX --nlayers NLAYERS --nhid NHID
180 | --rnn_dropout RNN_DROPOUT [--use_seed] [--use_input]
181 | [--use_input_adaptive] [--use_input_attention]
182 | [--n_attn_embsize N_ATTN_EMBSIZE] [--n_attn_hid N_ATTN_HID]
183 | [--attn_dropout ATTN_DROPOUT] [--attn_sparse] [--use_ch]
184 | [--ch_emb_size CH_EMB_SIZE]
185 | [--ch_feature_maps CH_FEATURE_MAPS [CH_FEATURE_MAPS ...]]
186 | [--ch_kernel_sizes CH_KERNEL_SIZES [CH_KERNEL_SIZES ...]]
187 | [--use_hidden] [--use_hidden_adaptive]
188 | [--use_hidden_attention] [--use_gated] [--use_gated_adaptive]
189 | [--use_gated_attention] --lr LR --decay_factor DECAY_FACTOR
190 | --decay_patience DECAY_PATIENCE --num_epochs NUM_EPOCHS
191 | --batch_size BATCH_SIZE --clip CLIP --random_seed RANDOM_SEED
192 | --exp_dir EXP_DIR [--w2v_weights W2V_WEIGHTS]
193 | [--fix_embeddings] [--fix_attn_embeddings] [--lm_ckpt LM_CKPT]
194 | [--attn_ckpt ATTN_CKPT]
195 |
196 | Script to train a model
197 |
198 | optional arguments:
199 | -h, --help show this help message and exit
200 | --pretrain whether to pretrain model on LM dataset or train on
201 | definitions
202 | --voc VOC location of vocabulary file
203 | --train_defs TRAIN_DEFS
204 | location of json file with train definitions.
205 | --eval_defs EVAL_DEFS
206 | location of json file with eval definitions.
207 | --test_defs TEST_DEFS
208 | location of json file with test definitions
209 | --input_train INPUT_TRAIN
210 | location of train vectors for Input conditioning
211 | --input_eval INPUT_EVAL
212 | location of eval vectors for Input conditioning
213 | --input_test INPUT_TEST
214 | location of test vectors for Input conditioning
215 | --input_adaptive_train INPUT_ADAPTIVE_TRAIN
216 | location of train vectors for InputAdaptive
217 | conditioning
218 | --input_adaptive_eval INPUT_ADAPTIVE_EVAL
219 | location of eval vectors for InputAdaptive
220 | conditioning
221 | --input_adaptive_test INPUT_ADAPTIVE_TEST
222 | location test vectors for InputAdaptive conditioning
223 | --context_voc CONTEXT_VOC
224 | location of context vocabulary file
225 | --ch_voc CH_VOC location of CH vocabulary file
226 | --train_lm TRAIN_LM location of txt file train LM data
227 | --eval_lm EVAL_LM location of txt file eval LM data
228 | --test_lm TEST_LM location of txt file test LM data
229 | --bptt BPTT sequence length for BackPropThroughTime in LM
230 | pretraining
231 | --nx NX size of embeddings
232 | --nlayers NLAYERS number of LSTM layers
233 | --nhid NHID size of hidden states
234 | --rnn_dropout RNN_DROPOUT
235 | probability of RNN dropout
236 | --use_seed whether to use Seed conditioning or not
237 | --use_input whether to use Input conditioning or not
238 | --use_input_adaptive whether to use InputAdaptive conditioning or not
239 | --use_input_attention
240 | whether to use InputAttention conditioning or not
241 | --n_attn_embsize N_ATTN_EMBSIZE
242 | size of InputAttention embeddings
243 | --n_attn_hid N_ATTN_HID
244 | size of InputAttention linear layer
245 | --attn_dropout ATTN_DROPOUT
246 | probability of InputAttention dropout
247 | --attn_sparse whether to use sparse embeddings in InputAttention or
248 | not
249 | --use_ch whether to use CH conditioning or not
250 | --ch_emb_size CH_EMB_SIZE
251 | size of embeddings in CH conditioning
252 | --ch_feature_maps CH_FEATURE_MAPS [CH_FEATURE_MAPS ...]
253 | list of feature map sizes in CH conditioning
254 | --ch_kernel_sizes CH_KERNEL_SIZES [CH_KERNEL_SIZES ...]
255 | list of kernel sizes in CH conditioning
256 | --use_hidden whether to use Hidden conditioning or not
257 | --use_hidden_adaptive
258 | whether to use HiddenAdaptive conditioning or not
259 | --use_hidden_attention
260 | whether to use HiddenAttention conditioning or not
261 | --use_gated whether to use Gated conditioning or not
262 | --use_gated_adaptive whether to use GatedAdaptive conditioning or not
263 | --use_gated_attention
264 | whether to use GatedAttention conditioning or not
265 | --lr LR initial lr
266 | --decay_factor DECAY_FACTOR
267 | factor to decay lr
268 | --decay_patience DECAY_PATIENCE
269 | after number of patience epochs - decay lr
270 | --num_epochs NUM_EPOCHS
271 | number of epochs to train
272 | --batch_size BATCH_SIZE
273 | batch size
274 | --clip CLIP value to clip norm of gradients to
275 | --random_seed RANDOM_SEED
276 | random seed
277 | --exp_dir EXP_DIR where to save all stuff about training
278 | --w2v_weights W2V_WEIGHTS
279 | path to pretrained embeddings to init
280 | --fix_embeddings whether to update embedding matrix or not
281 | --fix_attn_embeddings
282 | whether to update attention embedding matrix or not
283 | --lm_ckpt LM_CKPT path to pretrained language model weights
284 | --attn_ckpt ATTN_CKPT
285 | path to pretrained Attention module
286 | ```
287 |
288 | For example to train simple language model use:
289 | ```
290 | python train.py --voc VOC_PATH --nx 300 --nhid 300 --rnn_dropout 0.5 --lr 0.001 --decay_factor 0.1 --decay_patience 0
291 | --num_epochs 1 --batch_size 16 --clip 5 --random_seed 42 --exp_dir DIR_PATH -bptt 30
292 | --pretrain --train_lm PATH_TO_WIKI_103_TRAIN --eval_lm PATH_TO_WIKI_103_EVAL --test_lm PATH_TO_WIKI_103_TEST
293 | ```
294 |
295 | For example to train `Seed + Input` model use:
296 | ```
297 | python train.py --voc VOC_PATH --nx 300 --nhid 300 --rnn_dropout 0.5 --lr 0.001 --decay_factor 0.1 --decay_patience 0
298 | --num_epochs 1 --batch_size 16 --clip 5 --random_seed 42 --exp_dir DIR_PATH
299 | --train_defs TRAIN_SPLIT_PATH --eval_defs EVAL_DEFS_PATH --test_defs TEST_DEFS_PATH --use_seed
300 | --use_input --input_train PREPARED_W2V_TRAIN_VECS --input_eval PREPARED_W2V_EVAL_VECS --input_test PREPARED_W2V_TEST_VECS
301 | ```
302 |
303 | To train `Seed + Input` model with pretraining as unconditional LM provide path to pretrained LM weights
as `--lm_ckpt` argument in `train.py`
304 |
305 | * To generate using model use `python generate.py`
306 | ```
307 | usage: generate.py [-h] --params PARAMS --ckpt CKPT --tau TAU --n N --length
308 | LENGTH [--prefix PREFIX] [--wordlist WORDLIST]
309 | [--w2v_binary_path W2V_BINARY_PATH]
310 | [--ada_binary_path ADA_BINARY_PATH]
311 | [--prep_ada_path PREP_ADA_PATH]
312 |
313 | Script to generate using model
314 |
315 | optional arguments:
316 | -h, --help show this help message and exit
317 | --params PARAMS path to saved model params
318 | --ckpt CKPT path to saved model weights
319 | --tau TAU temperature to use in sampling
320 | --n N number of samples to generate
321 | --length LENGTH maximum length of generated samples
322 | --prefix PREFIX prefix to read until generation starts
323 | --wordlist WORDLIST path to word list with words and contexts
324 | --w2v_binary_path W2V_BINARY_PATH
325 | path to binary w2v file
326 | --ada_binary_path ADA_BINARY_PATH
327 | path to binary ada file
328 | --prep_ada_path PREP_ADA_PATH
329 | path to prep_ada.jl script
330 | ```
331 |
332 | * To evaluate model use `python eval.py`
333 | ```
334 | usage: eval.py [-h] --params PARAMS --ckpt CKPT --datasplit DATASPLIT --type
335 | TYPE [--wordlist WORDLIST] [--tau TAU] [--n N]
336 | [--length LENGTH]
337 |
338 | Script to evaluate model
339 |
340 | optional arguments:
341 | -h, --help show this help message and exit
342 | --params PARAMS path to saved model params
343 | --ckpt CKPT path to saved model weights
344 | --datasplit DATASPLIT
345 | train, val or test set to evaluate on
346 | --type TYPE compute ppl or bleu
347 | --wordlist WORDLIST word list to evaluate on (by default all data will be
348 | used)
349 | --tau TAU temperature to use in sampling
350 | --n N number of samples to generate
351 | --length LENGTH maximum length of generated samples
352 | ```
353 |
354 | * To measure BLEU for trained model, firstly evaluate it using `--bleu` argument in `eval.py`
355 | and then compute bleu using `python bleu.py`
356 | ```
357 | usage: bleu.py [-h] --ref REF --hyp HYP --n N [--with_contexts] --bleu_path
358 | BLEU_PATH --mode MODE
359 |
360 | Script to compute BLEU
361 |
362 | optional arguments:
363 | -h, --help show this help message and exit
364 | --ref REF path to file with references
365 | --hyp HYP path to file with hypotheses
366 | --n N --n argument used to generate --ref file using eval.py
367 | --with_contexts whether to consider contexts or not when compute BLEU
368 | --bleu_path BLEU_PATH
369 | path to mosesdecoder sentence-bleu binary
370 | --mode MODE whether to average or take random example per word
371 | ```
372 |
373 | * Also you can pretrain Attention module using `python train_attention_skipgram.py` and
374 | then use path to saved weights as `--attn_ckpt` argument in `train.py`
375 | ```
376 | usage: train_attention_skipgram.py [-h] [--data DATA] --context_voc
377 | CONTEXT_VOC [--prepared] --window WINDOW
378 | --random_seed RANDOM_SEED [--sparse]
379 | --vec_dim VEC_DIM --attn_hid ATTN_HID
380 | --attn_dropout ATTN_DROPOUT --lr LR
381 | --batch_size BATCH_SIZE --num_epochs
382 | NUM_EPOCHS --exp_dir EXP_DIR
383 |
384 | Script to train a AttentionSkipGram model
385 |
386 | optional arguments:
387 | -h, --help show this help message and exit
388 | --data DATA path to data
389 | --context_voc CONTEXT_VOC
390 | path to context voc for DefinitionModelingModel is
391 | necessary to save pretrained attention module,
392 | particulary embedding matrix
393 | --prepared whether to prepare data or use already prepared
394 | --window WINDOW window for AttentionSkipGram model
395 | --random_seed RANDOM_SEED
396 | random seed for training
397 | --sparse whether to use sparse embeddings or not
398 | --vec_dim VEC_DIM vector dim to train
399 | --attn_hid ATTN_HID hidden size in attention module
400 | --attn_dropout ATTN_DROPOUT
401 | dropout prob in attention module
402 | --lr LR initial lr to use
403 | --batch_size BATCH_SIZE
404 | batch size to use
405 | --num_epochs NUM_EPOCHS
406 | number of epochs to train
407 | --exp_dir EXP_DIR where to save weights, prepared data and logs
408 | ```
409 |
--------------------------------------------------------------------------------
/bleu.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import random
3 | from subprocess import Popen, PIPE
4 | import os
5 | import sys
6 | from itertools import islice
7 |
8 | parser = argparse.ArgumentParser(description='Script to compute BLEU')
9 | parser.add_argument(
10 | "--ref", type=str, required=True,
11 | help="path to file with references"
12 | )
13 | parser.add_argument(
14 | "--hyp", type=str, required=True,
15 | help="path to file with hypotheses"
16 | )
17 | parser.add_argument(
18 | "--n", type=int, required=True,
19 | help="--n argument used to generate --ref file using eval.py"
20 | )
21 | parser.add_argument(
22 | "--with_contexts", dest="with_contexts", action="store_true",
23 | help="whether to consider contexts or not when compute BLEU"
24 | )
25 | parser.add_argument(
26 | "--bleu_path", type=str, required=True,
27 | help="path to mosesdecoder sentence-bleu binary"
28 | )
29 | parser.add_argument(
30 | "--mode", type=str, required=True,
31 | help="whether to average or take random example per word"
32 | )
33 | args = parser.parse_args()
34 | assert args.mode in ["average", "random"], "--mode must be averange or random"
35 |
36 |
37 | def next_n_lines(file_opened, N):
38 | return [x.strip() for x in islice(file_opened, N)]
39 |
40 |
41 | def read_def_file(file, n, with_contexts=False):
42 | defs = {}
43 | while True:
44 | lines = next_n_lines(file, n + 2)
45 | if len(lines) == 0:
46 | break
47 | assert len(lines) == n + 2, "Something bad in hyps file"
48 | word = lines[0].split("Word:")[1].strip()
49 | context = lines[1].split("Context:")[1].strip()
50 | dict_key = word + " " + context if with_contexts else word
51 | if dict_key not in defs:
52 | defs[dict_key] = []
53 | for i in range(2, n + 2):
54 | defs[dict_key].append(lines[i].strip())
55 | return defs
56 |
57 |
58 | def read_ref_file(file, with_contexts=False):
59 | defs = {}
60 | while True:
61 | lines = next_n_lines(file, 3)
62 | if len(lines) == 0:
63 | break
64 | assert len(lines) == 3, "Something bad in refs file"
65 | word = lines[0].split("Word:")[1].strip()
66 | context = lines[1].split("Context:")[1].strip()
67 | definition = lines[2].split("Definition:")[1].strip()
68 | dict_key = word + " " + context if with_contexts else word
69 | if dict_key not in defs:
70 | defs[dict_key] = []
71 | defs[dict_key].append(definition)
72 | return defs
73 |
74 |
75 | def get_bleu_score(bleu_path, all_ref_paths, d, hyp_path):
76 | with open(hyp_path, 'w') as ofp:
77 | ofp.write(d)
78 | read_cmd = ['cat', hyp_path]
79 | bleu_cmd = [bleu_path] + all_ref_paths
80 | rp = Popen(read_cmd, stdout=PIPE)
81 | bp = Popen(bleu_cmd, stdin=rp.stdout, stdout=PIPE, stderr=devnull)
82 | out, err = bp.communicate()
83 | if err is None:
84 | return float(out.strip())
85 | else:
86 | return None
87 |
88 | with open(args.ref) as ifp:
89 | refs = read_ref_file(ifp, args.with_contexts)
90 | with open(args.hyp) as ifp:
91 | hyps = read_def_file(ifp, args.n, args.with_contexts)
92 |
93 | assert len(refs) == len(hyps), "Number of words being defined mismatched!"
94 | tmp_dir = "/tmp"
95 | suffix = str(random.random())
96 | words = refs.keys()
97 | hyp_path = os.path.join(tmp_dir, 'hyp' + suffix)
98 | to_be_deleted = set()
99 | to_be_deleted.add(hyp_path)
100 |
101 | # Computing BLEU
102 | devnull = open(os.devnull, 'w')
103 | score = 0
104 | count = 0
105 | total_refs = 0
106 | total_hyps = 0
107 | for word in words:
108 | if word not in refs or word not in hyps:
109 | continue
110 | wrefs = refs[word]
111 | whyps = hyps[word]
112 | # write out references
113 | all_ref_paths = []
114 | for i, d in enumerate(wrefs):
115 | ref_path = os.path.join(tmp_dir, 'ref' + suffix + str(i))
116 | with open(ref_path, 'w') as ofp:
117 | ofp.write(d)
118 | all_ref_paths.append(ref_path)
119 | to_be_deleted.add(ref_path)
120 | total_refs += len(all_ref_paths)
121 | # score for each output
122 | micro_score = 0
123 | micro_count = 0
124 | if args.mode == "average":
125 | for d in whyps:
126 | rhscore = get_bleu_score(
127 | args.bleu_path, all_ref_paths, d, hyp_path)
128 | if rhscore is not None:
129 | micro_score += rhscore
130 | micro_count += 1
131 | elif args.mode == "random":
132 | d = random.choice(whyps)
133 | rhscore = get_bleu_score(args.bleu_path, all_ref_paths, d, hyp_path)
134 | if rhscore is not None:
135 | micro_score += rhscore
136 | micro_count += 1
137 | total_hyps += micro_count
138 | score += micro_score / micro_count
139 | count += 1
140 | devnull.close()
141 |
142 | # delete tmp files
143 | for f in to_be_deleted:
144 | os.remove(f)
145 | print("BLEU: ", score / count)
146 | print("NUM HYPS USED: ", total_hyps)
147 | print("NUM REFS USED: ", total_refs)
148 |
--------------------------------------------------------------------------------
/eval.py:
--------------------------------------------------------------------------------
1 | from source.datasets import LanguageModelingDataset, LanguageModelingCollate
2 | from source.datasets import DefinitionModelingDataset, DefinitionModelingCollate
3 | from source.datasets import Vocabulary
4 | from source.model import DefinitionModelingModel
5 | from source.constants import BOS
6 | from source.pipeline import test
7 | from source.pipeline import generate
8 | from torch.utils.data import DataLoader
9 | from tqdm import tqdm
10 | import argparse
11 | import json
12 | import torch
13 |
14 | parser = argparse.ArgumentParser(description='Script to evaluate model')
15 | parser.add_argument(
16 | "--params", type=str, required=True,
17 | help="path to saved model params"
18 | )
19 | parser.add_argument(
20 | "--ckpt", type=str, required=True,
21 | help="path to saved model weights"
22 | )
23 | parser.add_argument(
24 | "--datasplit", type=str, required=True,
25 | help="train, val or test set to evaluate on"
26 | )
27 | parser.add_argument(
28 | "--type", type=str, required=True,
29 | help="compute ppl or bleu"
30 | )
31 | parser.add_argument(
32 | "--wordlist", type=str, required=False,
33 | help="word list to evaluate on (by default all data will be used)"
34 | )
35 | # params for BLEU
36 | parser.add_argument(
37 | "--tau", type=float, required=False,
38 | help="temperature to use in sampling"
39 | )
40 | parser.add_argument(
41 | "--n", type=int, required=False,
42 | help="number of samples to generate"
43 | )
44 | parser.add_argument(
45 | "--length", type=int, required=False,
46 | help="maximum length of generated samples"
47 | )
48 | args = parser.parse_args()
49 | assert args.datasplit in ["train", "val", "test"], ("--datasplit must be "
50 | "train, val or test")
51 | assert args.type in ["ppl", "bleu"], ("--type must be ppl or bleu")
52 |
53 | with open(args.params, "r") as infile:
54 | model_params = json.load(infile)
55 |
56 | logfile = open(model_params["exp_dir"] + "eval_log", "a")
57 | #import sys
58 | #logfile = sys.stdout
59 |
60 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
61 | model = DefinitionModelingModel(model_params).to(device)
62 | model.load_state_dict(torch.load(args.ckpt)["state_dict"])
63 |
64 | if model.params["pretrain"]:
65 | assert args.type == "ppl", "if --pretrain True => evaluate only ppl mode"
66 | if args.datasplit == "train":
67 | dataset = LanguageModelingDataset(
68 | file=model.params["train_lm"],
69 | vocab_path=model.params["voc"],
70 | bptt=model.params["bptt"],
71 | )
72 | elif args.datasplit == "val":
73 | dataset = LanguageModelingDataset(
74 | file=model.params["eval_lm"],
75 | vocab_path=model.params["voc"],
76 | bptt=model.params["bptt"],
77 | )
78 | elif args.datasplit == "test":
79 | dataset = LanguageModelingDataset(
80 | file=model.params["test_lm"],
81 | vocab_path=model.params["voc"],
82 | bptt=model.params["bptt"],
83 | )
84 | dataloader = DataLoader(
85 | dataset, batch_size=model.params["batch_size"],
86 | collate_fn=LanguageModelingCollate
87 | )
88 | else:
89 | if args.datasplit == "train":
90 | dataset = DefinitionModelingDataset(
91 | file=model.params["train_defs"],
92 | vocab_path=model.params["voc"],
93 | input_vectors_path=model.params["input_train"],
94 | input_adaptive_vectors_path=model.params["input_adaptive_train"],
95 | context_vocab_path=model.params["context_voc"],
96 | ch_vocab_path=model.params["ch_voc"],
97 | use_seed=model.params["use_seed"],
98 | wordlist_path=args.wordlist
99 | )
100 | elif args.datasplit == "val":
101 | dataset = DefinitionModelingDataset(
102 | file=model.params["eval_defs"],
103 | vocab_path=model.params["voc"],
104 | input_vectors_path=model.params["input_eval"],
105 | input_adaptive_vectors_path=model.params["input_adaptive_eval"],
106 | context_vocab_path=model.params["context_voc"],
107 | ch_vocab_path=model.params["ch_voc"],
108 | use_seed=model.params["use_seed"],
109 | wordlist_path=args.wordlist
110 | )
111 | elif args.datasplit == "test":
112 | dataset = DefinitionModelingDataset(
113 | file=model.params["test_defs"],
114 | vocab_path=model.params["voc"],
115 | input_vectors_path=model.params["input_test"],
116 | input_adaptive_vectors_path=model.params["input_adaptive_test"],
117 | context_vocab_path=model.params["context_voc"],
118 | ch_vocab_path=model.params["ch_voc"],
119 | use_seed=model.params["use_seed"],
120 | wordlist_path=args.wordlist
121 | )
122 | dataloader = DataLoader(
123 | dataset,
124 | batch_size=1 if args.type == "bleu" else model.params["batch_size"],
125 | collate_fn=DefinitionModelingCollate
126 | )
127 | if args.type == "ppl":
128 | eval_ppl = test(dataloader, model, device, logfile)
129 | else:
130 | assert args.tau is not None, "--tau is required if --type bleu"
131 | assert args.n is not None, "--n is required if --type bleu"
132 | assert args.length is not None, "--length is required if --type bleu"
133 | defsave = open(
134 | model.params["exp_dir"] + "generated_" +
135 | args.datasplit + "_tau=" +
136 | str(args.tau) + "_n=" + str(args.n) +
137 | "_length=" + str(args.length) + ".txt",
138 | "w"
139 | )
140 | refsave = open(
141 | model.params["exp_dir"] + "refs_" + args.datasplit + ".txt",
142 | "w"
143 | )
144 | #defsave = sys.stdout
145 | voc = Vocabulary()
146 | voc.load(model.params["voc"])
147 | to_input = {
148 | "model": model,
149 | "voc": voc,
150 | "tau": args.tau,
151 | "n": args.n,
152 | "length": args.length,
153 | "device": device,
154 | }
155 | if model.is_attn:
156 | context_voc = Vocabulary()
157 | context_voc.load(model.params["context_voc"])
158 | to_input["context_voc"] = context_voc
159 | if model.params["use_ch"]:
160 | ch_voc = Vocabulary()
161 | ch_voc.load(model.params["ch_voc"])
162 | to_input["ch_voc"] = ch_voc
163 | for i in tqdm(range(len(dataset)), file=logfile):
164 | if model.is_w2v:
165 | to_input["input"] = torch.from_numpy(dataset.input_vectors[i])
166 | if model.is_ada:
167 | to_input["input"] = torch.from_numpy(
168 | dataset.input_adaptive_vectors[i]
169 | )
170 | if model.is_attn:
171 | to_input["word"] = dataset.data[i][0][0]
172 | to_input["context"] = " ".join(dataset.data[i][2])
173 | if model.params["use_ch"]:
174 | to_input["CH_word"] = dataset.data[i][0][0]
175 | if model.params["use_seed"]:
176 | to_input["prefix"] = dataset.data[i][0][0]
177 | else:
178 | to_input["prefix"] = BOS
179 | defsave.write(
180 | "Word: {0}\nContext: {1}\n".format(
181 | dataset.data[i][0][0],
182 | " ".join(dataset.data[i][2])
183 | )
184 | )
185 | defsave.write(generate(**to_input) + "\n")
186 | refsave.write(
187 | "Word: {0}\nContext: {1}\nDefinition: {2}\n".format(
188 | dataset.data[i][0][0],
189 | " ".join(dataset.data[i][2]),
190 | " ".join(dataset.data[i][1])
191 | )
192 | )
193 | defsave.flush()
194 | logfile.flush()
195 | refsave.flush()
196 |
--------------------------------------------------------------------------------
/generate.py:
--------------------------------------------------------------------------------
1 | from source.model import DefinitionModelingModel
2 | from source.pipeline import generate
3 | from source.datasets import Vocabulary
4 | from source.utils import prepare_ada_vectors_from_python, prepare_w2v_vectors
5 | from source.constants import BOS
6 | import argparse
7 | import torch
8 | import json
9 |
10 | parser = argparse.ArgumentParser(description='Script to generate using model')
11 | parser.add_argument(
12 | "--params", type=str, required=True,
13 | help="path to saved model params"
14 | )
15 | parser.add_argument(
16 | "--ckpt", type=str, required=True,
17 | help="path to saved model weights"
18 | )
19 | parser.add_argument(
20 | "--tau", type=float, required=True,
21 | help="temperature to use in sampling"
22 | )
23 | parser.add_argument(
24 | "--n", type=int, required=True,
25 | help="number of samples to generate"
26 | )
27 | parser.add_argument(
28 | "--length", type=int, required=True,
29 | help="maximum length of generated samples"
30 | )
31 | parser.add_argument(
32 | "--prefix", type=str, required=False,
33 | help="prefix to read until generation starts"
34 | )
35 | parser.add_argument(
36 | "--wordlist", type=str, required=False,
37 | help="path to word list with words and contexts"
38 | )
39 | parser.add_argument(
40 | "--w2v_binary_path", type=str, required=False,
41 | help="path to binary w2v file"
42 | )
43 | parser.add_argument(
44 | "--ada_binary_path", type=str, required=False,
45 | help="path to binary ada file"
46 | )
47 | parser.add_argument(
48 | "--prep_ada_path", type=str, required=False,
49 | help="path to prep_ada.jl script"
50 | )
51 | args = parser.parse_args()
52 |
53 | with open(args.params, "r") as infile:
54 | model_params = json.load(infile)
55 |
56 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
57 | model = DefinitionModelingModel(model_params).to(device)
58 | model.load_state_dict(torch.load(args.ckpt)["state_dict"])
59 | voc = Vocabulary()
60 | voc.load(model_params["voc"])
61 | to_input = {
62 | "model": model,
63 | "voc": voc,
64 | "tau": args.tau,
65 | "n": args.n,
66 | "length": args.length,
67 | "device": device,
68 | }
69 | if model.params["pretrain"]:
70 | to_input["prefix"] = args.prefix
71 | print(generate(**to_input))
72 | else:
73 | assert args.wordlist is not None, ("to generate definitions in --pretrain "
74 | "False mode --wordlist is required")
75 |
76 | with open(args.wordlist, "r") as infile:
77 | data = infile.readlines()
78 |
79 | if model.is_w2v:
80 | assert args.w2v_binary_path is not None, ("model.is_w2v True => "
81 | "--w2v_binary_path is "
82 | "required")
83 | input_vecs = torch.from_numpy(
84 | prepare_w2v_vectors(args.wordlist, args.w2v_binary_path)
85 | )
86 | if model.is_ada:
87 | assert args.ada_binary_path is not None, ("model.is_ada True => "
88 | "--ada_binary_path is "
89 | "required")
90 | assert args.prep_ada_path is not None, ("model.is_ada True => "
91 | "--prep_ada_path is "
92 | "required")
93 | input_vecs = torch.from_numpy(
94 | prepare_ada_vectors_from_python(
95 | args.wordlist,
96 | args.prep_ada_path,
97 | args.ada_binary_path
98 | )
99 | )
100 | if model.is_attn:
101 | context_voc = Vocabulary()
102 | context_voc.load(model.params["context_voc"])
103 | to_input["context_voc"] = context_voc
104 | if model.params["use_ch"]:
105 | ch_voc = Vocabulary()
106 | ch_voc.load(model.params["ch_voc"])
107 | to_input["ch_voc"] = ch_voc
108 | for i in range(len(data)):
109 | word, context = data[i].split('\t')
110 | context = context.strip()
111 | if model.is_w2v or model.is_ada:
112 | to_input["input"] = input_vecs[i]
113 | if model.is_attn:
114 | to_input["word"] = word
115 | to_input["context"] = context
116 | if model.params["use_ch"]:
117 | to_input["CH_word"] = word
118 | if model.params["use_seed"]:
119 | to_input["prefix"] = word
120 | else:
121 | to_input["prefix"] = BOS
122 | print("Word: {0}".format(word))
123 | print("Context: {0}".format(context))
124 | print(generate(**to_input))
125 |
--------------------------------------------------------------------------------
/prep_ada.jl:
--------------------------------------------------------------------------------
1 | using ArgParse
2 | using AdaGram
3 | using JSON
4 | using NPZ
5 |
6 | function main(args)
7 |
8 | s = ArgParseSettings(description = "Prepare word vectors for Input-Adaptive conditioning")
9 |
10 | @add_arg_table s begin
11 | "--defs"
12 | nargs = '+'
13 | arg_type = String
14 | required = true
15 | help = "location of json file with definitions."
16 | "--save"
17 | nargs = '+'
18 | arg_type = String
19 | required = true
20 | help = "where to save files"
21 | "--ada"
22 | arg_type = String
23 | required = true
24 | help = "location of AdaGram file"
25 | end
26 |
27 | parsed_args = parse_args(s)
28 | if length(parsed_args["defs"]) != length(parsed_args["save"])
29 | error("Number of defs files must match number of save locations")
30 | end
31 |
32 | vm, dict = load_model(parsed_args["ada"]);
33 | for i = 1:length(parsed_args["defs"])
34 | open(parsed_args["defs"][i], "r") do f
35 | global definitions = JSON.parse(readstring(f))
36 | end
37 | global vectors = zeros(length(definitions), length(vm.In[:, 1, 1]))
38 | for (k, elem) in enumerate(definitions)
39 | if haskey(dict.word2id, elem[1][1])
40 | global good_context = []
41 | for w in elem[3]
42 | if haskey(dict.word2id, w)
43 | push!(good_context, w)
44 | end
45 | end
46 | mxval, mxidx = findmax(disambiguate(vm, dict, elem[1][1], split(join(good_context, " "))))
47 | vectors[k, :] = vm.In[:, mxidx, dict.word2id[elem[1][1]]]
48 | end
49 | end
50 | npzwrite(parsed_args["save"][i], vectors)
51 | end
52 |
53 | end
54 |
55 | main(ARGS)
--------------------------------------------------------------------------------
/prep_embedding_matrix.py:
--------------------------------------------------------------------------------
1 | from source.datasets import Vocabulary
2 | import argparse
3 | from gensim.models import KeyedVectors
4 | import torch
5 | import numpy as np
6 |
7 | parser = argparse.ArgumentParser(
8 | description='Prepare word vectors for embedding layer in the model'
9 | )
10 | parser.add_argument(
11 | "--voc", type=str, required=True,
12 | help="location of model vocabulary file"
13 | )
14 | parser.add_argument(
15 | "--w2v", type=str, required=True,
16 | help="location of binary w2v file"
17 | )
18 | parser.add_argument(
19 | "--save", type=str, required=True,
20 | help="where to save prepaired matrix"
21 | )
22 | args = parser.parse_args()
23 | word_vectors = KeyedVectors.load_word2vec_format(args.w2v, binary=True)
24 | voc = Vocabulary()
25 | voc.load(args.voc)
26 | vecs = []
27 | initrange = 0.5 / word_vectors.vector_size
28 | for key in voc.tok2id.keys():
29 | if key in word_vectors:
30 | vecs.append(word_vectors[key])
31 | else:
32 | vecs.append(
33 | np.random.uniform(
34 | low=-initrange,
35 | high=initrange,
36 | size=word_vectors.vector_size)
37 | )
38 | torch.save(torch.from_numpy(np.array(vecs)).float(), args.save)
39 |
--------------------------------------------------------------------------------
/prep_vocab.py:
--------------------------------------------------------------------------------
1 | from source.datasets import Vocabulary
2 | import argparse
3 | import json
4 |
5 | parser = argparse.ArgumentParser(description='Prepare vocabularies for model')
6 | parser.add_argument(
7 | '--defs', type=str, required=True, nargs="+",
8 | help='location of json file with definitions.'
9 | )
10 | parser.add_argument(
11 | "--lm", type=str, required=False, nargs="+",
12 | help="location of txt file with text for LM pre-training"
13 | )
14 | parser.add_argument(
15 | '--same', dest='same', action='store_true',
16 | help="use same vocab for definitions and contexts"
17 | )
18 | parser.set_defaults(same=False)
19 | parser.add_argument(
20 | "--save", type=str, required=True,
21 | help="where to save prepaired vocabulary (for words from definitions)"
22 | )
23 | parser.add_argument(
24 | "--save_context", type=str, required=False,
25 | help="where to save vocabulary (for words from contexts)"
26 | )
27 | parser.add_argument(
28 | "--save_chars", type=str, required=True,
29 | help="where to save char vocabulary (for chars from all words)"
30 | )
31 | args = parser.parse_args()
32 | if not args.same and args.save_context is None:
33 | parser.error("--save_context required if --same didn't used")
34 |
35 |
36 | voc = Vocabulary()
37 | char_voc = Vocabulary()
38 | if not args.same:
39 | context_voc = Vocabulary()
40 |
41 | definitions = []
42 | for f in args.defs:
43 | with open(f, "r") as infile:
44 | definitions.extend(json.load(infile))
45 |
46 | if args.lm is not None:
47 | lm_texts = ""
48 | for f in args.lm:
49 | lm_texts = lm_texts + open(f).read().lower() + " "
50 | lm_texts = lm_texts.split()
51 |
52 | for word in lm_texts:
53 | voc.add_token(word)
54 |
55 | for elem in definitions:
56 | voc.add_token(elem[0][0])
57 | char_voc.tok_maxlen = max(len(elem[0][0]), char_voc.tok_maxlen)
58 | for c in elem[0][0]:
59 | char_voc.add_token(c)
60 | for i in range(len(elem[1])):
61 | voc.add_token(elem[1][i])
62 | if args.same:
63 | for i in range(len(elem[2])):
64 | voc.add_token(elem[2][i])
65 | else:
66 | context_voc.add_token(elem[0][0])
67 | for i in range(len(elem[2])):
68 | context_voc.add_token(elem[2][i])
69 |
70 |
71 | voc.save(args.save)
72 | char_voc.save(args.save_chars)
73 | if not args.same:
74 | context_voc.save(args.save_context)
75 |
--------------------------------------------------------------------------------
/prep_w2v.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from gensim.models import KeyedVectors
3 | import numpy as np
4 | import json
5 |
6 | parser = argparse.ArgumentParser(
7 | description='Prepare word vectors for Input conditioning'
8 | )
9 |
10 | parser.add_argument(
11 | '--defs', type=str, required=True, nargs="+",
12 | help='location of json file with definitions.'
13 | )
14 |
15 | parser.add_argument(
16 | '--save', type=str, required=True, nargs="+",
17 | help='where to save files'
18 | )
19 |
20 | parser.add_argument(
21 | "--w2v", type=str, required=True,
22 | help="location of binary w2v file"
23 | )
24 | args = parser.parse_args()
25 |
26 | if len(args.defs) != len(args.save):
27 | parser.error("Number of defs files must match number of save locations")
28 |
29 | word_vectors = KeyedVectors.load_word2vec_format(args.w2v, binary=True)
30 | for i in range(len(args.defs)):
31 | vectors = []
32 | with open(args.defs[i], "r") as infile:
33 | definitions = json.load(infile)
34 | for elem in definitions:
35 | if elem[0][0] in word_vectors:
36 | vectors.append(word_vectors[elem[0][0]])
37 | else:
38 | vectors.append(np.zeros(word_vectors.vector_size))
39 | vectors = np.array(vectors)
40 | np.save(args.save[i], vectors)
41 |
--------------------------------------------------------------------------------
/source/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/agadetsky/pytorch-definitions/03e7fb2e02c03ce5774f5e2cd174c7f224373a3e/source/__init__.py
--------------------------------------------------------------------------------
/source/attention_skipgram.py:
--------------------------------------------------------------------------------
1 | from .layers import InputAttention
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 |
6 |
7 | class AttentionSkipGram(nn.Module):
8 |
9 | def __init__(self, n_attn_tokens, n_attn_embsize,
10 | n_attn_hid, attn_dropout, sparse=False):
11 | super(AttentionSkipGram, self).__init__()
12 | self.n_attn_tokens = n_attn_tokens
13 | self.n_attn_embsize = n_attn_embsize
14 | self.n_attn_hid = n_attn_hid
15 | self.attn_dropout = attn_dropout
16 | self.sparse = sparse
17 |
18 | self.emb0_lookup = InputAttention(
19 | n_attn_tokens=self.n_attn_tokens,
20 | n_attn_embsize=self.n_attn_embsize,
21 | n_attn_hid=self.n_attn_hid,
22 | attn_dropout=self.attn_dropout,
23 | sparse=self.sparse
24 | )
25 | self.emb1_lookup = nn.Embedding(
26 | num_embeddings=self.n_attn_tokens,
27 | embedding_dim=self.n_attn_embsize,
28 | sparse=self.sparse
29 | )
30 | self.emb1_lookup.weight.data.zero_()
31 |
32 | def forward(self, words, context, neg):
33 | idx = torch.LongTensor(words.size(0), 1).random_(
34 | 0, context.size(1)
35 | ).to(words.device)
36 | labels = context.gather(1, idx).squeeze(1)
37 |
38 | w_embs = self.emb0_lookup(words, context)
39 | c_embs = self.emb1_lookup(labels)
40 | n_embs = self.emb1_lookup(neg)
41 |
42 | pos_ips = torch.sum(w_embs * c_embs, 1)
43 | neg_ips = torch.bmm(
44 | n_embs, torch.unsqueeze(w_embs, 1).permute(0, 2, 1)
45 | ).squeeze(2)
46 |
47 | # Neg Log Likelihood
48 | pos_loss = -torch.mean(F.logsigmoid(pos_ips))
49 | neg_loss = -torch.mean(F.logsigmoid(-neg_ips).sum(1))
50 |
51 | return pos_loss + neg_loss
52 |
--------------------------------------------------------------------------------
/source/constants.py:
--------------------------------------------------------------------------------
1 | # FOR MODEL VOCABULARY
2 | PAD = ''
3 | UNK = ''
4 | BOS = ''
5 | EOS = ''
6 | PAD_IDX = 0
7 | UNK_IDX = 1
8 | BOS_IDX = 2
9 | EOS_IDX = 3
10 |
--------------------------------------------------------------------------------
/source/datasets.py:
--------------------------------------------------------------------------------
1 | from . import constants
2 | from torch.utils.data import Dataset
3 | import json
4 | import numpy as np
5 | import math
6 |
7 |
8 | class Vocabulary:
9 | """Word/char vocabulary"""
10 |
11 | def __init__(self):
12 | self.tok2id = {
13 | constants.PAD: constants.PAD_IDX,
14 | constants.UNK: constants.UNK_IDX,
15 | constants.BOS: constants.BOS_IDX,
16 | constants.EOS: constants.EOS_IDX
17 | }
18 | self.id2tok = {
19 | constants.PAD_IDX: constants.PAD,
20 | constants.UNK_IDX: constants.UNK,
21 | constants.BOS_IDX: constants.BOS,
22 | constants.EOS_IDX: constants.EOS
23 | }
24 |
25 | # we need this for maxlen of word being definedR in CH conditioning
26 | self.tok_maxlen = -float("inf")
27 |
28 | def encode(self, tok):
29 | if tok in self.tok2id:
30 | return self.tok2id[tok]
31 | else:
32 | return constants.UNK_IDX
33 |
34 | def decode(self, idx):
35 | if idx in self.id2tok:
36 | return self.id2tok[idx]
37 | else:
38 | raise ValueError("No such idx: {0}".format(idx))
39 |
40 | def encode_seq(self, arr):
41 | ret = []
42 | for elem in arr:
43 | ret.append(self.encode(elem))
44 | return ret
45 |
46 | def decode_seq(self, arr):
47 | ret = []
48 | for elem in arr:
49 | ret.append(self.decode(elem))
50 | return ret
51 |
52 | def add_token(self, tok):
53 | if tok not in self.tok2id:
54 | self.tok2id[tok] = len(self.tok2id)
55 | self.id2tok[len(self.id2tok)] = tok
56 |
57 | def save(self, path):
58 | with open(path, "w") as outfile:
59 | json.dump([self.id2tok, self.tok_maxlen], outfile, indent=4)
60 |
61 | def load(self, path):
62 | with open(path, "r") as infile:
63 | self.id2tok, self.tok_maxlen = json.load(infile)
64 | self.id2tok = {int(k): v for k, v in self.id2tok.items()}
65 | self.tok2id = {}
66 | for i in self.id2tok.keys():
67 | self.tok2id[self.id2tok[i]] = i
68 |
69 |
70 | def pad(seq, size, value):
71 | if len(seq) < size:
72 | seq.extend([value] * (size - len(seq)))
73 | return seq
74 |
75 |
76 | class LanguageModelingDataset(Dataset):
77 | """LanguageModeling dataset."""
78 |
79 | def __init__(self, file, vocab_path, bptt):
80 | """
81 | Args:
82 | file (string): Path to the file
83 | vocab_path (string): path to word vocab to use
84 | bptt (int): length of one sentence
85 | """
86 | with open(file, "r") as infile:
87 | self.data = infile.read().lower().split()
88 | self.voc = Vocabulary()
89 | self.voc.load(vocab_path)
90 | self.bptt = bptt
91 |
92 | def __len__(self):
93 | return math.ceil(len(self.data) / (self.bptt + 1))
94 |
95 | def __getitem__(self, idx):
96 | i = idx + self.bptt * idx
97 | sample = {
98 | "x": self.voc.encode_seq(self.data[i: i + self.bptt]),
99 | "y": self.voc.encode_seq(self.data[i + 1: i + self.bptt + 1]),
100 | }
101 | return sample
102 |
103 |
104 | def LanguageModelingCollate(batch):
105 | batch_x = []
106 | batch_y = []
107 | maxlen = -float("inf")
108 | for i in range(len(batch)):
109 | batch_x.append(batch[i]["x"])
110 | batch_y.append(batch[i]["y"])
111 | maxlen = max(maxlen, len(batch[i]["x"]), len(batch[i]["y"]))
112 |
113 | for i in range(len(batch)):
114 | batch_x[i] = pad(batch_x[i], maxlen, constants.PAD_IDX)
115 | batch_y[i] = pad(batch_y[i], maxlen, constants.PAD_IDX)
116 |
117 | ret_batch = {
118 | "x": np.array(batch_x),
119 | "y": np.array(batch_y),
120 | }
121 | return ret_batch
122 |
123 |
124 | class DefinitionModelingDataset(Dataset):
125 | """DefinitionModeling dataset."""
126 |
127 | def __init__(self, file, vocab_path, input_vectors_path=None,
128 | input_adaptive_vectors_path=None, context_vocab_path=None,
129 | ch_vocab_path=None, use_seed=False, wordlist_path=None):
130 | """
131 | Args:
132 | file (string): path to the file
133 | vocab_path (string): path to word vocab to use
134 | input_vectors_path (string): path to vectors for Input conditioning
135 | input_adaptive_vectors_path (string): path to vectors for Input-Adaptive conditioning
136 | context_vocab_path (string): path to vocab for context words for Input-Attention
137 | ch_vocab_path (string): path to char vocab for CH conditioning
138 | use_seed (bool): whether to use Seed conditioning or not
139 | wordlist_path (string): path to wordlist with words
140 | """
141 | with open(file, "r") as infile:
142 | self.data = json.load(infile)
143 | self.voc = Vocabulary()
144 | self.voc.load(vocab_path)
145 | if context_vocab_path is not None:
146 | self.context_voc = Vocabulary()
147 | self.context_voc.load(context_vocab_path)
148 | if ch_vocab_path is not None:
149 | self.ch_voc = Vocabulary()
150 | self.ch_voc.load(ch_vocab_path)
151 | if input_vectors_path is not None:
152 | self.input_vectors = np.load(input_vectors_path).astype(np.float32)
153 | if input_adaptive_vectors_path is not None:
154 | self.input_adaptive_vectors = np.load(
155 | input_adaptive_vectors_path
156 | ).astype(np.float32)
157 | if wordlist_path is not None:
158 | wordlist = set(
159 | [elem.strip() for elem in open(wordlist_path, "r").readlines()]
160 | )
161 | data = []
162 | if input_vectors_path is not None:
163 | input_vectors = []
164 | if input_adaptive_vectors_path is not None:
165 | input_adaptive_vectors = []
166 | for i in range(len(self.data)):
167 | if self.data[i][0][0] in wordlist:
168 | data.append(self.data[i])
169 | if input_vectors_path is not None:
170 | input_vectors.append(
171 | self.input_vectors[i]
172 | )
173 | if input_adaptive_vectors_path is not None:
174 | input_adaptive_vectors.append(
175 | self.input_adaptive_vectors[i]
176 | )
177 | assert len(data) > 0, "You provided bad wordlist, no words found"
178 | if input_vectors_path is not None:
179 | self.input_vectors = np.array(input_vectors).astype(np.float32)
180 | if input_adaptive_vectors_path is not None:
181 | self.input_adaptive_vectors = np.array(
182 | input_adaptive_vectors
183 | ).astype(np.float32)
184 | self.data = data
185 | self.use_seed = use_seed
186 |
187 | def __len__(self):
188 | return len(self.data)
189 |
190 | def __getitem__(self, idx):
191 | sample = {
192 | "x": self.voc.encode_seq(self.data[idx][1]),
193 | "y": self.voc.encode_seq(self.data[idx][1][1:] + [constants.EOS]),
194 | }
195 | if hasattr(self, "input_vectors"):
196 | sample["input"] = self.input_vectors[idx]
197 | if hasattr(self, "input_adaptive_vectors"):
198 | sample["input_adaptive"] = self.input_adaptive_vectors[idx]
199 | if hasattr(self, "context_voc"):
200 | sample["word"] = self.context_voc.encode(self.data[idx][0][0])
201 | sample["context"] = self.context_voc.encode_seq(self.data[idx][2])
202 | if hasattr(self, "ch_voc"):
203 | sample["CH"] = [constants.BOS_IDX] + \
204 | self.ch_voc.encode_seq(list(self.data[idx][0][0])) + \
205 | [constants.EOS_IDX]
206 | # CH_maxlen: +2 because EOS + BOS
207 | sample["CH_maxlen"] = self.ch_voc.tok_maxlen + 2
208 | if self.use_seed:
209 | sample["y"] = [sample["x"][0]] + sample["y"]
210 | sample["x"] = self.voc.encode_seq(self.data[idx][0]) + sample["x"]
211 | return sample
212 |
213 |
214 | def DefinitionModelingCollate(batch):
215 | batch_x = []
216 | batch_y = []
217 | is_w2v = "input" in batch[0]
218 | is_ada = "input_adaptive" in batch[0]
219 | is_attn = "word" in batch[0] and "context" in batch[0]
220 | is_ch = "CH" in batch[0] and "CH_maxlen" in batch[0]
221 | if is_w2v:
222 | batch_input = []
223 | if is_ada:
224 | batch_input_adaptive = []
225 | if is_attn:
226 | batch_word = []
227 | batch_context = []
228 | context_maxlen = -float("inf")
229 | if is_ch:
230 | batch_ch = []
231 | CH_maxlen = batch[0]["CH_maxlen"]
232 |
233 | definition_lengths = []
234 | for i in range(len(batch)):
235 | batch_x.append(batch[i]["x"])
236 | batch_y.append(batch[i]["y"])
237 | if is_w2v:
238 | batch_input.append(batch[i]["input"])
239 | if is_ada:
240 | batch_input_adaptive.append(batch[i]["input_adaptive"])
241 | if is_attn:
242 | batch_word.append(batch[i]["word"])
243 | batch_context.append(batch[i]["context"])
244 | context_maxlen = max(context_maxlen, len(batch_context[-1]))
245 | if is_ch:
246 | batch_ch.append(batch[i]["CH"])
247 | definition_lengths.append(len(batch_x[-1]))
248 |
249 | definition_maxlen = max(definition_lengths)
250 |
251 | for i in range(len(batch)):
252 | batch_x[i] = pad(batch_x[i], definition_maxlen, constants.PAD_IDX)
253 | batch_y[i] = pad(batch_y[i], definition_maxlen, constants.PAD_IDX)
254 | if is_attn:
255 | batch_context[i] = pad(
256 | batch_context[i], context_maxlen, constants.PAD_IDX
257 | )
258 | if is_ch:
259 | batch_ch[i] = pad(batch_ch[i], CH_maxlen, constants.PAD_IDX)
260 |
261 | order = np.argsort(definition_lengths)[::-1]
262 | batch_x = np.array(batch_x)[order]
263 | batch_y = np.array(batch_y)[order]
264 | ret_batch = {
265 | "x": batch_x,
266 | "y": batch_y,
267 | }
268 | if is_w2v:
269 | batch_input = np.array(batch_input, dtype=np.float32)[order]
270 | ret_batch["input"] = batch_input
271 | if is_ada:
272 | batch_input_adaptive = np.array(
273 | batch_input_adaptive,
274 | dtype=np.float32
275 | )[order]
276 | ret_batch["input_adaptive"] = batch_input_adaptive
277 | if is_attn:
278 | batch_word = np.array(batch_word)[order]
279 | batch_context = np.array(batch_context)[order]
280 | ret_batch["word"] = batch_word
281 | ret_batch["context"] = batch_context
282 | if is_ch:
283 | batch_ch = np.array(batch_ch)[order]
284 | ret_batch["CH"] = batch_ch
285 |
286 | return ret_batch
287 |
--------------------------------------------------------------------------------
/source/layers.py:
--------------------------------------------------------------------------------
1 | from . import constants
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 |
6 |
7 | class Input(nn.Module):
8 | """
9 | Class for Input or Input - Adaptive or dummy conditioning
10 | """
11 |
12 | def __init__(self):
13 | super(Input, self).__init__()
14 |
15 | def forward(self, x):
16 | """
17 | Vectors are already prepaired in DataLoaders, so just return them
18 | """
19 | return x
20 |
21 |
22 | class InputAttention(nn.Module):
23 | """
24 | Class for Input Attention conditioning
25 | """
26 |
27 | def __init__(self, n_attn_tokens, n_attn_embsize,
28 | n_attn_hid, attn_dropout, sparse=False):
29 | super(InputAttention, self).__init__()
30 | self.n_attn_tokens = n_attn_tokens
31 | self.n_attn_embsize = n_attn_embsize
32 | self.n_attn_hid = n_attn_hid
33 | self.attn_dropout = attn_dropout
34 | self.sparse = sparse
35 |
36 | self.embs = nn.Embedding(
37 | num_embeddings=self.n_attn_tokens,
38 | embedding_dim=self.n_attn_embsize,
39 | padding_idx=constants.PAD_IDX,
40 | sparse=self.sparse
41 | )
42 |
43 | self.ann = nn.Sequential(
44 | nn.Dropout(p=self.attn_dropout),
45 | nn.Linear(
46 | in_features=self.n_attn_embsize,
47 | out_features=self.n_attn_hid
48 | ),
49 | nn.Tanh()
50 | ) # maybe use ReLU or other?
51 |
52 | self.a_linear = nn.Linear(
53 | in_features=self.n_attn_hid,
54 | out_features=self.n_attn_embsize
55 | )
56 |
57 | def forward(self, word, context):
58 | x_embs = self.embs(word)
59 | mask = self.get_mask(context)
60 | return mask * x_embs
61 |
62 | def get_mask(self, context):
63 | context_embs = self.embs(context)
64 | lengths = (context != constants.PAD_IDX)
65 | for_sum_mask = lengths.unsqueeze(2).float()
66 | lengths = lengths.sum(1).float().view(-1, 1)
67 | logits = self.a_linear(
68 | (self.ann(context_embs) * for_sum_mask).sum(1) / lengths
69 | )
70 | return F.sigmoid(logits)
71 |
72 | def init_attn(self, freeze):
73 | initrange = 0.5 / self.n_attn_embsize
74 | with torch.no_grad():
75 | nn.init.uniform_(self.embs.weight, -initrange, initrange)
76 | nn.init.xavier_uniform_(self.a_linear.weight)
77 | nn.init.constant_(self.a_linear.bias, 0)
78 | nn.init.xavier_uniform_(self.ann[1].weight)
79 | nn.init.constant_(self.ann[1].bias, 0)
80 | self.embs.weight.requires_grad = not freeze
81 |
82 | def init_attn_from_pretrained(self, weights, freeze):
83 | self.load_state_dict(weights)
84 | self.embs.weight.requires_grad = not freeze
85 |
86 |
87 | class CharCNN(nn.Module):
88 | """
89 | Class for CH conditioning
90 | """
91 |
92 | def __init__(self, n_ch_tokens, ch_maxlen, ch_emb_size,
93 | ch_feature_maps, ch_kernel_sizes):
94 | super(CharCNN, self).__init__()
95 | assert len(ch_feature_maps) == len(ch_kernel_sizes)
96 |
97 | self.n_ch_tokens = n_ch_tokens
98 | self.ch_maxlen = ch_maxlen
99 | self.ch_emb_size = ch_emb_size
100 | self.ch_feature_maps = ch_feature_maps
101 | self.ch_kernel_sizes = ch_kernel_sizes
102 |
103 | self.feature_mappers = nn.ModuleList()
104 | for i in range(len(self.ch_feature_maps)):
105 | reduced_length = self.ch_maxlen - self.ch_kernel_sizes[i] + 1
106 | self.feature_mappers.append(
107 | nn.Sequential(
108 | nn.Conv2d(
109 | in_channels=1,
110 | out_channels=self.ch_feature_maps[i],
111 | kernel_size=(
112 | self.ch_kernel_sizes[i],
113 | self.ch_emb_size
114 | )
115 | ),
116 | nn.Tanh(),
117 | nn.MaxPool2d(kernel_size=(reduced_length, 1))
118 | )
119 | )
120 |
121 | self.embs = nn.Embedding(
122 | self.n_ch_tokens,
123 | self.ch_emb_size,
124 | padding_idx=constants.PAD_IDX
125 | )
126 |
127 | def forward(self, x):
128 | # x - [batch_size x maxlen]
129 | bsize, length = x.size()
130 | assert length == self.ch_maxlen
131 | x_embs = self.embs(x).view(bsize, 1, self.ch_maxlen, self.ch_emb_size)
132 |
133 | cnn_features = []
134 | for i in range(len(self.ch_feature_maps)):
135 | cnn_features.append(
136 | self.feature_mappers[i](x_embs).view(bsize, -1)
137 | )
138 |
139 | return torch.cat(cnn_features, dim=1)
140 |
141 | def init_ch(self):
142 | initrange = 0.5 / self.ch_emb_size
143 | with torch.no_grad():
144 | nn.init.uniform_(self.embs.weight, -initrange, initrange)
145 | for name, p in self.feature_mappers.named_parameters():
146 | if "bias" in name:
147 | nn.init.constant_(p, 0)
148 | elif "weight" in name:
149 | nn.init.xavier_uniform_(p)
150 |
151 |
152 | class Hidden(nn.Module):
153 | """
154 | Class for Hidden conditioning
155 | """
156 |
157 | def __init__(self, cond_size, hidden_size, out_size):
158 | super(Hidden, self).__init__()
159 | self.cond_size = cond_size
160 | self.hidden_size = hidden_size
161 | self.out_size = out_size
162 | self.in_size = self.cond_size + self.hidden_size
163 | self.linear = nn.Linear(
164 | in_features=self.in_size,
165 | out_features=self.out_size
166 | )
167 |
168 | def forward(self, hidden, conds):
169 | seqlen = hidden.size(1) # batch_first=True
170 | repeated_conds = conds.view(-1).repeat(seqlen)
171 | repeated_conds = repeated_conds.view(seqlen, *conds.size())
172 | repeated_conds = repeated_conds.permute(
173 | 1, 0, 2
174 | ) # batchsize x seqlen x cond_dim
175 | concat = torch.cat(
176 | [repeated_conds, hidden], dim=2
177 | ) # concat by last dim
178 | return F.tanh(self.linear(concat))
179 |
180 | def init_hidden(self):
181 | with torch.no_grad():
182 | nn.init.xavier_uniform_(self.linear.weight)
183 | nn.init.constant_(self.linear.bias, 0)
184 |
185 |
186 | class Gated(nn.Module):
187 | """
188 | Class for Gated conditioning
189 | """
190 |
191 | def __init__(self, cond_size, hidden_size):
192 | super(Gated, self).__init__()
193 | self.cond_size = cond_size
194 | self.hidden_size = hidden_size
195 | self.in_size = self.cond_size + self.hidden_size
196 | self.linear1 = nn.Linear(
197 | in_features=self.in_size,
198 | out_features=self.hidden_size
199 | )
200 | self.linear2 = nn.Linear(
201 | in_features=self.in_size,
202 | out_features=self.cond_size
203 | )
204 | self.linear3 = nn.Linear(
205 | in_features=self.in_size,
206 | out_features=self.hidden_size
207 | )
208 |
209 | def forward(self, hidden, conds):
210 | seqlen = hidden.size(1) # batch_first=True
211 | repeated_conds = conds.view(-1).repeat(seqlen)
212 | repeated_conds = repeated_conds.view(seqlen, *conds.size())
213 | repeated_conds = repeated_conds.permute(
214 | 1, 0, 2
215 | ) # batchsize x seqlen x cond_dim
216 | concat = torch.cat(
217 | [repeated_conds, hidden], dim=2
218 | ) # concat by last dim
219 | z_t = F.sigmoid(self.linear1(concat))
220 | r_t = F.sigmoid(self.linear2(concat))
221 | masked_concat = torch.cat(
222 | [repeated_conds * r_t, hidden], dim=2
223 | )
224 | hat_s_t = F.tanh(self.linear3(masked_concat))
225 | return (1 - z_t) * hidden + z_t * hat_s_t
226 |
227 | def init_gated(self):
228 | with torch.no_grad():
229 | nn.init.xavier_uniform_(self.linear1.weight)
230 | nn.init.xavier_uniform_(self.linear2.weight)
231 | nn.init.xavier_uniform_(self.linear3.weight)
232 | nn.init.constant_(self.linear1.bias, 0)
233 | nn.init.constant_(self.linear2.bias, 0)
234 | nn.init.constant_(self.linear3.bias, 0)
235 |
--------------------------------------------------------------------------------
/source/model.py:
--------------------------------------------------------------------------------
1 | from . import constants
2 | from .layers import Input, InputAttention, CharCNN, Hidden, Gated
3 | import torch
4 | import torch.nn as nn
5 | from torch.nn.utils.rnn import pad_packed_sequence as unpack
6 | from torch.nn.utils.rnn import pack_padded_sequence as pack
7 |
8 |
9 | class DefinitionModelingModel(nn.Module):
10 | """Definition modeling class"""
11 |
12 | def __init__(self, params):
13 | super(DefinitionModelingModel, self).__init__()
14 | self.params = params
15 |
16 | self.embs = nn.Embedding(
17 | num_embeddings=self.params["ntokens"],
18 | embedding_dim=self.params["nx"],
19 | padding_idx=constants.PAD_IDX
20 | )
21 | self.dropout = nn.Dropout(p=self.params["rnn_dropout"])
22 |
23 | self.n_rnn_input = self.params["nx"]
24 | if not self.params["pretrain"]:
25 | self.input_used = self.params["use_input"]
26 | self.input_used += self.params["use_input_adaptive"]
27 | self.input_used += self.params["use_input_attention"]
28 | self.hidden_used = self.params["use_hidden"]
29 | self.hidden_used += self.params["use_hidden_adaptive"]
30 | self.hidden_used += self.params["use_hidden_attention"]
31 | self.gated_used = self.params["use_gated"]
32 | self.gated_used += self.params["use_gated_adaptive"]
33 | self.gated_used += self.params["use_gated_attention"]
34 | self.is_w2v = self.params["use_input"]
35 | self.is_w2v += self.params["use_hidden"]
36 | self.is_w2v += self.params["use_gated"]
37 | self.is_ada = self.params["use_input_adaptive"]
38 | self.is_ada += self.params["use_hidden_adaptive"]
39 | self.is_ada += self.params["use_gated_adaptive"]
40 | self.is_attn = self.params["use_input_attention"]
41 | self.is_attn += self.params["use_hidden_attention"]
42 | self.is_attn += self.params["use_gated_attention"]
43 | self.is_conditioned = self.input_used
44 | self.is_conditioned += self.hidden_used
45 | self.is_conditioned += self.gated_used
46 |
47 | # check if either Input* or Hidden/Gated conditioning are used not
48 | # both
49 | assert self.input_used + self.hidden_used + \
50 | self.gated_used <= 1, "Too many conditionings used"
51 |
52 | if not self.is_conditioned and self.params["use_ch"]:
53 | raise ValueError("Don't use CH conditioning without others")
54 |
55 | self.cond_size = 0
56 | if self.is_w2v:
57 | self.input = Input()
58 | self.cond_size += self.params["input_dim"]
59 | elif self.is_ada:
60 | self.input_adaptive = Input()
61 | self.cond_size += self.params["input_adaptive_dim"]
62 | elif self.is_attn:
63 | self.input_attention = InputAttention(
64 | n_attn_tokens=self.params["n_attn_tokens"],
65 | n_attn_embsize=self.params["n_attn_embsize"],
66 | n_attn_hid=self.params["n_attn_hid"],
67 | attn_dropout=self.params["attn_dropout"],
68 | sparse=self.params["attn_sparse"]
69 | )
70 | self.cond_size += self.params["n_attn_embsize"]
71 |
72 | if self.params["use_ch"]:
73 | self.ch = CharCNN(
74 | n_ch_tokens=self.params["n_ch_tokens"],
75 | ch_maxlen=self.params["ch_maxlen"],
76 | ch_emb_size=self.params["ch_emb_size"],
77 | ch_feature_maps=self.params["ch_feature_maps"],
78 | ch_kernel_sizes=self.params["ch_kernel_sizes"]
79 | )
80 | self.cond_size += sum(self.params["ch_feature_maps"])
81 |
82 | if self.input_used:
83 | self.n_rnn_input += self.cond_size
84 |
85 | if self.hidden_used:
86 | self.hidden = Hidden(
87 | cond_size=self.cond_size,
88 | hidden_size=self.params["nhid"],
89 | out_size=self.params["nhid"]
90 | )
91 | elif self.gated_used:
92 | self.gated = Gated(
93 | cond_size=self.cond_size,
94 | hidden_size=self.params["nhid"]
95 | )
96 |
97 | self.rnn = nn.LSTM(
98 | input_size=self.n_rnn_input,
99 | hidden_size=self.params["nhid"],
100 | num_layers=self.params["nlayers"],
101 | batch_first=True,
102 | dropout=self.params["rnn_dropout"]
103 | )
104 | self.linear = nn.Linear(
105 | in_features=self.params["nhid"],
106 | out_features=self.params["ntokens"]
107 | )
108 |
109 | self.init_weights()
110 |
111 | def forward(self, x, input=None, word=None, context=None, CH_word=None, hidden=None):
112 | """
113 | x - definitions/LM_sequence to read
114 | input - vectors for Input, Input-Adaptive or dummy conditioning
115 | word - words for Input-Attention conditioning
116 | context - contexts for Input-Attention conditioning
117 | CH_word - words for CH conditioning
118 | hidden - hidden states of RNN
119 | """
120 | lengths = (x != constants.PAD_IDX).sum(dim=1).detach()
121 | maxlen = lengths.max().item()
122 | embs = self.embs(x)
123 | if not self.params["pretrain"]:
124 | all_conds = []
125 | if self.is_w2v:
126 | all_conds.append(self.input(input))
127 | elif self.is_ada:
128 | all_conds.append(self.input_adaptive(input))
129 | elif self.is_attn:
130 | all_conds.append(self.input_attention(word, context))
131 | if self.params["use_ch"]:
132 | all_conds.append(self.ch(CH_word))
133 | if self.is_conditioned:
134 | all_conds = torch.cat(all_conds, dim=1)
135 |
136 | if self.input_used:
137 | repeated_conds = all_conds.view(-1).repeat(maxlen)
138 | repeated_conds = repeated_conds.view(maxlen, *all_conds.size())
139 | repeated_conds = repeated_conds.permute(1, 0, 2)
140 | embs = torch.cat([repeated_conds, embs], dim=-1)
141 |
142 | embs = pack(embs, lengths, batch_first=True)
143 | output, hidden = self.rnn(embs, hidden)
144 | output = unpack(output, batch_first=True)[0]
145 | output = self.dropout(output)
146 |
147 | if not self.params["pretrain"]:
148 | if self.hidden_used:
149 | output = self.hidden(output, all_conds)
150 | elif self.gated_used:
151 | output = self.gated(output, all_conds)
152 |
153 | decoded = self.linear(
154 | output.contiguous().view(
155 | output.size(0) * output.size(1),
156 | output.size(2)
157 | )
158 | )
159 |
160 | return decoded, hidden
161 |
162 | def init_embeddings(self, freeze):
163 | initrange = 0.5 / self.params["nx"]
164 | with torch.no_grad():
165 | nn.init.uniform_(self.embs.weight, -initrange, initrange)
166 | self.embs.weight.requires_grad = not freeze
167 |
168 | def init_embeddings_from_pretrained(self, weights, freeze):
169 | self.embs = self.embs.from_pretrained(weights, freeze)
170 |
171 | def init_rnn(self):
172 | with torch.no_grad():
173 | for name, p in self.rnn.named_parameters():
174 | if "bias" in name:
175 | nn.init.constant_(p, 0)
176 | elif "weight" in name:
177 | nn.init.xavier_uniform_(p)
178 |
179 | def init_rnn_from_pretrained(self, weights):
180 | # k[4:] because we need to remove prefix "rnn." because
181 | # self.rnn.state_dict() is without "rnn." prefix
182 | correct_state_dict = {
183 | k[4:]: v for k, v in weights.items() if k[:4] == "rnn."
184 | }
185 | # also we need to correctly initialize weight_ih_l0
186 | # with pretrained weights because it has different size with
187 | # self.rnn.state_dict(), other weights has correct shapes if
188 | # hidden sizes have same shape as in the LM pretraining
189 | if self.input_used:
190 | w = torch.empty(4 * self.params["nhid"], self.n_rnn_input)
191 | nn.init.xavier_uniform_(w)
192 | w[:, self.cond_size:] = correct_state_dict["weight_ih_l0"]
193 | correct_state_dict["weight_ih_l0"] = w
194 | self.rnn.load_state_dict(correct_state_dict)
195 |
196 | def init_linear(self):
197 | with torch.no_grad():
198 | nn.init.xavier_uniform_(self.linear.weight)
199 | nn.init.constant_(self.linear.bias, 0)
200 |
201 | def init_linear_from_pretrained(self, weights):
202 | # k[7: ] because we need to remove prefix "linear." because
203 | # self.linear.state_dict() is without "linear." prefix
204 | self.linear.load_state_dict(
205 | {k[7:]: v for k, v in weights.items() if k[:7] == "linear."}
206 | )
207 |
208 | def init_weights(self):
209 | if self.params["pretrain"]:
210 | if self.params["w2v_weights"] is not None:
211 | self.init_embeddings_from_pretrained(
212 | torch.load(self.params["w2v_weights"]),
213 | self.params["fix_embeddings"]
214 | )
215 | else:
216 | self.init_embeddings(self.params["fix_embeddings"])
217 | self.init_rnn()
218 | self.init_linear()
219 | else:
220 | if self.params["lm_ckpt"] is not None:
221 | lm_ckpt_weights = torch.load(self.params["lm_ckpt"])
222 | lm_ckpt_weights = lm_ckpt_weights["state_dict"]
223 | self.init_embeddings_from_pretrained(
224 | lm_ckpt_weights["embs.weight"],
225 | self.params["fix_embeddings"]
226 | )
227 | self.init_rnn_from_pretrained(lm_ckpt_weights)
228 | self.init_linear_from_pretrained(lm_ckpt_weights)
229 | else:
230 | if self.params["w2v_weights"] is not None:
231 | self.init_embeddings_from_pretrained(
232 | torch.load(self.params["w2v_weights"]),
233 | self.params["fix_embeddings"]
234 | )
235 | else:
236 | self.init_embeddings(self.params["fix_embeddings"])
237 | self.init_rnn()
238 | self.init_linear()
239 | if self.is_attn:
240 | if self.params["attn_ckpt"] is not None:
241 | self.input_attention.init_attn_from_pretrained(
242 | torch.load(self.params["attn_ckpt"])["state_dict"],
243 | self.params["fix_attn_embeddings"]
244 | )
245 | else:
246 | self.input_attention.init_attn(
247 | self.params["fix_attn_embeddings"]
248 | )
249 | if self.hidden_used:
250 | self.hidden.init_hidden()
251 | if self.gated_used:
252 | self.gated.init_gated()
253 | if self.params["use_ch"]:
254 | self.ch.init_ch()
255 |
--------------------------------------------------------------------------------
/source/pipeline.py:
--------------------------------------------------------------------------------
1 | from . import constants
2 | from .datasets import pad
3 | from torch.nn.utils import clip_grad_norm_
4 | import torch.nn.functional as F
5 | from tqdm import tqdm
6 | import torch
7 | import numpy as np
8 |
9 |
10 | def train_epoch(dataloader, model, optimizer, device, clip_to, logfile):
11 | """
12 | Function for training the model one epoch
13 | dataloader - either LanguageModeling or DefinitionModeling dataloader
14 | model - DefinitionModelingModel
15 | optimizer - optimizer to use (usually Adam)
16 | device - cuda/cpu
17 | clip_to - value to clip gradients
18 | logfile - where to log training
19 | """
20 | # switch model to training mode
21 | model.train()
22 | # train
23 | mean_batch_loss = 0
24 | for batch in tqdm(dataloader, file=logfile):
25 | y_true = torch.from_numpy(batch.pop("y")).to(device).view(-1)
26 | # prepare model args
27 | to_input = {"x": torch.from_numpy(batch["x"]).to(device)}
28 | if not model.params["pretrain"]:
29 | if model.is_w2v:
30 | to_input["input"] = torch.from_numpy(batch["input"]).to(device)
31 | if model.is_ada:
32 | to_input["input"] = torch.from_numpy(
33 | batch["input_adaptive"]
34 | ).to(device)
35 | if model.is_attn:
36 | to_input["word"] = torch.from_numpy(batch["word"]).to(device)
37 | to_input["context"] = torch.from_numpy(
38 | batch["context"]
39 | ).to(device)
40 | if model.params["use_ch"]:
41 | to_input["CH_word"] = torch.from_numpy(
42 | batch["CH"]
43 | ).to(device)
44 |
45 | y_pred, hidden = model(**to_input)
46 | batch_loss = F.cross_entropy(
47 | y_pred, y_true,
48 | ignore_index=constants.PAD_IDX
49 | )
50 | optimizer.zero_grad()
51 | batch_loss.backward()
52 | clip_grad_norm_(
53 | filter(lambda p: p.requires_grad, model.parameters()), clip_to
54 | )
55 | optimizer.step()
56 | logfile.flush()
57 | mean_batch_loss += batch_loss.item()
58 |
59 | mean_batch_loss = mean_batch_loss / len(dataloader)
60 | logfile.write(
61 | "Mean training loss on epoch: {0}\n".format(mean_batch_loss)
62 | )
63 | logfile.flush()
64 |
65 |
66 | def test(dataloader, model, device, logfile):
67 | """
68 | Function for testing the model on dataloader
69 | dataloader - either LanguageModeling or DefinitionModeling dataloader
70 | model - DefinitionModelingModel
71 | device - cuda/cpu
72 | logfile - where to log evaluation
73 | """
74 | # switch model to evaluation mode
75 | model.eval()
76 | # eval
77 | lengths_sum = 0
78 | loss_sum = 0
79 | with torch.no_grad():
80 | for batch in tqdm(dataloader, file=logfile):
81 | y_true = torch.from_numpy(batch.pop("y")).to(device).view(-1)
82 | # prepare model args
83 | to_input = {"x": torch.from_numpy(batch["x"]).to(device)}
84 | if not model.params["pretrain"]:
85 | if model.is_w2v:
86 | to_input["input"] = torch.from_numpy(
87 | batch["input"]
88 | ).to(device)
89 | if model.is_ada:
90 | to_input["input"] = torch.from_numpy(
91 | batch["input_adaptive"]
92 | ).to(device)
93 | if model.is_attn:
94 | to_input["word"] = torch.from_numpy(
95 | batch["word"]
96 | ).to(device)
97 | to_input["context"] = torch.from_numpy(
98 | batch["context"]
99 | ).to(device)
100 | if model.params["use_ch"]:
101 | to_input["CH_word"] = torch.from_numpy(
102 | batch["CH"]
103 | ).to(device)
104 |
105 | y_pred, hidden = model(**to_input)
106 | loss_sum += F.cross_entropy(
107 | y_pred,
108 | y_true,
109 | ignore_index=constants.PAD_IDX,
110 | size_average=False
111 | ).item()
112 | lengths_sum += (to_input["x"] != constants.PAD_IDX).sum().item()
113 | logfile.flush()
114 |
115 | perplexity = np.exp(loss_sum / lengths_sum)
116 | logfile.write(
117 | "Perplexity: {0}\n".format(perplexity)
118 | )
119 | logfile.flush()
120 | return perplexity
121 |
122 |
123 | def generate(model, voc, tau, n, length, device, prefix=None,
124 | input=None, word=None, context=None, context_voc=None,
125 | CH_word=None, ch_voc=None):
126 | """
127 | model - DefinitionModelingModel
128 | voc - model Vocabulary
129 | tau - temperature to generate with
130 | n - number of samples
131 | length - length of the sample
132 | device - cuda/cpu
133 | prefix - prefix to read until generation
134 | input - vectors for Input/InputAdaptive conditioning
135 | word - word for InputAttention conditioning
136 | context - context for InputAttention conditioning
137 | context_voc - Vocabulary for InputAttention conditioning
138 | CH_word - word for CH conditioning
139 | ch_voc - Vocabulary for CH conditioning
140 | """
141 | model.eval()
142 | to_input = {}
143 | if not model.params["pretrain"]:
144 | if model.is_w2v or model.is_ada:
145 | assert input is not None, ("input argument is required because"
146 | "model uses w2v or adagram vectors")
147 | assert input.dim() == 1, ("input argument must be vector"
148 | "but its dim is {0}".format(input.dim()))
149 | to_input["input"] = input.repeat(n).view(n, -1).to(device)
150 | if model.is_attn:
151 | assert word is not None, ("word argument is required because"
152 | "model uses attention")
153 | assert context is not None, ("context argument is required because"
154 | "model uses attention")
155 | assert context_voc is not None, ("context_voc argument is required"
156 | "because model uses attention")
157 | assert isinstance(word, str), ("word argument must be string")
158 | assert isinstance(context, str), ("context argument must be "
159 | "string")
160 | to_input["word"] = torch.LongTensor(
161 | [context_voc.encode(word)]
162 | ).repeat(n).view(n).to(device)
163 | to_input["context"] = torch.LongTensor(
164 | context_voc.encode_seq(context.split())
165 | ).repeat(n).view(n, -1).to(device)
166 | if model.params["use_ch"]:
167 | assert CH_word is not None, ("CH_word argument is required because"
168 | "because model uses CH conditioning")
169 | assert ch_voc is not None, ("ch_voc argument is required because"
170 | "because model uses CH conditioning")
171 | assert isinstance(CH_word, str), ("CH_word must be string")
172 | to_input["CH_word"] = torch.LongTensor(
173 | pad(
174 | [constants.BOS_IDX] +
175 | ch_voc.encode_seq(list(CH_word)) +
176 | [constants.EOS_IDX], ch_voc.tok_maxlen + 2,
177 | constants.PAD_IDX
178 | )
179 | ).repeat(n).view(n, -1).to(device)
180 |
181 | to_input["x"] = None
182 | to_input["hidden"] = None # pytorch automatically init to zeroes
183 | ret = [[] for i in range(n)]
184 | if prefix is not None:
185 | assert isinstance(prefix, str), "prefix argument must be string"
186 | if len(prefix.split()) > 0:
187 | to_input["x"] = torch.LongTensor(
188 | voc.encode_seq(prefix.split())
189 | ).repeat(n).view(n, -1).to(device)
190 | else:
191 | to_input["x"] = torch.randint(
192 | model.params["ntokens"], size=(1, ), dtype=torch.long
193 | ).repeat(n).view(n, -1).to(device)
194 | prefix = voc.decode(to_input["x"][0][0].item())
195 | with torch.no_grad():
196 | for i in range(length):
197 | output, to_input["hidden"] = model(**to_input)
198 | output = output.view((n, -1, model.params["ntokens"]))[:, -1, :]
199 | to_input["x"] = F.softmax(
200 | output / tau, dim=1
201 | ).multinomial(num_samples=1)
202 | for i in range(n):
203 | ret[i].append(to_input["x"][i][0].item())
204 |
205 | output = [[] for i in range(n)]
206 | for i in range(n):
207 | decoded = voc.decode_seq(ret[i])
208 | for j in range(length):
209 | if decoded[j] == constants.EOS:
210 | break
211 | output[i].append(decoded[j])
212 | output[i] = " ".join(map(str, output[i]))
213 |
214 | return "\n".join(output)
215 |
--------------------------------------------------------------------------------
/source/utils.py:
--------------------------------------------------------------------------------
1 | import json
2 | import subprocess
3 | import os
4 | import numpy as np
5 | from gensim.models import KeyedVectors
6 | import random
7 |
8 |
9 | def prepare_ada_vectors_from_python(file, julia_script, ada_binary_path):
10 | """
11 | file - path to file with words and contexts on each line separated by \t,
12 | words and punctuation marks in contexts are separated by spaces
13 | julia_script - path to prep_ada.jl script
14 | ada_binary_path - path to ada binary file
15 | """
16 | data = open(file, "r").readlines()
17 | tmp = []
18 | for i in range(len(data)):
19 | word, context = data[i].split('\t')
20 | context = context.strip().split()
21 | tmp.append([[word], [], context])
22 | tmp_name = "./tmp" + str(random.randint(1, 999999)) + ".txt"
23 | tmp_script_name = "./tmp_script" + str(random.randint(1, 999999)) + ".sh"
24 | tmp_vecs_name = "./tmp_vecs" + str(random.randint(1, 999999))
25 | with open(tmp_name, "w") as outfile:
26 | json.dump(tmp, outfile, indent=4)
27 | with open(tmp_script_name, "w") as outfile:
28 | outfile.write(
29 | "julia " + julia_script + " --defs " + tmp_name +
30 | " --save " + tmp_vecs_name +
31 | " --ada " + ada_binary_path
32 | )
33 | subprocess.call(["/bin/bash", "-i", tmp_script_name])
34 | vecs = np.load(tmp_vecs_name).astype(np.float32)
35 | os.remove(tmp_name)
36 | os.remove(tmp_script_name)
37 | os.remove(tmp_vecs_name)
38 | return vecs
39 |
40 |
41 | def prepare_w2v_vectors(file, w2v_binary_path):
42 | """
43 | file - path to file with words and contexts on each line separated by \t,
44 | words and punctuation marks in contexts are separated by spaces
45 | w2v_binary_path - path to w2v binary
46 | """
47 | data = open(file, "r").readlines()
48 | word_vectors = KeyedVectors.load_word2vec_format(
49 | w2v_binary_path, binary=True
50 | )
51 | vecs = []
52 | initrange = 0.5 / word_vectors.vector_size
53 | for i in range(len(data)):
54 | word, context = data[i].split('\t')
55 | context = context.strip().split()
56 | if word in word_vectors:
57 | vecs.append(word_vectors[word])
58 | else:
59 | vecs.append(
60 | np.random.uniform(
61 | low=-initrange,
62 | high=initrange,
63 | size=word_vectors.vector_size
64 | )
65 | )
66 | return np.array(vecs, dtype=np.float32)
67 |
68 |
69 | class MultipleOptimizer(object):
70 |
71 | def __init__(self, *op):
72 | self.optimizers = op
73 |
74 | def zero_grad(self):
75 | for op in self.optimizers:
76 | op.zero_grad()
77 |
78 | def step(self):
79 | for op in self.optimizers:
80 | op.step()
81 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | from source.datasets import LanguageModelingDataset, LanguageModelingCollate
2 | from source.datasets import DefinitionModelingDataset, DefinitionModelingCollate
3 | from source.model import DefinitionModelingModel
4 | from source.pipeline import train_epoch, test
5 | from torch.utils.data import DataLoader
6 | import torch
7 | import torch.optim as optim
8 | from tqdm import tqdm
9 | import argparse
10 | import json
11 | import numpy as np
12 |
13 | # Read all arguments and prepare all stuff for training
14 |
15 | parser = argparse.ArgumentParser(description='Script to train a model')
16 | # Type of training
17 | parser.add_argument(
18 | '--pretrain', dest='pretrain', action="store_true",
19 | help='whether to pretrain model on LM dataset or train on definitions'
20 | )
21 | # Common data arguments
22 | parser.add_argument(
23 | "--voc", type=str, required=True, help="location of vocabulary file"
24 | )
25 | # Definitions data arguments
26 | parser.add_argument(
27 | '--train_defs', type=str, required=False,
28 | help="location of json file with train definitions."
29 | )
30 | parser.add_argument(
31 | '--eval_defs', type=str, required=False,
32 | help="location of json file with eval definitions."
33 | )
34 | parser.add_argument(
35 | '--test_defs', type=str, required=False,
36 | help="location of json file with test definitions"
37 | )
38 | parser.add_argument(
39 | '--input_train', type=str, required=False,
40 | help="location of train vectors for Input conditioning"
41 | )
42 | parser.add_argument(
43 | '--input_eval', type=str, required=False,
44 | help="location of eval vectors for Input conditioning"
45 | )
46 | parser.add_argument(
47 | '--input_test', type=str, required=False,
48 | help="location of test vectors for Input conditioning"
49 | )
50 | parser.add_argument(
51 | '--input_adaptive_train', type=str, required=False,
52 | help="location of train vectors for InputAdaptive conditioning"
53 | )
54 | parser.add_argument(
55 | '--input_adaptive_eval', type=str, required=False,
56 | help="location of eval vectors for InputAdaptive conditioning"
57 | )
58 | parser.add_argument(
59 | '--input_adaptive_test', type=str, required=False,
60 | help="location test vectors for InputAdaptive conditioning"
61 | )
62 | parser.add_argument(
63 | '--context_voc', type=str, required=False,
64 | help="location of context vocabulary file"
65 | )
66 | parser.add_argument(
67 | '--ch_voc', type=str, required=False,
68 | help="location of CH vocabulary file"
69 | )
70 | # LM data arguments
71 | parser.add_argument(
72 | '--train_lm', type=str, required=False,
73 | help="location of txt file train LM data"
74 | )
75 | parser.add_argument(
76 | '--eval_lm', type=str, required=False,
77 | help="location of txt file eval LM data"
78 | )
79 | parser.add_argument(
80 | '--test_lm', type=str, required=False,
81 | help="location of txt file test LM data"
82 | )
83 | parser.add_argument(
84 | '--bptt', type=int, required=False,
85 | help="sequence length for BackPropThroughTime in LM pretraining"
86 | )
87 | # Model parameters arguments
88 | parser.add_argument(
89 | '--nx', type=int, required=True,
90 | help="size of embeddings"
91 | )
92 | parser.add_argument(
93 | '--nlayers', type=int, required=True,
94 | help="number of LSTM layers"
95 | )
96 | parser.add_argument(
97 | '--nhid', type=int, required=True,
98 | help="size of hidden states"
99 | )
100 | parser.add_argument(
101 | '--rnn_dropout', type=float, required=True,
102 | help="probability of RNN dropout"
103 | )
104 | parser.add_argument(
105 | '--use_seed', dest="use_seed", action="store_true",
106 | help="whether to use Seed conditioning or not"
107 | )
108 | parser.add_argument(
109 | '--use_input', dest="use_input", action="store_true",
110 | help="whether to use Input conditioning or not"
111 | )
112 | parser.add_argument(
113 | '--use_input_adaptive', dest="use_input_adaptive", action="store_true",
114 | help="whether to use InputAdaptive conditioning or not"
115 | )
116 | parser.add_argument(
117 | '--use_input_attention', dest="use_input_attention",
118 | action="store_true",
119 | help="whether to use InputAttention conditioning or not"
120 | )
121 | parser.add_argument(
122 | '--n_attn_embsize', type=int, required=False,
123 | help="size of InputAttention embeddings"
124 | )
125 | parser.add_argument(
126 | '--n_attn_hid', type=int, required=False,
127 | help="size of InputAttention linear layer"
128 | )
129 | parser.add_argument(
130 | '--attn_dropout', type=float, required=False,
131 | help="probability of InputAttention dropout"
132 | )
133 | parser.add_argument(
134 | '--attn_sparse', dest="attn_sparse", action="store_true",
135 | help="whether to use sparse embeddings in InputAttention or not"
136 | )
137 | parser.add_argument(
138 | '--use_ch', dest="use_ch", action="store_true",
139 | help="whether to use CH conditioning or not"
140 | )
141 | parser.add_argument(
142 | '--ch_emb_size', type=int, required=False,
143 | help="size of embeddings in CH conditioning"
144 | )
145 | parser.add_argument(
146 | '--ch_feature_maps', type=int, required=False, nargs="+",
147 | help="list of feature map sizes in CH conditioning"
148 | )
149 | parser.add_argument(
150 | '--ch_kernel_sizes', type=int, required=False, nargs="+",
151 | help="list of kernel sizes in CH conditioning"
152 | )
153 | parser.add_argument(
154 | '--use_hidden', dest="use_hidden", action="store_true",
155 | help="whether to use Hidden conditioning or not"
156 | )
157 | parser.add_argument(
158 | '--use_hidden_adaptive', dest="use_hidden_adaptive",
159 | action="store_true",
160 | help="whether to use HiddenAdaptive conditioning or not"
161 | )
162 | parser.add_argument(
163 | '--use_hidden_attention', dest="use_hidden_attention",
164 | action="store_true",
165 | help="whether to use HiddenAttention conditioning or not"
166 | )
167 | parser.add_argument(
168 | '--use_gated', dest="use_gated", action="store_true",
169 | help="whether to use Gated conditioning or not"
170 | )
171 | parser.add_argument(
172 | '--use_gated_adaptive', dest="use_gated_adaptive", action="store_true",
173 | help="whether to use GatedAdaptive conditioning or not"
174 | )
175 | parser.add_argument(
176 | '--use_gated_attention', dest="use_gated_attention", action="store_true",
177 | help="whether to use GatedAttention conditioning or not"
178 | )
179 | # Training arguments
180 | parser.add_argument(
181 | '--lr', type=float, required=True,
182 | help="initial lr"
183 | )
184 | parser.add_argument(
185 | "--decay_factor", type=float, required=True,
186 | help="factor to decay lr"
187 | )
188 | parser.add_argument(
189 | '--decay_patience', type=int, required=True,
190 | help="after number of patience epochs - decay lr"
191 | )
192 | parser.add_argument(
193 | '--num_epochs', type=int, required=True,
194 | help="number of epochs to train"
195 | )
196 | parser.add_argument(
197 | '--batch_size', type=int, required=True,
198 | help="batch size"
199 | )
200 | parser.add_argument(
201 | "--clip", type=float, required=True,
202 | help="value to clip norm of gradients to"
203 | )
204 | parser.add_argument(
205 | "--random_seed", type=int, required=True,
206 | help="random seed"
207 | )
208 | # Utility arguments
209 | parser.add_argument(
210 | "--exp_dir", type=str, required=True,
211 | help="where to save all stuff about training"
212 | )
213 | parser.add_argument(
214 | "--w2v_weights", type=str, required=False,
215 | help="path to pretrained embeddings to init"
216 | )
217 | parser.add_argument(
218 | "--fix_embeddings", dest="fix_embeddings", action="store_true",
219 | help="whether to update embedding matrix or not"
220 | )
221 | parser.add_argument(
222 | "--fix_attn_embeddings", dest="fix_attn_embeddings", action="store_true",
223 | help="whether to update attention embedding matrix or not"
224 | )
225 | parser.add_argument(
226 | "--lm_ckpt", type=str, required=False,
227 | help="path to pretrained language model weights"
228 | )
229 | parser.add_argument(
230 | "--attn_ckpt", type=str, required=False,
231 | help="path to pretrained Attention module"
232 | )
233 | # read args
234 | args = vars(parser.parse_args())
235 |
236 | logfile = open(args["exp_dir"] + "training_log", "a")
237 | #import sys
238 | #logfile = sys.stdout
239 |
240 | if args["pretrain"]:
241 | assert args["train_lm"] is not None, "--train_lm is required if --pretrain"
242 | assert args["eval_lm"] is not None, "--eval_lm is required if --pretrain"
243 | assert args["test_lm"] is not None, "--test_lm is required if --pretrain"
244 | assert args["bptt"] is not None, "--bptt is required if --pretrain"
245 |
246 | train_dataset = LanguageModelingDataset(
247 | file=args["train_lm"],
248 | vocab_path=args["voc"],
249 | bptt=args["bptt"],
250 | )
251 | train_dataloader = DataLoader(
252 | train_dataset, batch_size=args["batch_size"],
253 | collate_fn=LanguageModelingCollate
254 | )
255 | eval_dataset = LanguageModelingDataset(
256 | file=args["eval_lm"],
257 | vocab_path=args["voc"],
258 | bptt=args["bptt"],
259 | )
260 | eval_dataloader = DataLoader(
261 | eval_dataset, batch_size=args["batch_size"],
262 | collate_fn=LanguageModelingCollate
263 | )
264 | else:
265 | assert args["train_defs"] is not None, ("--pretrain is False,"
266 | " --train_defs is required")
267 | assert args["eval_defs"] is not None, ("--pretrain is False,"
268 | " --eval_defs is required")
269 | assert args["test_defs"] is not None, ("--pretrain is False,"
270 | " --test_defs is required")
271 |
272 | train_dataset = DefinitionModelingDataset(
273 | file=args["train_defs"],
274 | vocab_path=args["voc"],
275 | input_vectors_path=args["input_train"],
276 | input_adaptive_vectors_path=args["input_adaptive_train"],
277 | context_vocab_path=args["context_voc"],
278 | ch_vocab_path=args["ch_voc"],
279 | use_seed=args["use_seed"]
280 | )
281 | train_dataloader = DataLoader(
282 | train_dataset,
283 | batch_size=args["batch_size"],
284 | collate_fn=DefinitionModelingCollate
285 | )
286 | eval_dataset = DefinitionModelingDataset(
287 | file=args["eval_defs"],
288 | vocab_path=args["voc"],
289 | input_vectors_path=args["input_eval"],
290 | input_adaptive_vectors_path=args["input_adaptive_eval"],
291 | context_vocab_path=args["context_voc"],
292 | ch_vocab_path=args["ch_voc"],
293 | use_seed=args["use_seed"]
294 | )
295 | eval_dataloader = DataLoader(
296 | eval_dataset,
297 | batch_size=args["batch_size"],
298 | collate_fn=DefinitionModelingCollate
299 | )
300 |
301 | if args["use_input"] or args["use_hidden"] or args["use_gated"]:
302 | assert args["input_train"] is not None, ("--use_input or "
303 | "--use_hidden or "
304 | "--use_gated is used "
305 | "--input_train is required")
306 | assert args["input_eval"] is not None, ("--use_input or "
307 | "--use_hidden or "
308 | "--use_gated is used "
309 | "--input_eval is required")
310 | assert args["input_test"] is not None, ("--use_input or "
311 | "--use_hidden or "
312 | "--use_gated is used "
313 | "--input_test is required")
314 | args["input_dim"] = train_dataset.input_vectors.shape[1]
315 |
316 | if args["use_input_adaptive"] or args["use_hidden_adaptive"] or args["use_gated_adaptive"]:
317 | assert args["input_adaptive_train"] is not None, ("--use_input_adaptive or "
318 | "--use_hidden_adaptive or "
319 | "--use_gated_adaptive is used "
320 | "--input_adaptive_train is required")
321 | assert args["input_adaptive_eval"] is not None, ("--use_input_adaptive or "
322 | "--use_hidden_adaptive or "
323 | "--use_gated_adaptive is used "
324 | "--input_adaptive_eval is required")
325 | assert args["input_adaptive_test"] is not None, ("--use_input_adaptive or "
326 | "--use_hidden_adaptive or "
327 | "--use_gated_adaptive is used "
328 | "--input_adaptive_test is required")
329 | args["input_adaptive_dim"] = train_dataset.input_adaptive_vectors.shape[1]
330 |
331 | if args["use_input_attention"] or args["use_hidden_attention"] or args["use_gated_attention"]:
332 | assert args["context_voc"] is not None, ("--use_input_attention or "
333 | "--use_hidden_attention or "
334 | "--use_gated_attention is used "
335 | "--context_voc is required")
336 | assert args["n_attn_embsize"] is not None, ("--use_input_attention or "
337 | "--use_hidden_attention or "
338 | "--use_gated_attention is used "
339 | "--n_attn_embsize is required")
340 | assert args["n_attn_hid"] is not None, ("--use_input_attention or "
341 | "--use_hidden_attention or "
342 | "--use_gated_attention is used "
343 | "--n_attn_hid is required")
344 | assert args["attn_dropout"] is not None, ("--use_input_attention or "
345 | "--use_hidden_attention or "
346 | "--use_gated_attention is used "
347 | "--attn_dropout is required")
348 |
349 | args["n_attn_tokens"] = len(train_dataset.context_voc.tok2id)
350 |
351 | if args["use_ch"]:
352 | assert args["ch_voc"] is not None, ("--ch_voc is required "
353 | "if --use_ch")
354 | assert args["ch_emb_size"] is not None, ("--ch_emb_size is required "
355 | "if --use_ch")
356 | assert args["ch_feature_maps"] is not None, ("--ch_feature_maps is "
357 | "required if --use_ch")
358 | assert args["ch_kernel_sizes"] is not None, ("--ch_kernel_sizes is "
359 | "required if --use_ch")
360 |
361 | args["n_ch_tokens"] = len(train_dataset.ch_voc.tok2id)
362 | args["ch_maxlen"] = train_dataset.ch_voc.tok_maxlen + 2
363 |
364 |
365 | args["ntokens"] = len(train_dataset.voc.tok2id)
366 |
367 | np.random.seed(args["random_seed"])
368 | torch.manual_seed(args["random_seed"])
369 | if torch.cuda.is_available():
370 | torch.cuda.manual_seed(args["random_seed"])
371 |
372 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
373 | model = DefinitionModelingModel(args).to(device)
374 | scheduler = optim.lr_scheduler.ReduceLROnPlateau(
375 | optim.Adam(
376 | filter(lambda p: p.requires_grad, model.parameters()), lr=args["lr"]
377 | ),
378 | factor=args["decay_factor"],
379 | patience=args["decay_patience"]
380 | )
381 |
382 | best_ppl = float("inf")
383 | for epoch in tqdm(range(args["num_epochs"]), file=logfile):
384 | train_epoch(
385 | train_dataloader,
386 | model,
387 | scheduler.optimizer,
388 | device,
389 | args["clip"],
390 | logfile
391 | )
392 | eval_ppl = test(eval_dataloader, model, device, logfile)
393 | if eval_ppl < best_ppl:
394 | best_ppl = eval_ppl
395 | torch.save(
396 | {"state_dict": model.state_dict()},
397 | args["exp_dir"] + "weights.pth"
398 | )
399 | scheduler.step(metrics=eval_ppl)
400 |
401 | with open(args["exp_dir"] + "params.json", "w") as outfile:
402 | json.dump(args, outfile, indent=4)
403 |
--------------------------------------------------------------------------------
/train_attention_skipgram.py:
--------------------------------------------------------------------------------
1 | from source.attention_skipgram import AttentionSkipGram
2 | from source.utils import MultipleOptimizer
3 | import argparse
4 | import numpy as np
5 | import os.path
6 | from tqdm import tqdm
7 | from collections import Counter
8 | import json
9 | from source.datasets import Vocabulary
10 | import torch
11 | import torch.nn as nn
12 | import torch.nn.functional as F
13 | import torch.optim as optim
14 | from itertools import chain
15 |
16 | parser = argparse.ArgumentParser(
17 | description='Script to train a AttentionSkipGram model'
18 | )
19 | parser.add_argument(
20 | '--data', type=str, required=False,
21 | help="path to data"
22 | )
23 | parser.add_argument(
24 | '--context_voc', type=str, required=True,
25 | help=("path to context voc for DefinitionModelingModel is necessary to "
26 | "save pretrained attention module, particulary embedding matrix")
27 | )
28 | parser.add_argument(
29 | '--prepared', dest='prepared', action="store_true",
30 | help='whether to prepare data or use already prepared'
31 | )
32 | parser.add_argument(
33 | "--window", type=int, required=True,
34 | help="window for AttentionSkipGram model"
35 | )
36 | parser.add_argument(
37 | "--random_seed", type=int, required=True,
38 | help="random seed for training"
39 | )
40 | parser.add_argument(
41 | "--sparse", dest="sparse", action="store_true",
42 | help="whether to use sparse embeddings or not"
43 | )
44 | parser.add_argument(
45 | "--vec_dim", type=int, required=True,
46 | help="vector dim to train"
47 | )
48 | parser.add_argument(
49 | "--attn_hid", type=int, required=True,
50 | help="hidden size in attention module"
51 | )
52 | parser.add_argument(
53 | "--attn_dropout", type=float, required=True,
54 | help="dropout prob in attention module"
55 | )
56 | parser.add_argument(
57 | "--lr", type=float, required=True,
58 | help="initial lr to use"
59 | )
60 | parser.add_argument(
61 | "--batch_size", type=int, required=True,
62 | help="batch size to use"
63 | )
64 | parser.add_argument(
65 | "--num_epochs", type=int, required=True,
66 | help="number of epochs to train"
67 | )
68 | parser.add_argument(
69 | "--exp_dir", type=str, required=True,
70 | help="where to save weights, prepared data and logs"
71 | )
72 | args = vars(parser.parse_args())
73 | logfile = open(args["exp_dir"] + "training_log", "a")
74 |
75 | context_voc = Vocabulary()
76 | context_voc.load(args["context_voc"])
77 |
78 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
79 |
80 | np.random.seed(args["random_seed"])
81 | torch.manual_seed(args["random_seed"])
82 | if torch.cuda.is_available():
83 | torch.cuda.manual_seed(args["random_seed"])
84 |
85 | if args["prepared"]:
86 | assert os.path.isfile(args["exp_dir"] + "data.npz"), ("prepared data "
87 | "does not exist")
88 | assert os.path.isfile(args["exp_dir"] + "voc.json"), ("prepared voc "
89 | "does not exist")
90 |
91 | tqdm.write("Loading data!", file=logfile)
92 | logfile.flush()
93 |
94 | data = np.load(args["exp_dir"] + "data.npz")
95 | words_idx = data['words_idx']
96 | cnt_idx = data['cnt_idx']
97 | freqs = data['freqs']
98 |
99 | with open(args["exp_dir"] + "voc.json", 'r') as f:
100 | voc = json.load(f)
101 |
102 | word2id = voc['word2id']
103 | id2word = voc['id2word']
104 |
105 | else:
106 | assert args["data"] is not None, "--prepared False, provide --data"
107 |
108 | tqdm.write("Preparing data!", file=logfile)
109 | logfile.flush()
110 |
111 | with open(args["data"], 'r') as f:
112 | data = f.read()
113 |
114 | data = data.lower().split()
115 | counter = Counter(data)
116 | word2id = {}
117 | id2word = {}
118 | i = 0
119 | words = []
120 | counts = []
121 | for w, c in counter.most_common():
122 | words.append(w)
123 | counts.append(c)
124 | word2id[words[-1]] = i
125 | id2word[i] = w
126 | i += 1
127 |
128 | freqs = np.array(counts)
129 | freqs = freqs / freqs.sum()
130 | freqs = np.sqrt(freqs)
131 | freqs = freqs / freqs.sum()
132 | data = list(map(lambda w: word2id[w], data))
133 |
134 | words_idx = np.zeros(len(data) - 2 * args["window"], dtype=np.int)
135 | cnt_idx = np.zeros(
136 | (len(data) - 2 * args["window"], 2 * args["window"]), dtype=np.int
137 | )
138 |
139 | for i in tqdm(range(args["window"], len(data) - args["window"]), file=logfile):
140 | words_idx[i - args["window"]] = data[i]
141 | cnt_idx[i - args["window"]] = np.array(
142 | data[i - args["window"]:i] + data[i + 1:i + args["window"] + 1]
143 | )
144 |
145 | np.savez(
146 | args["exp_dir"] + "data",
147 | words_idx=words_idx,
148 | cnt_idx=cnt_idx,
149 | freqs=freqs
150 | )
151 | with open(args["exp_dir"] + "voc.json", 'w') as f:
152 | json.dump({'word2id': word2id, 'id2word': id2word}, f)
153 |
154 | tqdm.write("Data prepared and saved!", file=logfile)
155 | logfile.flush()
156 |
157 |
158 | def generate_neg(batch_size, negative=10):
159 | return np.random.choice(freqs.size, size=(batch_size, negative), p=freqs)
160 |
161 |
162 | def generate_batch(batch_size=128):
163 | shuffle = np.random.permutation(words_idx.shape[0])
164 | words_idx_shuffled = words_idx[shuffle]
165 | cnt_idx_shuffled = cnt_idx[shuffle]
166 | for i in tqdm(range(0, words_idx.shape[0], batch_size), file=logfile):
167 | start = i
168 | end = min(i + batch_size, words_idx.shape[0])
169 | words = words_idx_shuffled[start:end]
170 | context = cnt_idx_shuffled[start:end]
171 | neg = generate_neg(end - start)
172 |
173 | context = torch.from_numpy(context).to(device)
174 | words = torch.from_numpy(words).to(device)
175 | neg = torch.from_numpy(neg).to(device)
176 |
177 | yield words, context, neg
178 |
179 | del words_idx_shuffled
180 | del cnt_idx_shuffled
181 |
182 | tqdm.write("Initialising model!", file=logfile)
183 | logfile.flush()
184 |
185 | model = AttentionSkipGram(
186 | n_attn_tokens=len(word2id),
187 | n_attn_embsize=args["vec_dim"],
188 | n_attn_hid=args["attn_hid"],
189 | attn_dropout=args["attn_dropout"],
190 | sparse=args["sparse"]
191 | ).to(device)
192 |
193 |
194 | if args["sparse"]:
195 | optimizer = MultipleOptimizer(
196 | optim.SparseAdam(chain(
197 | model.emb0_lookup.embs.parameters(),
198 | model.emb1_lookup.parameters()
199 | ), lr=args["lr"]),
200 | optim.Adam(chain(
201 | model.emb0_lookup.ann.parameters(),
202 | model.emb0_lookup.a_linear.parameters()
203 | ), lr=args["lr"])
204 | )
205 | else:
206 | optimizer = optim.Adam(model.parameters(), lr=args["lr"])
207 |
208 | tqdm.write("Start training!", file=logfile)
209 | logfile.flush()
210 |
211 | model.train()
212 |
213 | for _ in range(args["num_epochs"]):
214 | for w, c, n in generate_batch(batch_size=args["batch_size"]):
215 | optimizer.zero_grad()
216 | loss = model(w, c, n)
217 | loss.backward()
218 | optimizer.step()
219 |
220 |
221 | tqdm.write("Training ended! Saving model!", file=logfile)
222 | logfile.flush()
223 |
224 | state_dict = model.emb0_lookup.state_dict()
225 | initrange = 0.5 / args["vec_dim"]
226 | embs_weights = np.random.uniform(
227 | low=-initrange,
228 | high=initrange,
229 | size=(len(context_voc.tok2id), args["vec_dim"]),
230 | ).astype(np.float32)
231 | for word in context_voc.tok2id.keys():
232 | if word in word2id:
233 | new_id = context_voc.tok2id[word]
234 | old_id = word2id[word]
235 | embs_weights[new_id] = state_dict["embs.weight"][old_id].cpu().numpy()
236 |
237 | state_dict["embs.weight"] = torch.from_numpy(embs_weights).to(device)
238 |
239 | torch.save(
240 | {"state_dict": state_dict},
241 | args["exp_dir"] + "weights.pth"
242 | )
243 |
244 | with open(args["exp_dir"] + "params.json", "w") as outfile:
245 | json.dump(args, outfile, indent=4)
246 |
--------------------------------------------------------------------------------