├── .gitignore ├── README.md ├── bleu.py ├── eval.py ├── generate.py ├── prep_ada.jl ├── prep_embedding_matrix.py ├── prep_vocab.py ├── prep_w2v.py ├── source ├── __init__.py ├── attention_skipgram.py ├── constants.py ├── datasets.py ├── layers.py ├── model.py ├── pipeline.py └── utils.py ├── train.py └── train_attention_skipgram.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | data/ 3 | nohup.out 4 | .ipynb_checkpoints 5 | sftp-config.json 6 | .DS_Store 7 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Conditional Generators of Words Definitions 2 | 3 | This repo contains code for our paper [Conditional Generators of Words Definitions](https://arxiv.org/abs/1806.10090). 4 | 5 | __Abstract__ 6 | 7 | We explore recently introduced definition modeling technique that provided the tool for evaluation of different distributed 8 | vector representations of words through modeling dictionary definitions of words. In this work, we study the problem of word ambiguities in definition modeling and propose a possible solution by employing latent variable modeling and soft attention mechanisms. Our quantitative and qualitative evaluation and analysis of the model shows that taking into account words ambiguity and polysemy leads to performance improvement. 9 | 10 | # Citation 11 | 12 | ``` 13 | @InProceedings{P18-2043, 14 | author = "Gadetsky, Artyom and Yakubovskiy, Ilya and Vetrov, Dmitry", 15 | title = "Conditional Generators of Words Definitions", 16 | booktitle = "Proceedings of the 56th Annual Meeting of the Association for Computational Linguistics (Volume 2: Short Papers)", 17 | year = "2018", 18 | publisher = "Association for Computational Linguistics", 19 | pages = "266--271", 20 | location = "Melbourne, Australia", 21 | url = "http://aclweb.org/anthology/P18-2043" 22 | } 23 | ``` 24 | 25 | # Environment requirements and Data Preparation 26 | 27 | * Install conda environment with the following packages: 28 | 29 | ``` 30 | Python 3.6 31 | Pytorch 0.4 32 | Numpy 1.14 33 | Tqdm 4.23 34 | Gensim 3.4 35 | ``` 36 | 37 | * To install AdaGram software to use Adaptive conditioning: 38 | 39 | Download Julia 0.6 binaries from [official site](https://julialang.org/downloads/) and add alias in ~/.bashrc 40 | ``` 41 | alias julia='JULIA_BINARY_PATH/bin/julia' 42 | ``` 43 | Use `source ~/.bashrc` to reload ~/.bashrc 44 | 45 | Then activate julia interpreter using `julia` and install following packages: 46 | ``` 47 | Pkg.clone("https://github.com/mirestrepo/AdaGram.jl") 48 | Pkg.build("AdaGram") 49 | Pkg.add("ArgParse") 50 | Pkg.add("JSON") 51 | Pkg.add("NPZ") 52 | exit() 53 | ``` 54 | Then add in ~/.bashrc 55 | ``` 56 | export PATH="JULIA_BINARY_PATH/bin:$PATH" 57 | export LD_LIBRARY_PATH="JULIA_INSTALL_PATH/v0.6/AdaGram/lib:$LD_LIBRARY_PATH" 58 | ``` 59 | And finally to apply exports 60 | ``` 61 | source ~/.bashrc 62 | ``` 63 | * To install Mosesdecoder (for BLEU) follow instructions on the [official site](http://www.statmt.org/moses/?n=Development.GetStarted) 64 | 65 | * To get data for language model (LM) pretraining: 66 | ``` 67 | cd pytorch-definitions 68 | mkdir data 69 | cd data 70 | wget https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-103-v1.zip 71 | unzip wikitext-103-v1.zip 72 | ``` 73 | * To get data for Google word vectors use [official site](https://code.google.com/archive/p/word2vec/). You need .bin.gz file. Don't forget to `gunzip` downloaded file to extract binaries 74 | 75 | * Adaptive Skip-gram vectors are available upon request. Also you can train your owns using instructions in the [official repo](https://github.com/sbos/AdaGram.jl) 76 | 77 | * The Definition Modeling data is available upon request because of Oxford Dictionaries distribution license. Also you can collect your own. If you want to collect your own, then you should prepare 3 datasplits: train, test and val. Each datasplit is python array with the following format saved as json file: 78 | 79 | ``` 80 | data = [ 81 | [ 82 | ["word"], 83 | ["word1", "word2", ...], 84 | ["word1", "word2", ...] 85 | ], 86 | ... 87 | ] 88 | So i-th element of the data: 89 | data[i][0][0] - word being defined (string) 90 | data[i][1] - definition (list of strings) 91 | data[i][2] - context to understand word meaning (list of strings) 92 | ``` 93 | 94 | # Usage 95 | Firstly, you need to prepare vocabs, vectors and etc for using model: 96 | 97 | * To prepare vocabs use `python prep_vocab.py` 98 | 99 | ``` 100 | usage: prep_vocab.py [-h] --defs DEFS [DEFS ...] [--lm LM [LM ...]] [--same] 101 | --save SAVE [--save_context SAVE_CONTEXT] --save_chars 102 | SAVE_CHARS 103 | 104 | Prepare vocabularies for model 105 | 106 | optional arguments: 107 | -h, --help show this help message and exit 108 | --defs DEFS [DEFS ...] 109 | location of json file with definitions. 110 | --lm LM [LM ...] location of txt file with text for LM pre-training 111 | --same use same vocab for definitions and contexts 112 | --save SAVE where to save prepaired vocabulary (for words from 113 | definitions) 114 | --save_context SAVE_CONTEXT 115 | where to save vocabulary (for words from contexts) 116 | --save_chars SAVE_CHARS 117 | where to save char vocabulary (for chars from all 118 | words) 119 | ``` 120 | 121 | * To prepare w2v vectors use `python prep_w2v.py` 122 | ``` 123 | usage: prep_w2v.py [-h] --defs DEFS [DEFS ...] --save SAVE [SAVE ...] --w2v 124 | W2V 125 | 126 | Prepare word vectors for Input conditioning 127 | 128 | optional arguments: 129 | -h, --help show this help message and exit 130 | --defs DEFS [DEFS ...] 131 | location of json file with definitions. 132 | --save SAVE [SAVE ...] 133 | where to save files 134 | --w2v W2V location of binary w2v file 135 | ``` 136 | 137 | * To prepare Adagram vectors use `julia prep_ada.jl` 138 | ``` 139 | usage: prep_ada.jl --defs DEFS [DEFS...] --save SAVE [SAVE...] 140 | --ada ADA [-h] 141 | 142 | Prepare word vectors for Input-Adaptive conditioning 143 | 144 | optional arguments: 145 | --defs DEFS [DEFS...] 146 | location of json file with definitions. 147 | --save SAVE [SAVE...] 148 | where to save files 149 | --ada ADA location of AdaGram file 150 | -h, --help show this help message and exit 151 | ``` 152 | * If you want to init embedding matrix of the model with Google word vectors then prepare it using
153 | `python prep_embedding_matrix.py` and then use path to saved weights as `--w2v_weights` in `train.py` 154 | ``` 155 | usage: prep_embedding_matrix.py [-h] --voc VOC --w2v W2V --save SAVE 156 | 157 | Prepare word vectors for embedding layer in the model 158 | 159 | optional arguments: 160 | -h, --help show this help message and exit 161 | --voc VOC location of model vocabulary file 162 | --w2v W2V location of binary w2v file 163 | --save SAVE where to save prepaired matrix 164 | ``` 165 | 166 | Now all is already ready for model usage! 167 | 168 | * To train model use `python train.py` 169 | ``` 170 | usage: train.py [-h] [--pretrain] --voc VOC [--train_defs TRAIN_DEFS] 171 | [--eval_defs EVAL_DEFS] [--test_defs TEST_DEFS] 172 | [--input_train INPUT_TRAIN] [--input_eval INPUT_EVAL] 173 | [--input_test INPUT_TEST] 174 | [--input_adaptive_train INPUT_ADAPTIVE_TRAIN] 175 | [--input_adaptive_eval INPUT_ADAPTIVE_EVAL] 176 | [--input_adaptive_test INPUT_ADAPTIVE_TEST] 177 | [--context_voc CONTEXT_VOC] [--ch_voc CH_VOC] 178 | [--train_lm TRAIN_LM] [--eval_lm EVAL_LM] [--test_lm TEST_LM] 179 | [--bptt BPTT] --nx NX --nlayers NLAYERS --nhid NHID 180 | --rnn_dropout RNN_DROPOUT [--use_seed] [--use_input] 181 | [--use_input_adaptive] [--use_input_attention] 182 | [--n_attn_embsize N_ATTN_EMBSIZE] [--n_attn_hid N_ATTN_HID] 183 | [--attn_dropout ATTN_DROPOUT] [--attn_sparse] [--use_ch] 184 | [--ch_emb_size CH_EMB_SIZE] 185 | [--ch_feature_maps CH_FEATURE_MAPS [CH_FEATURE_MAPS ...]] 186 | [--ch_kernel_sizes CH_KERNEL_SIZES [CH_KERNEL_SIZES ...]] 187 | [--use_hidden] [--use_hidden_adaptive] 188 | [--use_hidden_attention] [--use_gated] [--use_gated_adaptive] 189 | [--use_gated_attention] --lr LR --decay_factor DECAY_FACTOR 190 | --decay_patience DECAY_PATIENCE --num_epochs NUM_EPOCHS 191 | --batch_size BATCH_SIZE --clip CLIP --random_seed RANDOM_SEED 192 | --exp_dir EXP_DIR [--w2v_weights W2V_WEIGHTS] 193 | [--fix_embeddings] [--fix_attn_embeddings] [--lm_ckpt LM_CKPT] 194 | [--attn_ckpt ATTN_CKPT] 195 | 196 | Script to train a model 197 | 198 | optional arguments: 199 | -h, --help show this help message and exit 200 | --pretrain whether to pretrain model on LM dataset or train on 201 | definitions 202 | --voc VOC location of vocabulary file 203 | --train_defs TRAIN_DEFS 204 | location of json file with train definitions. 205 | --eval_defs EVAL_DEFS 206 | location of json file with eval definitions. 207 | --test_defs TEST_DEFS 208 | location of json file with test definitions 209 | --input_train INPUT_TRAIN 210 | location of train vectors for Input conditioning 211 | --input_eval INPUT_EVAL 212 | location of eval vectors for Input conditioning 213 | --input_test INPUT_TEST 214 | location of test vectors for Input conditioning 215 | --input_adaptive_train INPUT_ADAPTIVE_TRAIN 216 | location of train vectors for InputAdaptive 217 | conditioning 218 | --input_adaptive_eval INPUT_ADAPTIVE_EVAL 219 | location of eval vectors for InputAdaptive 220 | conditioning 221 | --input_adaptive_test INPUT_ADAPTIVE_TEST 222 | location test vectors for InputAdaptive conditioning 223 | --context_voc CONTEXT_VOC 224 | location of context vocabulary file 225 | --ch_voc CH_VOC location of CH vocabulary file 226 | --train_lm TRAIN_LM location of txt file train LM data 227 | --eval_lm EVAL_LM location of txt file eval LM data 228 | --test_lm TEST_LM location of txt file test LM data 229 | --bptt BPTT sequence length for BackPropThroughTime in LM 230 | pretraining 231 | --nx NX size of embeddings 232 | --nlayers NLAYERS number of LSTM layers 233 | --nhid NHID size of hidden states 234 | --rnn_dropout RNN_DROPOUT 235 | probability of RNN dropout 236 | --use_seed whether to use Seed conditioning or not 237 | --use_input whether to use Input conditioning or not 238 | --use_input_adaptive whether to use InputAdaptive conditioning or not 239 | --use_input_attention 240 | whether to use InputAttention conditioning or not 241 | --n_attn_embsize N_ATTN_EMBSIZE 242 | size of InputAttention embeddings 243 | --n_attn_hid N_ATTN_HID 244 | size of InputAttention linear layer 245 | --attn_dropout ATTN_DROPOUT 246 | probability of InputAttention dropout 247 | --attn_sparse whether to use sparse embeddings in InputAttention or 248 | not 249 | --use_ch whether to use CH conditioning or not 250 | --ch_emb_size CH_EMB_SIZE 251 | size of embeddings in CH conditioning 252 | --ch_feature_maps CH_FEATURE_MAPS [CH_FEATURE_MAPS ...] 253 | list of feature map sizes in CH conditioning 254 | --ch_kernel_sizes CH_KERNEL_SIZES [CH_KERNEL_SIZES ...] 255 | list of kernel sizes in CH conditioning 256 | --use_hidden whether to use Hidden conditioning or not 257 | --use_hidden_adaptive 258 | whether to use HiddenAdaptive conditioning or not 259 | --use_hidden_attention 260 | whether to use HiddenAttention conditioning or not 261 | --use_gated whether to use Gated conditioning or not 262 | --use_gated_adaptive whether to use GatedAdaptive conditioning or not 263 | --use_gated_attention 264 | whether to use GatedAttention conditioning or not 265 | --lr LR initial lr 266 | --decay_factor DECAY_FACTOR 267 | factor to decay lr 268 | --decay_patience DECAY_PATIENCE 269 | after number of patience epochs - decay lr 270 | --num_epochs NUM_EPOCHS 271 | number of epochs to train 272 | --batch_size BATCH_SIZE 273 | batch size 274 | --clip CLIP value to clip norm of gradients to 275 | --random_seed RANDOM_SEED 276 | random seed 277 | --exp_dir EXP_DIR where to save all stuff about training 278 | --w2v_weights W2V_WEIGHTS 279 | path to pretrained embeddings to init 280 | --fix_embeddings whether to update embedding matrix or not 281 | --fix_attn_embeddings 282 | whether to update attention embedding matrix or not 283 | --lm_ckpt LM_CKPT path to pretrained language model weights 284 | --attn_ckpt ATTN_CKPT 285 | path to pretrained Attention module 286 | ``` 287 | 288 | For example to train simple language model use: 289 | ``` 290 | python train.py --voc VOC_PATH --nx 300 --nhid 300 --rnn_dropout 0.5 --lr 0.001 --decay_factor 0.1 --decay_patience 0 291 | --num_epochs 1 --batch_size 16 --clip 5 --random_seed 42 --exp_dir DIR_PATH -bptt 30 292 | --pretrain --train_lm PATH_TO_WIKI_103_TRAIN --eval_lm PATH_TO_WIKI_103_EVAL --test_lm PATH_TO_WIKI_103_TEST 293 | ``` 294 | 295 | For example to train `Seed + Input` model use: 296 | ``` 297 | python train.py --voc VOC_PATH --nx 300 --nhid 300 --rnn_dropout 0.5 --lr 0.001 --decay_factor 0.1 --decay_patience 0 298 | --num_epochs 1 --batch_size 16 --clip 5 --random_seed 42 --exp_dir DIR_PATH 299 | --train_defs TRAIN_SPLIT_PATH --eval_defs EVAL_DEFS_PATH --test_defs TEST_DEFS_PATH --use_seed 300 | --use_input --input_train PREPARED_W2V_TRAIN_VECS --input_eval PREPARED_W2V_EVAL_VECS --input_test PREPARED_W2V_TEST_VECS 301 | ``` 302 | 303 | To train `Seed + Input` model with pretraining as unconditional LM provide path to pretrained LM weights
as `--lm_ckpt` argument in `train.py` 304 | 305 | * To generate using model use `python generate.py` 306 | ``` 307 | usage: generate.py [-h] --params PARAMS --ckpt CKPT --tau TAU --n N --length 308 | LENGTH [--prefix PREFIX] [--wordlist WORDLIST] 309 | [--w2v_binary_path W2V_BINARY_PATH] 310 | [--ada_binary_path ADA_BINARY_PATH] 311 | [--prep_ada_path PREP_ADA_PATH] 312 | 313 | Script to generate using model 314 | 315 | optional arguments: 316 | -h, --help show this help message and exit 317 | --params PARAMS path to saved model params 318 | --ckpt CKPT path to saved model weights 319 | --tau TAU temperature to use in sampling 320 | --n N number of samples to generate 321 | --length LENGTH maximum length of generated samples 322 | --prefix PREFIX prefix to read until generation starts 323 | --wordlist WORDLIST path to word list with words and contexts 324 | --w2v_binary_path W2V_BINARY_PATH 325 | path to binary w2v file 326 | --ada_binary_path ADA_BINARY_PATH 327 | path to binary ada file 328 | --prep_ada_path PREP_ADA_PATH 329 | path to prep_ada.jl script 330 | ``` 331 | 332 | * To evaluate model use `python eval.py` 333 | ``` 334 | usage: eval.py [-h] --params PARAMS --ckpt CKPT --datasplit DATASPLIT --type 335 | TYPE [--wordlist WORDLIST] [--tau TAU] [--n N] 336 | [--length LENGTH] 337 | 338 | Script to evaluate model 339 | 340 | optional arguments: 341 | -h, --help show this help message and exit 342 | --params PARAMS path to saved model params 343 | --ckpt CKPT path to saved model weights 344 | --datasplit DATASPLIT 345 | train, val or test set to evaluate on 346 | --type TYPE compute ppl or bleu 347 | --wordlist WORDLIST word list to evaluate on (by default all data will be 348 | used) 349 | --tau TAU temperature to use in sampling 350 | --n N number of samples to generate 351 | --length LENGTH maximum length of generated samples 352 | ``` 353 | 354 | * To measure BLEU for trained model, firstly evaluate it using `--bleu` argument in `eval.py`
355 | and then compute bleu using `python bleu.py` 356 | ``` 357 | usage: bleu.py [-h] --ref REF --hyp HYP --n N [--with_contexts] --bleu_path 358 | BLEU_PATH --mode MODE 359 | 360 | Script to compute BLEU 361 | 362 | optional arguments: 363 | -h, --help show this help message and exit 364 | --ref REF path to file with references 365 | --hyp HYP path to file with hypotheses 366 | --n N --n argument used to generate --ref file using eval.py 367 | --with_contexts whether to consider contexts or not when compute BLEU 368 | --bleu_path BLEU_PATH 369 | path to mosesdecoder sentence-bleu binary 370 | --mode MODE whether to average or take random example per word 371 | ``` 372 | 373 | * Also you can pretrain Attention module using `python train_attention_skipgram.py` and
374 | then use path to saved weights as `--attn_ckpt` argument in `train.py` 375 | ``` 376 | usage: train_attention_skipgram.py [-h] [--data DATA] --context_voc 377 | CONTEXT_VOC [--prepared] --window WINDOW 378 | --random_seed RANDOM_SEED [--sparse] 379 | --vec_dim VEC_DIM --attn_hid ATTN_HID 380 | --attn_dropout ATTN_DROPOUT --lr LR 381 | --batch_size BATCH_SIZE --num_epochs 382 | NUM_EPOCHS --exp_dir EXP_DIR 383 | 384 | Script to train a AttentionSkipGram model 385 | 386 | optional arguments: 387 | -h, --help show this help message and exit 388 | --data DATA path to data 389 | --context_voc CONTEXT_VOC 390 | path to context voc for DefinitionModelingModel is 391 | necessary to save pretrained attention module, 392 | particulary embedding matrix 393 | --prepared whether to prepare data or use already prepared 394 | --window WINDOW window for AttentionSkipGram model 395 | --random_seed RANDOM_SEED 396 | random seed for training 397 | --sparse whether to use sparse embeddings or not 398 | --vec_dim VEC_DIM vector dim to train 399 | --attn_hid ATTN_HID hidden size in attention module 400 | --attn_dropout ATTN_DROPOUT 401 | dropout prob in attention module 402 | --lr LR initial lr to use 403 | --batch_size BATCH_SIZE 404 | batch size to use 405 | --num_epochs NUM_EPOCHS 406 | number of epochs to train 407 | --exp_dir EXP_DIR where to save weights, prepared data and logs 408 | ``` 409 | -------------------------------------------------------------------------------- /bleu.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import random 3 | from subprocess import Popen, PIPE 4 | import os 5 | import sys 6 | from itertools import islice 7 | 8 | parser = argparse.ArgumentParser(description='Script to compute BLEU') 9 | parser.add_argument( 10 | "--ref", type=str, required=True, 11 | help="path to file with references" 12 | ) 13 | parser.add_argument( 14 | "--hyp", type=str, required=True, 15 | help="path to file with hypotheses" 16 | ) 17 | parser.add_argument( 18 | "--n", type=int, required=True, 19 | help="--n argument used to generate --ref file using eval.py" 20 | ) 21 | parser.add_argument( 22 | "--with_contexts", dest="with_contexts", action="store_true", 23 | help="whether to consider contexts or not when compute BLEU" 24 | ) 25 | parser.add_argument( 26 | "--bleu_path", type=str, required=True, 27 | help="path to mosesdecoder sentence-bleu binary" 28 | ) 29 | parser.add_argument( 30 | "--mode", type=str, required=True, 31 | help="whether to average or take random example per word" 32 | ) 33 | args = parser.parse_args() 34 | assert args.mode in ["average", "random"], "--mode must be averange or random" 35 | 36 | 37 | def next_n_lines(file_opened, N): 38 | return [x.strip() for x in islice(file_opened, N)] 39 | 40 | 41 | def read_def_file(file, n, with_contexts=False): 42 | defs = {} 43 | while True: 44 | lines = next_n_lines(file, n + 2) 45 | if len(lines) == 0: 46 | break 47 | assert len(lines) == n + 2, "Something bad in hyps file" 48 | word = lines[0].split("Word:")[1].strip() 49 | context = lines[1].split("Context:")[1].strip() 50 | dict_key = word + " " + context if with_contexts else word 51 | if dict_key not in defs: 52 | defs[dict_key] = [] 53 | for i in range(2, n + 2): 54 | defs[dict_key].append(lines[i].strip()) 55 | return defs 56 | 57 | 58 | def read_ref_file(file, with_contexts=False): 59 | defs = {} 60 | while True: 61 | lines = next_n_lines(file, 3) 62 | if len(lines) == 0: 63 | break 64 | assert len(lines) == 3, "Something bad in refs file" 65 | word = lines[0].split("Word:")[1].strip() 66 | context = lines[1].split("Context:")[1].strip() 67 | definition = lines[2].split("Definition:")[1].strip() 68 | dict_key = word + " " + context if with_contexts else word 69 | if dict_key not in defs: 70 | defs[dict_key] = [] 71 | defs[dict_key].append(definition) 72 | return defs 73 | 74 | 75 | def get_bleu_score(bleu_path, all_ref_paths, d, hyp_path): 76 | with open(hyp_path, 'w') as ofp: 77 | ofp.write(d) 78 | read_cmd = ['cat', hyp_path] 79 | bleu_cmd = [bleu_path] + all_ref_paths 80 | rp = Popen(read_cmd, stdout=PIPE) 81 | bp = Popen(bleu_cmd, stdin=rp.stdout, stdout=PIPE, stderr=devnull) 82 | out, err = bp.communicate() 83 | if err is None: 84 | return float(out.strip()) 85 | else: 86 | return None 87 | 88 | with open(args.ref) as ifp: 89 | refs = read_ref_file(ifp, args.with_contexts) 90 | with open(args.hyp) as ifp: 91 | hyps = read_def_file(ifp, args.n, args.with_contexts) 92 | 93 | assert len(refs) == len(hyps), "Number of words being defined mismatched!" 94 | tmp_dir = "/tmp" 95 | suffix = str(random.random()) 96 | words = refs.keys() 97 | hyp_path = os.path.join(tmp_dir, 'hyp' + suffix) 98 | to_be_deleted = set() 99 | to_be_deleted.add(hyp_path) 100 | 101 | # Computing BLEU 102 | devnull = open(os.devnull, 'w') 103 | score = 0 104 | count = 0 105 | total_refs = 0 106 | total_hyps = 0 107 | for word in words: 108 | if word not in refs or word not in hyps: 109 | continue 110 | wrefs = refs[word] 111 | whyps = hyps[word] 112 | # write out references 113 | all_ref_paths = [] 114 | for i, d in enumerate(wrefs): 115 | ref_path = os.path.join(tmp_dir, 'ref' + suffix + str(i)) 116 | with open(ref_path, 'w') as ofp: 117 | ofp.write(d) 118 | all_ref_paths.append(ref_path) 119 | to_be_deleted.add(ref_path) 120 | total_refs += len(all_ref_paths) 121 | # score for each output 122 | micro_score = 0 123 | micro_count = 0 124 | if args.mode == "average": 125 | for d in whyps: 126 | rhscore = get_bleu_score( 127 | args.bleu_path, all_ref_paths, d, hyp_path) 128 | if rhscore is not None: 129 | micro_score += rhscore 130 | micro_count += 1 131 | elif args.mode == "random": 132 | d = random.choice(whyps) 133 | rhscore = get_bleu_score(args.bleu_path, all_ref_paths, d, hyp_path) 134 | if rhscore is not None: 135 | micro_score += rhscore 136 | micro_count += 1 137 | total_hyps += micro_count 138 | score += micro_score / micro_count 139 | count += 1 140 | devnull.close() 141 | 142 | # delete tmp files 143 | for f in to_be_deleted: 144 | os.remove(f) 145 | print("BLEU: ", score / count) 146 | print("NUM HYPS USED: ", total_hyps) 147 | print("NUM REFS USED: ", total_refs) 148 | -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | from source.datasets import LanguageModelingDataset, LanguageModelingCollate 2 | from source.datasets import DefinitionModelingDataset, DefinitionModelingCollate 3 | from source.datasets import Vocabulary 4 | from source.model import DefinitionModelingModel 5 | from source.constants import BOS 6 | from source.pipeline import test 7 | from source.pipeline import generate 8 | from torch.utils.data import DataLoader 9 | from tqdm import tqdm 10 | import argparse 11 | import json 12 | import torch 13 | 14 | parser = argparse.ArgumentParser(description='Script to evaluate model') 15 | parser.add_argument( 16 | "--params", type=str, required=True, 17 | help="path to saved model params" 18 | ) 19 | parser.add_argument( 20 | "--ckpt", type=str, required=True, 21 | help="path to saved model weights" 22 | ) 23 | parser.add_argument( 24 | "--datasplit", type=str, required=True, 25 | help="train, val or test set to evaluate on" 26 | ) 27 | parser.add_argument( 28 | "--type", type=str, required=True, 29 | help="compute ppl or bleu" 30 | ) 31 | parser.add_argument( 32 | "--wordlist", type=str, required=False, 33 | help="word list to evaluate on (by default all data will be used)" 34 | ) 35 | # params for BLEU 36 | parser.add_argument( 37 | "--tau", type=float, required=False, 38 | help="temperature to use in sampling" 39 | ) 40 | parser.add_argument( 41 | "--n", type=int, required=False, 42 | help="number of samples to generate" 43 | ) 44 | parser.add_argument( 45 | "--length", type=int, required=False, 46 | help="maximum length of generated samples" 47 | ) 48 | args = parser.parse_args() 49 | assert args.datasplit in ["train", "val", "test"], ("--datasplit must be " 50 | "train, val or test") 51 | assert args.type in ["ppl", "bleu"], ("--type must be ppl or bleu") 52 | 53 | with open(args.params, "r") as infile: 54 | model_params = json.load(infile) 55 | 56 | logfile = open(model_params["exp_dir"] + "eval_log", "a") 57 | #import sys 58 | #logfile = sys.stdout 59 | 60 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 61 | model = DefinitionModelingModel(model_params).to(device) 62 | model.load_state_dict(torch.load(args.ckpt)["state_dict"]) 63 | 64 | if model.params["pretrain"]: 65 | assert args.type == "ppl", "if --pretrain True => evaluate only ppl mode" 66 | if args.datasplit == "train": 67 | dataset = LanguageModelingDataset( 68 | file=model.params["train_lm"], 69 | vocab_path=model.params["voc"], 70 | bptt=model.params["bptt"], 71 | ) 72 | elif args.datasplit == "val": 73 | dataset = LanguageModelingDataset( 74 | file=model.params["eval_lm"], 75 | vocab_path=model.params["voc"], 76 | bptt=model.params["bptt"], 77 | ) 78 | elif args.datasplit == "test": 79 | dataset = LanguageModelingDataset( 80 | file=model.params["test_lm"], 81 | vocab_path=model.params["voc"], 82 | bptt=model.params["bptt"], 83 | ) 84 | dataloader = DataLoader( 85 | dataset, batch_size=model.params["batch_size"], 86 | collate_fn=LanguageModelingCollate 87 | ) 88 | else: 89 | if args.datasplit == "train": 90 | dataset = DefinitionModelingDataset( 91 | file=model.params["train_defs"], 92 | vocab_path=model.params["voc"], 93 | input_vectors_path=model.params["input_train"], 94 | input_adaptive_vectors_path=model.params["input_adaptive_train"], 95 | context_vocab_path=model.params["context_voc"], 96 | ch_vocab_path=model.params["ch_voc"], 97 | use_seed=model.params["use_seed"], 98 | wordlist_path=args.wordlist 99 | ) 100 | elif args.datasplit == "val": 101 | dataset = DefinitionModelingDataset( 102 | file=model.params["eval_defs"], 103 | vocab_path=model.params["voc"], 104 | input_vectors_path=model.params["input_eval"], 105 | input_adaptive_vectors_path=model.params["input_adaptive_eval"], 106 | context_vocab_path=model.params["context_voc"], 107 | ch_vocab_path=model.params["ch_voc"], 108 | use_seed=model.params["use_seed"], 109 | wordlist_path=args.wordlist 110 | ) 111 | elif args.datasplit == "test": 112 | dataset = DefinitionModelingDataset( 113 | file=model.params["test_defs"], 114 | vocab_path=model.params["voc"], 115 | input_vectors_path=model.params["input_test"], 116 | input_adaptive_vectors_path=model.params["input_adaptive_test"], 117 | context_vocab_path=model.params["context_voc"], 118 | ch_vocab_path=model.params["ch_voc"], 119 | use_seed=model.params["use_seed"], 120 | wordlist_path=args.wordlist 121 | ) 122 | dataloader = DataLoader( 123 | dataset, 124 | batch_size=1 if args.type == "bleu" else model.params["batch_size"], 125 | collate_fn=DefinitionModelingCollate 126 | ) 127 | if args.type == "ppl": 128 | eval_ppl = test(dataloader, model, device, logfile) 129 | else: 130 | assert args.tau is not None, "--tau is required if --type bleu" 131 | assert args.n is not None, "--n is required if --type bleu" 132 | assert args.length is not None, "--length is required if --type bleu" 133 | defsave = open( 134 | model.params["exp_dir"] + "generated_" + 135 | args.datasplit + "_tau=" + 136 | str(args.tau) + "_n=" + str(args.n) + 137 | "_length=" + str(args.length) + ".txt", 138 | "w" 139 | ) 140 | refsave = open( 141 | model.params["exp_dir"] + "refs_" + args.datasplit + ".txt", 142 | "w" 143 | ) 144 | #defsave = sys.stdout 145 | voc = Vocabulary() 146 | voc.load(model.params["voc"]) 147 | to_input = { 148 | "model": model, 149 | "voc": voc, 150 | "tau": args.tau, 151 | "n": args.n, 152 | "length": args.length, 153 | "device": device, 154 | } 155 | if model.is_attn: 156 | context_voc = Vocabulary() 157 | context_voc.load(model.params["context_voc"]) 158 | to_input["context_voc"] = context_voc 159 | if model.params["use_ch"]: 160 | ch_voc = Vocabulary() 161 | ch_voc.load(model.params["ch_voc"]) 162 | to_input["ch_voc"] = ch_voc 163 | for i in tqdm(range(len(dataset)), file=logfile): 164 | if model.is_w2v: 165 | to_input["input"] = torch.from_numpy(dataset.input_vectors[i]) 166 | if model.is_ada: 167 | to_input["input"] = torch.from_numpy( 168 | dataset.input_adaptive_vectors[i] 169 | ) 170 | if model.is_attn: 171 | to_input["word"] = dataset.data[i][0][0] 172 | to_input["context"] = " ".join(dataset.data[i][2]) 173 | if model.params["use_ch"]: 174 | to_input["CH_word"] = dataset.data[i][0][0] 175 | if model.params["use_seed"]: 176 | to_input["prefix"] = dataset.data[i][0][0] 177 | else: 178 | to_input["prefix"] = BOS 179 | defsave.write( 180 | "Word: {0}\nContext: {1}\n".format( 181 | dataset.data[i][0][0], 182 | " ".join(dataset.data[i][2]) 183 | ) 184 | ) 185 | defsave.write(generate(**to_input) + "\n") 186 | refsave.write( 187 | "Word: {0}\nContext: {1}\nDefinition: {2}\n".format( 188 | dataset.data[i][0][0], 189 | " ".join(dataset.data[i][2]), 190 | " ".join(dataset.data[i][1]) 191 | ) 192 | ) 193 | defsave.flush() 194 | logfile.flush() 195 | refsave.flush() 196 | -------------------------------------------------------------------------------- /generate.py: -------------------------------------------------------------------------------- 1 | from source.model import DefinitionModelingModel 2 | from source.pipeline import generate 3 | from source.datasets import Vocabulary 4 | from source.utils import prepare_ada_vectors_from_python, prepare_w2v_vectors 5 | from source.constants import BOS 6 | import argparse 7 | import torch 8 | import json 9 | 10 | parser = argparse.ArgumentParser(description='Script to generate using model') 11 | parser.add_argument( 12 | "--params", type=str, required=True, 13 | help="path to saved model params" 14 | ) 15 | parser.add_argument( 16 | "--ckpt", type=str, required=True, 17 | help="path to saved model weights" 18 | ) 19 | parser.add_argument( 20 | "--tau", type=float, required=True, 21 | help="temperature to use in sampling" 22 | ) 23 | parser.add_argument( 24 | "--n", type=int, required=True, 25 | help="number of samples to generate" 26 | ) 27 | parser.add_argument( 28 | "--length", type=int, required=True, 29 | help="maximum length of generated samples" 30 | ) 31 | parser.add_argument( 32 | "--prefix", type=str, required=False, 33 | help="prefix to read until generation starts" 34 | ) 35 | parser.add_argument( 36 | "--wordlist", type=str, required=False, 37 | help="path to word list with words and contexts" 38 | ) 39 | parser.add_argument( 40 | "--w2v_binary_path", type=str, required=False, 41 | help="path to binary w2v file" 42 | ) 43 | parser.add_argument( 44 | "--ada_binary_path", type=str, required=False, 45 | help="path to binary ada file" 46 | ) 47 | parser.add_argument( 48 | "--prep_ada_path", type=str, required=False, 49 | help="path to prep_ada.jl script" 50 | ) 51 | args = parser.parse_args() 52 | 53 | with open(args.params, "r") as infile: 54 | model_params = json.load(infile) 55 | 56 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 57 | model = DefinitionModelingModel(model_params).to(device) 58 | model.load_state_dict(torch.load(args.ckpt)["state_dict"]) 59 | voc = Vocabulary() 60 | voc.load(model_params["voc"]) 61 | to_input = { 62 | "model": model, 63 | "voc": voc, 64 | "tau": args.tau, 65 | "n": args.n, 66 | "length": args.length, 67 | "device": device, 68 | } 69 | if model.params["pretrain"]: 70 | to_input["prefix"] = args.prefix 71 | print(generate(**to_input)) 72 | else: 73 | assert args.wordlist is not None, ("to generate definitions in --pretrain " 74 | "False mode --wordlist is required") 75 | 76 | with open(args.wordlist, "r") as infile: 77 | data = infile.readlines() 78 | 79 | if model.is_w2v: 80 | assert args.w2v_binary_path is not None, ("model.is_w2v True => " 81 | "--w2v_binary_path is " 82 | "required") 83 | input_vecs = torch.from_numpy( 84 | prepare_w2v_vectors(args.wordlist, args.w2v_binary_path) 85 | ) 86 | if model.is_ada: 87 | assert args.ada_binary_path is not None, ("model.is_ada True => " 88 | "--ada_binary_path is " 89 | "required") 90 | assert args.prep_ada_path is not None, ("model.is_ada True => " 91 | "--prep_ada_path is " 92 | "required") 93 | input_vecs = torch.from_numpy( 94 | prepare_ada_vectors_from_python( 95 | args.wordlist, 96 | args.prep_ada_path, 97 | args.ada_binary_path 98 | ) 99 | ) 100 | if model.is_attn: 101 | context_voc = Vocabulary() 102 | context_voc.load(model.params["context_voc"]) 103 | to_input["context_voc"] = context_voc 104 | if model.params["use_ch"]: 105 | ch_voc = Vocabulary() 106 | ch_voc.load(model.params["ch_voc"]) 107 | to_input["ch_voc"] = ch_voc 108 | for i in range(len(data)): 109 | word, context = data[i].split('\t') 110 | context = context.strip() 111 | if model.is_w2v or model.is_ada: 112 | to_input["input"] = input_vecs[i] 113 | if model.is_attn: 114 | to_input["word"] = word 115 | to_input["context"] = context 116 | if model.params["use_ch"]: 117 | to_input["CH_word"] = word 118 | if model.params["use_seed"]: 119 | to_input["prefix"] = word 120 | else: 121 | to_input["prefix"] = BOS 122 | print("Word: {0}".format(word)) 123 | print("Context: {0}".format(context)) 124 | print(generate(**to_input)) 125 | -------------------------------------------------------------------------------- /prep_ada.jl: -------------------------------------------------------------------------------- 1 | using ArgParse 2 | using AdaGram 3 | using JSON 4 | using NPZ 5 | 6 | function main(args) 7 | 8 | s = ArgParseSettings(description = "Prepare word vectors for Input-Adaptive conditioning") 9 | 10 | @add_arg_table s begin 11 | "--defs" 12 | nargs = '+' 13 | arg_type = String 14 | required = true 15 | help = "location of json file with definitions." 16 | "--save" 17 | nargs = '+' 18 | arg_type = String 19 | required = true 20 | help = "where to save files" 21 | "--ada" 22 | arg_type = String 23 | required = true 24 | help = "location of AdaGram file" 25 | end 26 | 27 | parsed_args = parse_args(s) 28 | if length(parsed_args["defs"]) != length(parsed_args["save"]) 29 | error("Number of defs files must match number of save locations") 30 | end 31 | 32 | vm, dict = load_model(parsed_args["ada"]); 33 | for i = 1:length(parsed_args["defs"]) 34 | open(parsed_args["defs"][i], "r") do f 35 | global definitions = JSON.parse(readstring(f)) 36 | end 37 | global vectors = zeros(length(definitions), length(vm.In[:, 1, 1])) 38 | for (k, elem) in enumerate(definitions) 39 | if haskey(dict.word2id, elem[1][1]) 40 | global good_context = [] 41 | for w in elem[3] 42 | if haskey(dict.word2id, w) 43 | push!(good_context, w) 44 | end 45 | end 46 | mxval, mxidx = findmax(disambiguate(vm, dict, elem[1][1], split(join(good_context, " ")))) 47 | vectors[k, :] = vm.In[:, mxidx, dict.word2id[elem[1][1]]] 48 | end 49 | end 50 | npzwrite(parsed_args["save"][i], vectors) 51 | end 52 | 53 | end 54 | 55 | main(ARGS) -------------------------------------------------------------------------------- /prep_embedding_matrix.py: -------------------------------------------------------------------------------- 1 | from source.datasets import Vocabulary 2 | import argparse 3 | from gensim.models import KeyedVectors 4 | import torch 5 | import numpy as np 6 | 7 | parser = argparse.ArgumentParser( 8 | description='Prepare word vectors for embedding layer in the model' 9 | ) 10 | parser.add_argument( 11 | "--voc", type=str, required=True, 12 | help="location of model vocabulary file" 13 | ) 14 | parser.add_argument( 15 | "--w2v", type=str, required=True, 16 | help="location of binary w2v file" 17 | ) 18 | parser.add_argument( 19 | "--save", type=str, required=True, 20 | help="where to save prepaired matrix" 21 | ) 22 | args = parser.parse_args() 23 | word_vectors = KeyedVectors.load_word2vec_format(args.w2v, binary=True) 24 | voc = Vocabulary() 25 | voc.load(args.voc) 26 | vecs = [] 27 | initrange = 0.5 / word_vectors.vector_size 28 | for key in voc.tok2id.keys(): 29 | if key in word_vectors: 30 | vecs.append(word_vectors[key]) 31 | else: 32 | vecs.append( 33 | np.random.uniform( 34 | low=-initrange, 35 | high=initrange, 36 | size=word_vectors.vector_size) 37 | ) 38 | torch.save(torch.from_numpy(np.array(vecs)).float(), args.save) 39 | -------------------------------------------------------------------------------- /prep_vocab.py: -------------------------------------------------------------------------------- 1 | from source.datasets import Vocabulary 2 | import argparse 3 | import json 4 | 5 | parser = argparse.ArgumentParser(description='Prepare vocabularies for model') 6 | parser.add_argument( 7 | '--defs', type=str, required=True, nargs="+", 8 | help='location of json file with definitions.' 9 | ) 10 | parser.add_argument( 11 | "--lm", type=str, required=False, nargs="+", 12 | help="location of txt file with text for LM pre-training" 13 | ) 14 | parser.add_argument( 15 | '--same', dest='same', action='store_true', 16 | help="use same vocab for definitions and contexts" 17 | ) 18 | parser.set_defaults(same=False) 19 | parser.add_argument( 20 | "--save", type=str, required=True, 21 | help="where to save prepaired vocabulary (for words from definitions)" 22 | ) 23 | parser.add_argument( 24 | "--save_context", type=str, required=False, 25 | help="where to save vocabulary (for words from contexts)" 26 | ) 27 | parser.add_argument( 28 | "--save_chars", type=str, required=True, 29 | help="where to save char vocabulary (for chars from all words)" 30 | ) 31 | args = parser.parse_args() 32 | if not args.same and args.save_context is None: 33 | parser.error("--save_context required if --same didn't used") 34 | 35 | 36 | voc = Vocabulary() 37 | char_voc = Vocabulary() 38 | if not args.same: 39 | context_voc = Vocabulary() 40 | 41 | definitions = [] 42 | for f in args.defs: 43 | with open(f, "r") as infile: 44 | definitions.extend(json.load(infile)) 45 | 46 | if args.lm is not None: 47 | lm_texts = "" 48 | for f in args.lm: 49 | lm_texts = lm_texts + open(f).read().lower() + " " 50 | lm_texts = lm_texts.split() 51 | 52 | for word in lm_texts: 53 | voc.add_token(word) 54 | 55 | for elem in definitions: 56 | voc.add_token(elem[0][0]) 57 | char_voc.tok_maxlen = max(len(elem[0][0]), char_voc.tok_maxlen) 58 | for c in elem[0][0]: 59 | char_voc.add_token(c) 60 | for i in range(len(elem[1])): 61 | voc.add_token(elem[1][i]) 62 | if args.same: 63 | for i in range(len(elem[2])): 64 | voc.add_token(elem[2][i]) 65 | else: 66 | context_voc.add_token(elem[0][0]) 67 | for i in range(len(elem[2])): 68 | context_voc.add_token(elem[2][i]) 69 | 70 | 71 | voc.save(args.save) 72 | char_voc.save(args.save_chars) 73 | if not args.same: 74 | context_voc.save(args.save_context) 75 | -------------------------------------------------------------------------------- /prep_w2v.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from gensim.models import KeyedVectors 3 | import numpy as np 4 | import json 5 | 6 | parser = argparse.ArgumentParser( 7 | description='Prepare word vectors for Input conditioning' 8 | ) 9 | 10 | parser.add_argument( 11 | '--defs', type=str, required=True, nargs="+", 12 | help='location of json file with definitions.' 13 | ) 14 | 15 | parser.add_argument( 16 | '--save', type=str, required=True, nargs="+", 17 | help='where to save files' 18 | ) 19 | 20 | parser.add_argument( 21 | "--w2v", type=str, required=True, 22 | help="location of binary w2v file" 23 | ) 24 | args = parser.parse_args() 25 | 26 | if len(args.defs) != len(args.save): 27 | parser.error("Number of defs files must match number of save locations") 28 | 29 | word_vectors = KeyedVectors.load_word2vec_format(args.w2v, binary=True) 30 | for i in range(len(args.defs)): 31 | vectors = [] 32 | with open(args.defs[i], "r") as infile: 33 | definitions = json.load(infile) 34 | for elem in definitions: 35 | if elem[0][0] in word_vectors: 36 | vectors.append(word_vectors[elem[0][0]]) 37 | else: 38 | vectors.append(np.zeros(word_vectors.vector_size)) 39 | vectors = np.array(vectors) 40 | np.save(args.save[i], vectors) 41 | -------------------------------------------------------------------------------- /source/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/agadetsky/pytorch-definitions/03e7fb2e02c03ce5774f5e2cd174c7f224373a3e/source/__init__.py -------------------------------------------------------------------------------- /source/attention_skipgram.py: -------------------------------------------------------------------------------- 1 | from .layers import InputAttention 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | 7 | class AttentionSkipGram(nn.Module): 8 | 9 | def __init__(self, n_attn_tokens, n_attn_embsize, 10 | n_attn_hid, attn_dropout, sparse=False): 11 | super(AttentionSkipGram, self).__init__() 12 | self.n_attn_tokens = n_attn_tokens 13 | self.n_attn_embsize = n_attn_embsize 14 | self.n_attn_hid = n_attn_hid 15 | self.attn_dropout = attn_dropout 16 | self.sparse = sparse 17 | 18 | self.emb0_lookup = InputAttention( 19 | n_attn_tokens=self.n_attn_tokens, 20 | n_attn_embsize=self.n_attn_embsize, 21 | n_attn_hid=self.n_attn_hid, 22 | attn_dropout=self.attn_dropout, 23 | sparse=self.sparse 24 | ) 25 | self.emb1_lookup = nn.Embedding( 26 | num_embeddings=self.n_attn_tokens, 27 | embedding_dim=self.n_attn_embsize, 28 | sparse=self.sparse 29 | ) 30 | self.emb1_lookup.weight.data.zero_() 31 | 32 | def forward(self, words, context, neg): 33 | idx = torch.LongTensor(words.size(0), 1).random_( 34 | 0, context.size(1) 35 | ).to(words.device) 36 | labels = context.gather(1, idx).squeeze(1) 37 | 38 | w_embs = self.emb0_lookup(words, context) 39 | c_embs = self.emb1_lookup(labels) 40 | n_embs = self.emb1_lookup(neg) 41 | 42 | pos_ips = torch.sum(w_embs * c_embs, 1) 43 | neg_ips = torch.bmm( 44 | n_embs, torch.unsqueeze(w_embs, 1).permute(0, 2, 1) 45 | ).squeeze(2) 46 | 47 | # Neg Log Likelihood 48 | pos_loss = -torch.mean(F.logsigmoid(pos_ips)) 49 | neg_loss = -torch.mean(F.logsigmoid(-neg_ips).sum(1)) 50 | 51 | return pos_loss + neg_loss 52 | -------------------------------------------------------------------------------- /source/constants.py: -------------------------------------------------------------------------------- 1 | # FOR MODEL VOCABULARY 2 | PAD = '' 3 | UNK = '' 4 | BOS = '' 5 | EOS = '' 6 | PAD_IDX = 0 7 | UNK_IDX = 1 8 | BOS_IDX = 2 9 | EOS_IDX = 3 10 | -------------------------------------------------------------------------------- /source/datasets.py: -------------------------------------------------------------------------------- 1 | from . import constants 2 | from torch.utils.data import Dataset 3 | import json 4 | import numpy as np 5 | import math 6 | 7 | 8 | class Vocabulary: 9 | """Word/char vocabulary""" 10 | 11 | def __init__(self): 12 | self.tok2id = { 13 | constants.PAD: constants.PAD_IDX, 14 | constants.UNK: constants.UNK_IDX, 15 | constants.BOS: constants.BOS_IDX, 16 | constants.EOS: constants.EOS_IDX 17 | } 18 | self.id2tok = { 19 | constants.PAD_IDX: constants.PAD, 20 | constants.UNK_IDX: constants.UNK, 21 | constants.BOS_IDX: constants.BOS, 22 | constants.EOS_IDX: constants.EOS 23 | } 24 | 25 | # we need this for maxlen of word being definedR in CH conditioning 26 | self.tok_maxlen = -float("inf") 27 | 28 | def encode(self, tok): 29 | if tok in self.tok2id: 30 | return self.tok2id[tok] 31 | else: 32 | return constants.UNK_IDX 33 | 34 | def decode(self, idx): 35 | if idx in self.id2tok: 36 | return self.id2tok[idx] 37 | else: 38 | raise ValueError("No such idx: {0}".format(idx)) 39 | 40 | def encode_seq(self, arr): 41 | ret = [] 42 | for elem in arr: 43 | ret.append(self.encode(elem)) 44 | return ret 45 | 46 | def decode_seq(self, arr): 47 | ret = [] 48 | for elem in arr: 49 | ret.append(self.decode(elem)) 50 | return ret 51 | 52 | def add_token(self, tok): 53 | if tok not in self.tok2id: 54 | self.tok2id[tok] = len(self.tok2id) 55 | self.id2tok[len(self.id2tok)] = tok 56 | 57 | def save(self, path): 58 | with open(path, "w") as outfile: 59 | json.dump([self.id2tok, self.tok_maxlen], outfile, indent=4) 60 | 61 | def load(self, path): 62 | with open(path, "r") as infile: 63 | self.id2tok, self.tok_maxlen = json.load(infile) 64 | self.id2tok = {int(k): v for k, v in self.id2tok.items()} 65 | self.tok2id = {} 66 | for i in self.id2tok.keys(): 67 | self.tok2id[self.id2tok[i]] = i 68 | 69 | 70 | def pad(seq, size, value): 71 | if len(seq) < size: 72 | seq.extend([value] * (size - len(seq))) 73 | return seq 74 | 75 | 76 | class LanguageModelingDataset(Dataset): 77 | """LanguageModeling dataset.""" 78 | 79 | def __init__(self, file, vocab_path, bptt): 80 | """ 81 | Args: 82 | file (string): Path to the file 83 | vocab_path (string): path to word vocab to use 84 | bptt (int): length of one sentence 85 | """ 86 | with open(file, "r") as infile: 87 | self.data = infile.read().lower().split() 88 | self.voc = Vocabulary() 89 | self.voc.load(vocab_path) 90 | self.bptt = bptt 91 | 92 | def __len__(self): 93 | return math.ceil(len(self.data) / (self.bptt + 1)) 94 | 95 | def __getitem__(self, idx): 96 | i = idx + self.bptt * idx 97 | sample = { 98 | "x": self.voc.encode_seq(self.data[i: i + self.bptt]), 99 | "y": self.voc.encode_seq(self.data[i + 1: i + self.bptt + 1]), 100 | } 101 | return sample 102 | 103 | 104 | def LanguageModelingCollate(batch): 105 | batch_x = [] 106 | batch_y = [] 107 | maxlen = -float("inf") 108 | for i in range(len(batch)): 109 | batch_x.append(batch[i]["x"]) 110 | batch_y.append(batch[i]["y"]) 111 | maxlen = max(maxlen, len(batch[i]["x"]), len(batch[i]["y"])) 112 | 113 | for i in range(len(batch)): 114 | batch_x[i] = pad(batch_x[i], maxlen, constants.PAD_IDX) 115 | batch_y[i] = pad(batch_y[i], maxlen, constants.PAD_IDX) 116 | 117 | ret_batch = { 118 | "x": np.array(batch_x), 119 | "y": np.array(batch_y), 120 | } 121 | return ret_batch 122 | 123 | 124 | class DefinitionModelingDataset(Dataset): 125 | """DefinitionModeling dataset.""" 126 | 127 | def __init__(self, file, vocab_path, input_vectors_path=None, 128 | input_adaptive_vectors_path=None, context_vocab_path=None, 129 | ch_vocab_path=None, use_seed=False, wordlist_path=None): 130 | """ 131 | Args: 132 | file (string): path to the file 133 | vocab_path (string): path to word vocab to use 134 | input_vectors_path (string): path to vectors for Input conditioning 135 | input_adaptive_vectors_path (string): path to vectors for Input-Adaptive conditioning 136 | context_vocab_path (string): path to vocab for context words for Input-Attention 137 | ch_vocab_path (string): path to char vocab for CH conditioning 138 | use_seed (bool): whether to use Seed conditioning or not 139 | wordlist_path (string): path to wordlist with words 140 | """ 141 | with open(file, "r") as infile: 142 | self.data = json.load(infile) 143 | self.voc = Vocabulary() 144 | self.voc.load(vocab_path) 145 | if context_vocab_path is not None: 146 | self.context_voc = Vocabulary() 147 | self.context_voc.load(context_vocab_path) 148 | if ch_vocab_path is not None: 149 | self.ch_voc = Vocabulary() 150 | self.ch_voc.load(ch_vocab_path) 151 | if input_vectors_path is not None: 152 | self.input_vectors = np.load(input_vectors_path).astype(np.float32) 153 | if input_adaptive_vectors_path is not None: 154 | self.input_adaptive_vectors = np.load( 155 | input_adaptive_vectors_path 156 | ).astype(np.float32) 157 | if wordlist_path is not None: 158 | wordlist = set( 159 | [elem.strip() for elem in open(wordlist_path, "r").readlines()] 160 | ) 161 | data = [] 162 | if input_vectors_path is not None: 163 | input_vectors = [] 164 | if input_adaptive_vectors_path is not None: 165 | input_adaptive_vectors = [] 166 | for i in range(len(self.data)): 167 | if self.data[i][0][0] in wordlist: 168 | data.append(self.data[i]) 169 | if input_vectors_path is not None: 170 | input_vectors.append( 171 | self.input_vectors[i] 172 | ) 173 | if input_adaptive_vectors_path is not None: 174 | input_adaptive_vectors.append( 175 | self.input_adaptive_vectors[i] 176 | ) 177 | assert len(data) > 0, "You provided bad wordlist, no words found" 178 | if input_vectors_path is not None: 179 | self.input_vectors = np.array(input_vectors).astype(np.float32) 180 | if input_adaptive_vectors_path is not None: 181 | self.input_adaptive_vectors = np.array( 182 | input_adaptive_vectors 183 | ).astype(np.float32) 184 | self.data = data 185 | self.use_seed = use_seed 186 | 187 | def __len__(self): 188 | return len(self.data) 189 | 190 | def __getitem__(self, idx): 191 | sample = { 192 | "x": self.voc.encode_seq(self.data[idx][1]), 193 | "y": self.voc.encode_seq(self.data[idx][1][1:] + [constants.EOS]), 194 | } 195 | if hasattr(self, "input_vectors"): 196 | sample["input"] = self.input_vectors[idx] 197 | if hasattr(self, "input_adaptive_vectors"): 198 | sample["input_adaptive"] = self.input_adaptive_vectors[idx] 199 | if hasattr(self, "context_voc"): 200 | sample["word"] = self.context_voc.encode(self.data[idx][0][0]) 201 | sample["context"] = self.context_voc.encode_seq(self.data[idx][2]) 202 | if hasattr(self, "ch_voc"): 203 | sample["CH"] = [constants.BOS_IDX] + \ 204 | self.ch_voc.encode_seq(list(self.data[idx][0][0])) + \ 205 | [constants.EOS_IDX] 206 | # CH_maxlen: +2 because EOS + BOS 207 | sample["CH_maxlen"] = self.ch_voc.tok_maxlen + 2 208 | if self.use_seed: 209 | sample["y"] = [sample["x"][0]] + sample["y"] 210 | sample["x"] = self.voc.encode_seq(self.data[idx][0]) + sample["x"] 211 | return sample 212 | 213 | 214 | def DefinitionModelingCollate(batch): 215 | batch_x = [] 216 | batch_y = [] 217 | is_w2v = "input" in batch[0] 218 | is_ada = "input_adaptive" in batch[0] 219 | is_attn = "word" in batch[0] and "context" in batch[0] 220 | is_ch = "CH" in batch[0] and "CH_maxlen" in batch[0] 221 | if is_w2v: 222 | batch_input = [] 223 | if is_ada: 224 | batch_input_adaptive = [] 225 | if is_attn: 226 | batch_word = [] 227 | batch_context = [] 228 | context_maxlen = -float("inf") 229 | if is_ch: 230 | batch_ch = [] 231 | CH_maxlen = batch[0]["CH_maxlen"] 232 | 233 | definition_lengths = [] 234 | for i in range(len(batch)): 235 | batch_x.append(batch[i]["x"]) 236 | batch_y.append(batch[i]["y"]) 237 | if is_w2v: 238 | batch_input.append(batch[i]["input"]) 239 | if is_ada: 240 | batch_input_adaptive.append(batch[i]["input_adaptive"]) 241 | if is_attn: 242 | batch_word.append(batch[i]["word"]) 243 | batch_context.append(batch[i]["context"]) 244 | context_maxlen = max(context_maxlen, len(batch_context[-1])) 245 | if is_ch: 246 | batch_ch.append(batch[i]["CH"]) 247 | definition_lengths.append(len(batch_x[-1])) 248 | 249 | definition_maxlen = max(definition_lengths) 250 | 251 | for i in range(len(batch)): 252 | batch_x[i] = pad(batch_x[i], definition_maxlen, constants.PAD_IDX) 253 | batch_y[i] = pad(batch_y[i], definition_maxlen, constants.PAD_IDX) 254 | if is_attn: 255 | batch_context[i] = pad( 256 | batch_context[i], context_maxlen, constants.PAD_IDX 257 | ) 258 | if is_ch: 259 | batch_ch[i] = pad(batch_ch[i], CH_maxlen, constants.PAD_IDX) 260 | 261 | order = np.argsort(definition_lengths)[::-1] 262 | batch_x = np.array(batch_x)[order] 263 | batch_y = np.array(batch_y)[order] 264 | ret_batch = { 265 | "x": batch_x, 266 | "y": batch_y, 267 | } 268 | if is_w2v: 269 | batch_input = np.array(batch_input, dtype=np.float32)[order] 270 | ret_batch["input"] = batch_input 271 | if is_ada: 272 | batch_input_adaptive = np.array( 273 | batch_input_adaptive, 274 | dtype=np.float32 275 | )[order] 276 | ret_batch["input_adaptive"] = batch_input_adaptive 277 | if is_attn: 278 | batch_word = np.array(batch_word)[order] 279 | batch_context = np.array(batch_context)[order] 280 | ret_batch["word"] = batch_word 281 | ret_batch["context"] = batch_context 282 | if is_ch: 283 | batch_ch = np.array(batch_ch)[order] 284 | ret_batch["CH"] = batch_ch 285 | 286 | return ret_batch 287 | -------------------------------------------------------------------------------- /source/layers.py: -------------------------------------------------------------------------------- 1 | from . import constants 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | 7 | class Input(nn.Module): 8 | """ 9 | Class for Input or Input - Adaptive or dummy conditioning 10 | """ 11 | 12 | def __init__(self): 13 | super(Input, self).__init__() 14 | 15 | def forward(self, x): 16 | """ 17 | Vectors are already prepaired in DataLoaders, so just return them 18 | """ 19 | return x 20 | 21 | 22 | class InputAttention(nn.Module): 23 | """ 24 | Class for Input Attention conditioning 25 | """ 26 | 27 | def __init__(self, n_attn_tokens, n_attn_embsize, 28 | n_attn_hid, attn_dropout, sparse=False): 29 | super(InputAttention, self).__init__() 30 | self.n_attn_tokens = n_attn_tokens 31 | self.n_attn_embsize = n_attn_embsize 32 | self.n_attn_hid = n_attn_hid 33 | self.attn_dropout = attn_dropout 34 | self.sparse = sparse 35 | 36 | self.embs = nn.Embedding( 37 | num_embeddings=self.n_attn_tokens, 38 | embedding_dim=self.n_attn_embsize, 39 | padding_idx=constants.PAD_IDX, 40 | sparse=self.sparse 41 | ) 42 | 43 | self.ann = nn.Sequential( 44 | nn.Dropout(p=self.attn_dropout), 45 | nn.Linear( 46 | in_features=self.n_attn_embsize, 47 | out_features=self.n_attn_hid 48 | ), 49 | nn.Tanh() 50 | ) # maybe use ReLU or other? 51 | 52 | self.a_linear = nn.Linear( 53 | in_features=self.n_attn_hid, 54 | out_features=self.n_attn_embsize 55 | ) 56 | 57 | def forward(self, word, context): 58 | x_embs = self.embs(word) 59 | mask = self.get_mask(context) 60 | return mask * x_embs 61 | 62 | def get_mask(self, context): 63 | context_embs = self.embs(context) 64 | lengths = (context != constants.PAD_IDX) 65 | for_sum_mask = lengths.unsqueeze(2).float() 66 | lengths = lengths.sum(1).float().view(-1, 1) 67 | logits = self.a_linear( 68 | (self.ann(context_embs) * for_sum_mask).sum(1) / lengths 69 | ) 70 | return F.sigmoid(logits) 71 | 72 | def init_attn(self, freeze): 73 | initrange = 0.5 / self.n_attn_embsize 74 | with torch.no_grad(): 75 | nn.init.uniform_(self.embs.weight, -initrange, initrange) 76 | nn.init.xavier_uniform_(self.a_linear.weight) 77 | nn.init.constant_(self.a_linear.bias, 0) 78 | nn.init.xavier_uniform_(self.ann[1].weight) 79 | nn.init.constant_(self.ann[1].bias, 0) 80 | self.embs.weight.requires_grad = not freeze 81 | 82 | def init_attn_from_pretrained(self, weights, freeze): 83 | self.load_state_dict(weights) 84 | self.embs.weight.requires_grad = not freeze 85 | 86 | 87 | class CharCNN(nn.Module): 88 | """ 89 | Class for CH conditioning 90 | """ 91 | 92 | def __init__(self, n_ch_tokens, ch_maxlen, ch_emb_size, 93 | ch_feature_maps, ch_kernel_sizes): 94 | super(CharCNN, self).__init__() 95 | assert len(ch_feature_maps) == len(ch_kernel_sizes) 96 | 97 | self.n_ch_tokens = n_ch_tokens 98 | self.ch_maxlen = ch_maxlen 99 | self.ch_emb_size = ch_emb_size 100 | self.ch_feature_maps = ch_feature_maps 101 | self.ch_kernel_sizes = ch_kernel_sizes 102 | 103 | self.feature_mappers = nn.ModuleList() 104 | for i in range(len(self.ch_feature_maps)): 105 | reduced_length = self.ch_maxlen - self.ch_kernel_sizes[i] + 1 106 | self.feature_mappers.append( 107 | nn.Sequential( 108 | nn.Conv2d( 109 | in_channels=1, 110 | out_channels=self.ch_feature_maps[i], 111 | kernel_size=( 112 | self.ch_kernel_sizes[i], 113 | self.ch_emb_size 114 | ) 115 | ), 116 | nn.Tanh(), 117 | nn.MaxPool2d(kernel_size=(reduced_length, 1)) 118 | ) 119 | ) 120 | 121 | self.embs = nn.Embedding( 122 | self.n_ch_tokens, 123 | self.ch_emb_size, 124 | padding_idx=constants.PAD_IDX 125 | ) 126 | 127 | def forward(self, x): 128 | # x - [batch_size x maxlen] 129 | bsize, length = x.size() 130 | assert length == self.ch_maxlen 131 | x_embs = self.embs(x).view(bsize, 1, self.ch_maxlen, self.ch_emb_size) 132 | 133 | cnn_features = [] 134 | for i in range(len(self.ch_feature_maps)): 135 | cnn_features.append( 136 | self.feature_mappers[i](x_embs).view(bsize, -1) 137 | ) 138 | 139 | return torch.cat(cnn_features, dim=1) 140 | 141 | def init_ch(self): 142 | initrange = 0.5 / self.ch_emb_size 143 | with torch.no_grad(): 144 | nn.init.uniform_(self.embs.weight, -initrange, initrange) 145 | for name, p in self.feature_mappers.named_parameters(): 146 | if "bias" in name: 147 | nn.init.constant_(p, 0) 148 | elif "weight" in name: 149 | nn.init.xavier_uniform_(p) 150 | 151 | 152 | class Hidden(nn.Module): 153 | """ 154 | Class for Hidden conditioning 155 | """ 156 | 157 | def __init__(self, cond_size, hidden_size, out_size): 158 | super(Hidden, self).__init__() 159 | self.cond_size = cond_size 160 | self.hidden_size = hidden_size 161 | self.out_size = out_size 162 | self.in_size = self.cond_size + self.hidden_size 163 | self.linear = nn.Linear( 164 | in_features=self.in_size, 165 | out_features=self.out_size 166 | ) 167 | 168 | def forward(self, hidden, conds): 169 | seqlen = hidden.size(1) # batch_first=True 170 | repeated_conds = conds.view(-1).repeat(seqlen) 171 | repeated_conds = repeated_conds.view(seqlen, *conds.size()) 172 | repeated_conds = repeated_conds.permute( 173 | 1, 0, 2 174 | ) # batchsize x seqlen x cond_dim 175 | concat = torch.cat( 176 | [repeated_conds, hidden], dim=2 177 | ) # concat by last dim 178 | return F.tanh(self.linear(concat)) 179 | 180 | def init_hidden(self): 181 | with torch.no_grad(): 182 | nn.init.xavier_uniform_(self.linear.weight) 183 | nn.init.constant_(self.linear.bias, 0) 184 | 185 | 186 | class Gated(nn.Module): 187 | """ 188 | Class for Gated conditioning 189 | """ 190 | 191 | def __init__(self, cond_size, hidden_size): 192 | super(Gated, self).__init__() 193 | self.cond_size = cond_size 194 | self.hidden_size = hidden_size 195 | self.in_size = self.cond_size + self.hidden_size 196 | self.linear1 = nn.Linear( 197 | in_features=self.in_size, 198 | out_features=self.hidden_size 199 | ) 200 | self.linear2 = nn.Linear( 201 | in_features=self.in_size, 202 | out_features=self.cond_size 203 | ) 204 | self.linear3 = nn.Linear( 205 | in_features=self.in_size, 206 | out_features=self.hidden_size 207 | ) 208 | 209 | def forward(self, hidden, conds): 210 | seqlen = hidden.size(1) # batch_first=True 211 | repeated_conds = conds.view(-1).repeat(seqlen) 212 | repeated_conds = repeated_conds.view(seqlen, *conds.size()) 213 | repeated_conds = repeated_conds.permute( 214 | 1, 0, 2 215 | ) # batchsize x seqlen x cond_dim 216 | concat = torch.cat( 217 | [repeated_conds, hidden], dim=2 218 | ) # concat by last dim 219 | z_t = F.sigmoid(self.linear1(concat)) 220 | r_t = F.sigmoid(self.linear2(concat)) 221 | masked_concat = torch.cat( 222 | [repeated_conds * r_t, hidden], dim=2 223 | ) 224 | hat_s_t = F.tanh(self.linear3(masked_concat)) 225 | return (1 - z_t) * hidden + z_t * hat_s_t 226 | 227 | def init_gated(self): 228 | with torch.no_grad(): 229 | nn.init.xavier_uniform_(self.linear1.weight) 230 | nn.init.xavier_uniform_(self.linear2.weight) 231 | nn.init.xavier_uniform_(self.linear3.weight) 232 | nn.init.constant_(self.linear1.bias, 0) 233 | nn.init.constant_(self.linear2.bias, 0) 234 | nn.init.constant_(self.linear3.bias, 0) 235 | -------------------------------------------------------------------------------- /source/model.py: -------------------------------------------------------------------------------- 1 | from . import constants 2 | from .layers import Input, InputAttention, CharCNN, Hidden, Gated 3 | import torch 4 | import torch.nn as nn 5 | from torch.nn.utils.rnn import pad_packed_sequence as unpack 6 | from torch.nn.utils.rnn import pack_padded_sequence as pack 7 | 8 | 9 | class DefinitionModelingModel(nn.Module): 10 | """Definition modeling class""" 11 | 12 | def __init__(self, params): 13 | super(DefinitionModelingModel, self).__init__() 14 | self.params = params 15 | 16 | self.embs = nn.Embedding( 17 | num_embeddings=self.params["ntokens"], 18 | embedding_dim=self.params["nx"], 19 | padding_idx=constants.PAD_IDX 20 | ) 21 | self.dropout = nn.Dropout(p=self.params["rnn_dropout"]) 22 | 23 | self.n_rnn_input = self.params["nx"] 24 | if not self.params["pretrain"]: 25 | self.input_used = self.params["use_input"] 26 | self.input_used += self.params["use_input_adaptive"] 27 | self.input_used += self.params["use_input_attention"] 28 | self.hidden_used = self.params["use_hidden"] 29 | self.hidden_used += self.params["use_hidden_adaptive"] 30 | self.hidden_used += self.params["use_hidden_attention"] 31 | self.gated_used = self.params["use_gated"] 32 | self.gated_used += self.params["use_gated_adaptive"] 33 | self.gated_used += self.params["use_gated_attention"] 34 | self.is_w2v = self.params["use_input"] 35 | self.is_w2v += self.params["use_hidden"] 36 | self.is_w2v += self.params["use_gated"] 37 | self.is_ada = self.params["use_input_adaptive"] 38 | self.is_ada += self.params["use_hidden_adaptive"] 39 | self.is_ada += self.params["use_gated_adaptive"] 40 | self.is_attn = self.params["use_input_attention"] 41 | self.is_attn += self.params["use_hidden_attention"] 42 | self.is_attn += self.params["use_gated_attention"] 43 | self.is_conditioned = self.input_used 44 | self.is_conditioned += self.hidden_used 45 | self.is_conditioned += self.gated_used 46 | 47 | # check if either Input* or Hidden/Gated conditioning are used not 48 | # both 49 | assert self.input_used + self.hidden_used + \ 50 | self.gated_used <= 1, "Too many conditionings used" 51 | 52 | if not self.is_conditioned and self.params["use_ch"]: 53 | raise ValueError("Don't use CH conditioning without others") 54 | 55 | self.cond_size = 0 56 | if self.is_w2v: 57 | self.input = Input() 58 | self.cond_size += self.params["input_dim"] 59 | elif self.is_ada: 60 | self.input_adaptive = Input() 61 | self.cond_size += self.params["input_adaptive_dim"] 62 | elif self.is_attn: 63 | self.input_attention = InputAttention( 64 | n_attn_tokens=self.params["n_attn_tokens"], 65 | n_attn_embsize=self.params["n_attn_embsize"], 66 | n_attn_hid=self.params["n_attn_hid"], 67 | attn_dropout=self.params["attn_dropout"], 68 | sparse=self.params["attn_sparse"] 69 | ) 70 | self.cond_size += self.params["n_attn_embsize"] 71 | 72 | if self.params["use_ch"]: 73 | self.ch = CharCNN( 74 | n_ch_tokens=self.params["n_ch_tokens"], 75 | ch_maxlen=self.params["ch_maxlen"], 76 | ch_emb_size=self.params["ch_emb_size"], 77 | ch_feature_maps=self.params["ch_feature_maps"], 78 | ch_kernel_sizes=self.params["ch_kernel_sizes"] 79 | ) 80 | self.cond_size += sum(self.params["ch_feature_maps"]) 81 | 82 | if self.input_used: 83 | self.n_rnn_input += self.cond_size 84 | 85 | if self.hidden_used: 86 | self.hidden = Hidden( 87 | cond_size=self.cond_size, 88 | hidden_size=self.params["nhid"], 89 | out_size=self.params["nhid"] 90 | ) 91 | elif self.gated_used: 92 | self.gated = Gated( 93 | cond_size=self.cond_size, 94 | hidden_size=self.params["nhid"] 95 | ) 96 | 97 | self.rnn = nn.LSTM( 98 | input_size=self.n_rnn_input, 99 | hidden_size=self.params["nhid"], 100 | num_layers=self.params["nlayers"], 101 | batch_first=True, 102 | dropout=self.params["rnn_dropout"] 103 | ) 104 | self.linear = nn.Linear( 105 | in_features=self.params["nhid"], 106 | out_features=self.params["ntokens"] 107 | ) 108 | 109 | self.init_weights() 110 | 111 | def forward(self, x, input=None, word=None, context=None, CH_word=None, hidden=None): 112 | """ 113 | x - definitions/LM_sequence to read 114 | input - vectors for Input, Input-Adaptive or dummy conditioning 115 | word - words for Input-Attention conditioning 116 | context - contexts for Input-Attention conditioning 117 | CH_word - words for CH conditioning 118 | hidden - hidden states of RNN 119 | """ 120 | lengths = (x != constants.PAD_IDX).sum(dim=1).detach() 121 | maxlen = lengths.max().item() 122 | embs = self.embs(x) 123 | if not self.params["pretrain"]: 124 | all_conds = [] 125 | if self.is_w2v: 126 | all_conds.append(self.input(input)) 127 | elif self.is_ada: 128 | all_conds.append(self.input_adaptive(input)) 129 | elif self.is_attn: 130 | all_conds.append(self.input_attention(word, context)) 131 | if self.params["use_ch"]: 132 | all_conds.append(self.ch(CH_word)) 133 | if self.is_conditioned: 134 | all_conds = torch.cat(all_conds, dim=1) 135 | 136 | if self.input_used: 137 | repeated_conds = all_conds.view(-1).repeat(maxlen) 138 | repeated_conds = repeated_conds.view(maxlen, *all_conds.size()) 139 | repeated_conds = repeated_conds.permute(1, 0, 2) 140 | embs = torch.cat([repeated_conds, embs], dim=-1) 141 | 142 | embs = pack(embs, lengths, batch_first=True) 143 | output, hidden = self.rnn(embs, hidden) 144 | output = unpack(output, batch_first=True)[0] 145 | output = self.dropout(output) 146 | 147 | if not self.params["pretrain"]: 148 | if self.hidden_used: 149 | output = self.hidden(output, all_conds) 150 | elif self.gated_used: 151 | output = self.gated(output, all_conds) 152 | 153 | decoded = self.linear( 154 | output.contiguous().view( 155 | output.size(0) * output.size(1), 156 | output.size(2) 157 | ) 158 | ) 159 | 160 | return decoded, hidden 161 | 162 | def init_embeddings(self, freeze): 163 | initrange = 0.5 / self.params["nx"] 164 | with torch.no_grad(): 165 | nn.init.uniform_(self.embs.weight, -initrange, initrange) 166 | self.embs.weight.requires_grad = not freeze 167 | 168 | def init_embeddings_from_pretrained(self, weights, freeze): 169 | self.embs = self.embs.from_pretrained(weights, freeze) 170 | 171 | def init_rnn(self): 172 | with torch.no_grad(): 173 | for name, p in self.rnn.named_parameters(): 174 | if "bias" in name: 175 | nn.init.constant_(p, 0) 176 | elif "weight" in name: 177 | nn.init.xavier_uniform_(p) 178 | 179 | def init_rnn_from_pretrained(self, weights): 180 | # k[4:] because we need to remove prefix "rnn." because 181 | # self.rnn.state_dict() is without "rnn." prefix 182 | correct_state_dict = { 183 | k[4:]: v for k, v in weights.items() if k[:4] == "rnn." 184 | } 185 | # also we need to correctly initialize weight_ih_l0 186 | # with pretrained weights because it has different size with 187 | # self.rnn.state_dict(), other weights has correct shapes if 188 | # hidden sizes have same shape as in the LM pretraining 189 | if self.input_used: 190 | w = torch.empty(4 * self.params["nhid"], self.n_rnn_input) 191 | nn.init.xavier_uniform_(w) 192 | w[:, self.cond_size:] = correct_state_dict["weight_ih_l0"] 193 | correct_state_dict["weight_ih_l0"] = w 194 | self.rnn.load_state_dict(correct_state_dict) 195 | 196 | def init_linear(self): 197 | with torch.no_grad(): 198 | nn.init.xavier_uniform_(self.linear.weight) 199 | nn.init.constant_(self.linear.bias, 0) 200 | 201 | def init_linear_from_pretrained(self, weights): 202 | # k[7: ] because we need to remove prefix "linear." because 203 | # self.linear.state_dict() is without "linear." prefix 204 | self.linear.load_state_dict( 205 | {k[7:]: v for k, v in weights.items() if k[:7] == "linear."} 206 | ) 207 | 208 | def init_weights(self): 209 | if self.params["pretrain"]: 210 | if self.params["w2v_weights"] is not None: 211 | self.init_embeddings_from_pretrained( 212 | torch.load(self.params["w2v_weights"]), 213 | self.params["fix_embeddings"] 214 | ) 215 | else: 216 | self.init_embeddings(self.params["fix_embeddings"]) 217 | self.init_rnn() 218 | self.init_linear() 219 | else: 220 | if self.params["lm_ckpt"] is not None: 221 | lm_ckpt_weights = torch.load(self.params["lm_ckpt"]) 222 | lm_ckpt_weights = lm_ckpt_weights["state_dict"] 223 | self.init_embeddings_from_pretrained( 224 | lm_ckpt_weights["embs.weight"], 225 | self.params["fix_embeddings"] 226 | ) 227 | self.init_rnn_from_pretrained(lm_ckpt_weights) 228 | self.init_linear_from_pretrained(lm_ckpt_weights) 229 | else: 230 | if self.params["w2v_weights"] is not None: 231 | self.init_embeddings_from_pretrained( 232 | torch.load(self.params["w2v_weights"]), 233 | self.params["fix_embeddings"] 234 | ) 235 | else: 236 | self.init_embeddings(self.params["fix_embeddings"]) 237 | self.init_rnn() 238 | self.init_linear() 239 | if self.is_attn: 240 | if self.params["attn_ckpt"] is not None: 241 | self.input_attention.init_attn_from_pretrained( 242 | torch.load(self.params["attn_ckpt"])["state_dict"], 243 | self.params["fix_attn_embeddings"] 244 | ) 245 | else: 246 | self.input_attention.init_attn( 247 | self.params["fix_attn_embeddings"] 248 | ) 249 | if self.hidden_used: 250 | self.hidden.init_hidden() 251 | if self.gated_used: 252 | self.gated.init_gated() 253 | if self.params["use_ch"]: 254 | self.ch.init_ch() 255 | -------------------------------------------------------------------------------- /source/pipeline.py: -------------------------------------------------------------------------------- 1 | from . import constants 2 | from .datasets import pad 3 | from torch.nn.utils import clip_grad_norm_ 4 | import torch.nn.functional as F 5 | from tqdm import tqdm 6 | import torch 7 | import numpy as np 8 | 9 | 10 | def train_epoch(dataloader, model, optimizer, device, clip_to, logfile): 11 | """ 12 | Function for training the model one epoch 13 | dataloader - either LanguageModeling or DefinitionModeling dataloader 14 | model - DefinitionModelingModel 15 | optimizer - optimizer to use (usually Adam) 16 | device - cuda/cpu 17 | clip_to - value to clip gradients 18 | logfile - where to log training 19 | """ 20 | # switch model to training mode 21 | model.train() 22 | # train 23 | mean_batch_loss = 0 24 | for batch in tqdm(dataloader, file=logfile): 25 | y_true = torch.from_numpy(batch.pop("y")).to(device).view(-1) 26 | # prepare model args 27 | to_input = {"x": torch.from_numpy(batch["x"]).to(device)} 28 | if not model.params["pretrain"]: 29 | if model.is_w2v: 30 | to_input["input"] = torch.from_numpy(batch["input"]).to(device) 31 | if model.is_ada: 32 | to_input["input"] = torch.from_numpy( 33 | batch["input_adaptive"] 34 | ).to(device) 35 | if model.is_attn: 36 | to_input["word"] = torch.from_numpy(batch["word"]).to(device) 37 | to_input["context"] = torch.from_numpy( 38 | batch["context"] 39 | ).to(device) 40 | if model.params["use_ch"]: 41 | to_input["CH_word"] = torch.from_numpy( 42 | batch["CH"] 43 | ).to(device) 44 | 45 | y_pred, hidden = model(**to_input) 46 | batch_loss = F.cross_entropy( 47 | y_pred, y_true, 48 | ignore_index=constants.PAD_IDX 49 | ) 50 | optimizer.zero_grad() 51 | batch_loss.backward() 52 | clip_grad_norm_( 53 | filter(lambda p: p.requires_grad, model.parameters()), clip_to 54 | ) 55 | optimizer.step() 56 | logfile.flush() 57 | mean_batch_loss += batch_loss.item() 58 | 59 | mean_batch_loss = mean_batch_loss / len(dataloader) 60 | logfile.write( 61 | "Mean training loss on epoch: {0}\n".format(mean_batch_loss) 62 | ) 63 | logfile.flush() 64 | 65 | 66 | def test(dataloader, model, device, logfile): 67 | """ 68 | Function for testing the model on dataloader 69 | dataloader - either LanguageModeling or DefinitionModeling dataloader 70 | model - DefinitionModelingModel 71 | device - cuda/cpu 72 | logfile - where to log evaluation 73 | """ 74 | # switch model to evaluation mode 75 | model.eval() 76 | # eval 77 | lengths_sum = 0 78 | loss_sum = 0 79 | with torch.no_grad(): 80 | for batch in tqdm(dataloader, file=logfile): 81 | y_true = torch.from_numpy(batch.pop("y")).to(device).view(-1) 82 | # prepare model args 83 | to_input = {"x": torch.from_numpy(batch["x"]).to(device)} 84 | if not model.params["pretrain"]: 85 | if model.is_w2v: 86 | to_input["input"] = torch.from_numpy( 87 | batch["input"] 88 | ).to(device) 89 | if model.is_ada: 90 | to_input["input"] = torch.from_numpy( 91 | batch["input_adaptive"] 92 | ).to(device) 93 | if model.is_attn: 94 | to_input["word"] = torch.from_numpy( 95 | batch["word"] 96 | ).to(device) 97 | to_input["context"] = torch.from_numpy( 98 | batch["context"] 99 | ).to(device) 100 | if model.params["use_ch"]: 101 | to_input["CH_word"] = torch.from_numpy( 102 | batch["CH"] 103 | ).to(device) 104 | 105 | y_pred, hidden = model(**to_input) 106 | loss_sum += F.cross_entropy( 107 | y_pred, 108 | y_true, 109 | ignore_index=constants.PAD_IDX, 110 | size_average=False 111 | ).item() 112 | lengths_sum += (to_input["x"] != constants.PAD_IDX).sum().item() 113 | logfile.flush() 114 | 115 | perplexity = np.exp(loss_sum / lengths_sum) 116 | logfile.write( 117 | "Perplexity: {0}\n".format(perplexity) 118 | ) 119 | logfile.flush() 120 | return perplexity 121 | 122 | 123 | def generate(model, voc, tau, n, length, device, prefix=None, 124 | input=None, word=None, context=None, context_voc=None, 125 | CH_word=None, ch_voc=None): 126 | """ 127 | model - DefinitionModelingModel 128 | voc - model Vocabulary 129 | tau - temperature to generate with 130 | n - number of samples 131 | length - length of the sample 132 | device - cuda/cpu 133 | prefix - prefix to read until generation 134 | input - vectors for Input/InputAdaptive conditioning 135 | word - word for InputAttention conditioning 136 | context - context for InputAttention conditioning 137 | context_voc - Vocabulary for InputAttention conditioning 138 | CH_word - word for CH conditioning 139 | ch_voc - Vocabulary for CH conditioning 140 | """ 141 | model.eval() 142 | to_input = {} 143 | if not model.params["pretrain"]: 144 | if model.is_w2v or model.is_ada: 145 | assert input is not None, ("input argument is required because" 146 | "model uses w2v or adagram vectors") 147 | assert input.dim() == 1, ("input argument must be vector" 148 | "but its dim is {0}".format(input.dim())) 149 | to_input["input"] = input.repeat(n).view(n, -1).to(device) 150 | if model.is_attn: 151 | assert word is not None, ("word argument is required because" 152 | "model uses attention") 153 | assert context is not None, ("context argument is required because" 154 | "model uses attention") 155 | assert context_voc is not None, ("context_voc argument is required" 156 | "because model uses attention") 157 | assert isinstance(word, str), ("word argument must be string") 158 | assert isinstance(context, str), ("context argument must be " 159 | "string") 160 | to_input["word"] = torch.LongTensor( 161 | [context_voc.encode(word)] 162 | ).repeat(n).view(n).to(device) 163 | to_input["context"] = torch.LongTensor( 164 | context_voc.encode_seq(context.split()) 165 | ).repeat(n).view(n, -1).to(device) 166 | if model.params["use_ch"]: 167 | assert CH_word is not None, ("CH_word argument is required because" 168 | "because model uses CH conditioning") 169 | assert ch_voc is not None, ("ch_voc argument is required because" 170 | "because model uses CH conditioning") 171 | assert isinstance(CH_word, str), ("CH_word must be string") 172 | to_input["CH_word"] = torch.LongTensor( 173 | pad( 174 | [constants.BOS_IDX] + 175 | ch_voc.encode_seq(list(CH_word)) + 176 | [constants.EOS_IDX], ch_voc.tok_maxlen + 2, 177 | constants.PAD_IDX 178 | ) 179 | ).repeat(n).view(n, -1).to(device) 180 | 181 | to_input["x"] = None 182 | to_input["hidden"] = None # pytorch automatically init to zeroes 183 | ret = [[] for i in range(n)] 184 | if prefix is not None: 185 | assert isinstance(prefix, str), "prefix argument must be string" 186 | if len(prefix.split()) > 0: 187 | to_input["x"] = torch.LongTensor( 188 | voc.encode_seq(prefix.split()) 189 | ).repeat(n).view(n, -1).to(device) 190 | else: 191 | to_input["x"] = torch.randint( 192 | model.params["ntokens"], size=(1, ), dtype=torch.long 193 | ).repeat(n).view(n, -1).to(device) 194 | prefix = voc.decode(to_input["x"][0][0].item()) 195 | with torch.no_grad(): 196 | for i in range(length): 197 | output, to_input["hidden"] = model(**to_input) 198 | output = output.view((n, -1, model.params["ntokens"]))[:, -1, :] 199 | to_input["x"] = F.softmax( 200 | output / tau, dim=1 201 | ).multinomial(num_samples=1) 202 | for i in range(n): 203 | ret[i].append(to_input["x"][i][0].item()) 204 | 205 | output = [[] for i in range(n)] 206 | for i in range(n): 207 | decoded = voc.decode_seq(ret[i]) 208 | for j in range(length): 209 | if decoded[j] == constants.EOS: 210 | break 211 | output[i].append(decoded[j]) 212 | output[i] = " ".join(map(str, output[i])) 213 | 214 | return "\n".join(output) 215 | -------------------------------------------------------------------------------- /source/utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | import subprocess 3 | import os 4 | import numpy as np 5 | from gensim.models import KeyedVectors 6 | import random 7 | 8 | 9 | def prepare_ada_vectors_from_python(file, julia_script, ada_binary_path): 10 | """ 11 | file - path to file with words and contexts on each line separated by \t, 12 | words and punctuation marks in contexts are separated by spaces 13 | julia_script - path to prep_ada.jl script 14 | ada_binary_path - path to ada binary file 15 | """ 16 | data = open(file, "r").readlines() 17 | tmp = [] 18 | for i in range(len(data)): 19 | word, context = data[i].split('\t') 20 | context = context.strip().split() 21 | tmp.append([[word], [], context]) 22 | tmp_name = "./tmp" + str(random.randint(1, 999999)) + ".txt" 23 | tmp_script_name = "./tmp_script" + str(random.randint(1, 999999)) + ".sh" 24 | tmp_vecs_name = "./tmp_vecs" + str(random.randint(1, 999999)) 25 | with open(tmp_name, "w") as outfile: 26 | json.dump(tmp, outfile, indent=4) 27 | with open(tmp_script_name, "w") as outfile: 28 | outfile.write( 29 | "julia " + julia_script + " --defs " + tmp_name + 30 | " --save " + tmp_vecs_name + 31 | " --ada " + ada_binary_path 32 | ) 33 | subprocess.call(["/bin/bash", "-i", tmp_script_name]) 34 | vecs = np.load(tmp_vecs_name).astype(np.float32) 35 | os.remove(tmp_name) 36 | os.remove(tmp_script_name) 37 | os.remove(tmp_vecs_name) 38 | return vecs 39 | 40 | 41 | def prepare_w2v_vectors(file, w2v_binary_path): 42 | """ 43 | file - path to file with words and contexts on each line separated by \t, 44 | words and punctuation marks in contexts are separated by spaces 45 | w2v_binary_path - path to w2v binary 46 | """ 47 | data = open(file, "r").readlines() 48 | word_vectors = KeyedVectors.load_word2vec_format( 49 | w2v_binary_path, binary=True 50 | ) 51 | vecs = [] 52 | initrange = 0.5 / word_vectors.vector_size 53 | for i in range(len(data)): 54 | word, context = data[i].split('\t') 55 | context = context.strip().split() 56 | if word in word_vectors: 57 | vecs.append(word_vectors[word]) 58 | else: 59 | vecs.append( 60 | np.random.uniform( 61 | low=-initrange, 62 | high=initrange, 63 | size=word_vectors.vector_size 64 | ) 65 | ) 66 | return np.array(vecs, dtype=np.float32) 67 | 68 | 69 | class MultipleOptimizer(object): 70 | 71 | def __init__(self, *op): 72 | self.optimizers = op 73 | 74 | def zero_grad(self): 75 | for op in self.optimizers: 76 | op.zero_grad() 77 | 78 | def step(self): 79 | for op in self.optimizers: 80 | op.step() 81 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | from source.datasets import LanguageModelingDataset, LanguageModelingCollate 2 | from source.datasets import DefinitionModelingDataset, DefinitionModelingCollate 3 | from source.model import DefinitionModelingModel 4 | from source.pipeline import train_epoch, test 5 | from torch.utils.data import DataLoader 6 | import torch 7 | import torch.optim as optim 8 | from tqdm import tqdm 9 | import argparse 10 | import json 11 | import numpy as np 12 | 13 | # Read all arguments and prepare all stuff for training 14 | 15 | parser = argparse.ArgumentParser(description='Script to train a model') 16 | # Type of training 17 | parser.add_argument( 18 | '--pretrain', dest='pretrain', action="store_true", 19 | help='whether to pretrain model on LM dataset or train on definitions' 20 | ) 21 | # Common data arguments 22 | parser.add_argument( 23 | "--voc", type=str, required=True, help="location of vocabulary file" 24 | ) 25 | # Definitions data arguments 26 | parser.add_argument( 27 | '--train_defs', type=str, required=False, 28 | help="location of json file with train definitions." 29 | ) 30 | parser.add_argument( 31 | '--eval_defs', type=str, required=False, 32 | help="location of json file with eval definitions." 33 | ) 34 | parser.add_argument( 35 | '--test_defs', type=str, required=False, 36 | help="location of json file with test definitions" 37 | ) 38 | parser.add_argument( 39 | '--input_train', type=str, required=False, 40 | help="location of train vectors for Input conditioning" 41 | ) 42 | parser.add_argument( 43 | '--input_eval', type=str, required=False, 44 | help="location of eval vectors for Input conditioning" 45 | ) 46 | parser.add_argument( 47 | '--input_test', type=str, required=False, 48 | help="location of test vectors for Input conditioning" 49 | ) 50 | parser.add_argument( 51 | '--input_adaptive_train', type=str, required=False, 52 | help="location of train vectors for InputAdaptive conditioning" 53 | ) 54 | parser.add_argument( 55 | '--input_adaptive_eval', type=str, required=False, 56 | help="location of eval vectors for InputAdaptive conditioning" 57 | ) 58 | parser.add_argument( 59 | '--input_adaptive_test', type=str, required=False, 60 | help="location test vectors for InputAdaptive conditioning" 61 | ) 62 | parser.add_argument( 63 | '--context_voc', type=str, required=False, 64 | help="location of context vocabulary file" 65 | ) 66 | parser.add_argument( 67 | '--ch_voc', type=str, required=False, 68 | help="location of CH vocabulary file" 69 | ) 70 | # LM data arguments 71 | parser.add_argument( 72 | '--train_lm', type=str, required=False, 73 | help="location of txt file train LM data" 74 | ) 75 | parser.add_argument( 76 | '--eval_lm', type=str, required=False, 77 | help="location of txt file eval LM data" 78 | ) 79 | parser.add_argument( 80 | '--test_lm', type=str, required=False, 81 | help="location of txt file test LM data" 82 | ) 83 | parser.add_argument( 84 | '--bptt', type=int, required=False, 85 | help="sequence length for BackPropThroughTime in LM pretraining" 86 | ) 87 | # Model parameters arguments 88 | parser.add_argument( 89 | '--nx', type=int, required=True, 90 | help="size of embeddings" 91 | ) 92 | parser.add_argument( 93 | '--nlayers', type=int, required=True, 94 | help="number of LSTM layers" 95 | ) 96 | parser.add_argument( 97 | '--nhid', type=int, required=True, 98 | help="size of hidden states" 99 | ) 100 | parser.add_argument( 101 | '--rnn_dropout', type=float, required=True, 102 | help="probability of RNN dropout" 103 | ) 104 | parser.add_argument( 105 | '--use_seed', dest="use_seed", action="store_true", 106 | help="whether to use Seed conditioning or not" 107 | ) 108 | parser.add_argument( 109 | '--use_input', dest="use_input", action="store_true", 110 | help="whether to use Input conditioning or not" 111 | ) 112 | parser.add_argument( 113 | '--use_input_adaptive', dest="use_input_adaptive", action="store_true", 114 | help="whether to use InputAdaptive conditioning or not" 115 | ) 116 | parser.add_argument( 117 | '--use_input_attention', dest="use_input_attention", 118 | action="store_true", 119 | help="whether to use InputAttention conditioning or not" 120 | ) 121 | parser.add_argument( 122 | '--n_attn_embsize', type=int, required=False, 123 | help="size of InputAttention embeddings" 124 | ) 125 | parser.add_argument( 126 | '--n_attn_hid', type=int, required=False, 127 | help="size of InputAttention linear layer" 128 | ) 129 | parser.add_argument( 130 | '--attn_dropout', type=float, required=False, 131 | help="probability of InputAttention dropout" 132 | ) 133 | parser.add_argument( 134 | '--attn_sparse', dest="attn_sparse", action="store_true", 135 | help="whether to use sparse embeddings in InputAttention or not" 136 | ) 137 | parser.add_argument( 138 | '--use_ch', dest="use_ch", action="store_true", 139 | help="whether to use CH conditioning or not" 140 | ) 141 | parser.add_argument( 142 | '--ch_emb_size', type=int, required=False, 143 | help="size of embeddings in CH conditioning" 144 | ) 145 | parser.add_argument( 146 | '--ch_feature_maps', type=int, required=False, nargs="+", 147 | help="list of feature map sizes in CH conditioning" 148 | ) 149 | parser.add_argument( 150 | '--ch_kernel_sizes', type=int, required=False, nargs="+", 151 | help="list of kernel sizes in CH conditioning" 152 | ) 153 | parser.add_argument( 154 | '--use_hidden', dest="use_hidden", action="store_true", 155 | help="whether to use Hidden conditioning or not" 156 | ) 157 | parser.add_argument( 158 | '--use_hidden_adaptive', dest="use_hidden_adaptive", 159 | action="store_true", 160 | help="whether to use HiddenAdaptive conditioning or not" 161 | ) 162 | parser.add_argument( 163 | '--use_hidden_attention', dest="use_hidden_attention", 164 | action="store_true", 165 | help="whether to use HiddenAttention conditioning or not" 166 | ) 167 | parser.add_argument( 168 | '--use_gated', dest="use_gated", action="store_true", 169 | help="whether to use Gated conditioning or not" 170 | ) 171 | parser.add_argument( 172 | '--use_gated_adaptive', dest="use_gated_adaptive", action="store_true", 173 | help="whether to use GatedAdaptive conditioning or not" 174 | ) 175 | parser.add_argument( 176 | '--use_gated_attention', dest="use_gated_attention", action="store_true", 177 | help="whether to use GatedAttention conditioning or not" 178 | ) 179 | # Training arguments 180 | parser.add_argument( 181 | '--lr', type=float, required=True, 182 | help="initial lr" 183 | ) 184 | parser.add_argument( 185 | "--decay_factor", type=float, required=True, 186 | help="factor to decay lr" 187 | ) 188 | parser.add_argument( 189 | '--decay_patience', type=int, required=True, 190 | help="after number of patience epochs - decay lr" 191 | ) 192 | parser.add_argument( 193 | '--num_epochs', type=int, required=True, 194 | help="number of epochs to train" 195 | ) 196 | parser.add_argument( 197 | '--batch_size', type=int, required=True, 198 | help="batch size" 199 | ) 200 | parser.add_argument( 201 | "--clip", type=float, required=True, 202 | help="value to clip norm of gradients to" 203 | ) 204 | parser.add_argument( 205 | "--random_seed", type=int, required=True, 206 | help="random seed" 207 | ) 208 | # Utility arguments 209 | parser.add_argument( 210 | "--exp_dir", type=str, required=True, 211 | help="where to save all stuff about training" 212 | ) 213 | parser.add_argument( 214 | "--w2v_weights", type=str, required=False, 215 | help="path to pretrained embeddings to init" 216 | ) 217 | parser.add_argument( 218 | "--fix_embeddings", dest="fix_embeddings", action="store_true", 219 | help="whether to update embedding matrix or not" 220 | ) 221 | parser.add_argument( 222 | "--fix_attn_embeddings", dest="fix_attn_embeddings", action="store_true", 223 | help="whether to update attention embedding matrix or not" 224 | ) 225 | parser.add_argument( 226 | "--lm_ckpt", type=str, required=False, 227 | help="path to pretrained language model weights" 228 | ) 229 | parser.add_argument( 230 | "--attn_ckpt", type=str, required=False, 231 | help="path to pretrained Attention module" 232 | ) 233 | # read args 234 | args = vars(parser.parse_args()) 235 | 236 | logfile = open(args["exp_dir"] + "training_log", "a") 237 | #import sys 238 | #logfile = sys.stdout 239 | 240 | if args["pretrain"]: 241 | assert args["train_lm"] is not None, "--train_lm is required if --pretrain" 242 | assert args["eval_lm"] is not None, "--eval_lm is required if --pretrain" 243 | assert args["test_lm"] is not None, "--test_lm is required if --pretrain" 244 | assert args["bptt"] is not None, "--bptt is required if --pretrain" 245 | 246 | train_dataset = LanguageModelingDataset( 247 | file=args["train_lm"], 248 | vocab_path=args["voc"], 249 | bptt=args["bptt"], 250 | ) 251 | train_dataloader = DataLoader( 252 | train_dataset, batch_size=args["batch_size"], 253 | collate_fn=LanguageModelingCollate 254 | ) 255 | eval_dataset = LanguageModelingDataset( 256 | file=args["eval_lm"], 257 | vocab_path=args["voc"], 258 | bptt=args["bptt"], 259 | ) 260 | eval_dataloader = DataLoader( 261 | eval_dataset, batch_size=args["batch_size"], 262 | collate_fn=LanguageModelingCollate 263 | ) 264 | else: 265 | assert args["train_defs"] is not None, ("--pretrain is False," 266 | " --train_defs is required") 267 | assert args["eval_defs"] is not None, ("--pretrain is False," 268 | " --eval_defs is required") 269 | assert args["test_defs"] is not None, ("--pretrain is False," 270 | " --test_defs is required") 271 | 272 | train_dataset = DefinitionModelingDataset( 273 | file=args["train_defs"], 274 | vocab_path=args["voc"], 275 | input_vectors_path=args["input_train"], 276 | input_adaptive_vectors_path=args["input_adaptive_train"], 277 | context_vocab_path=args["context_voc"], 278 | ch_vocab_path=args["ch_voc"], 279 | use_seed=args["use_seed"] 280 | ) 281 | train_dataloader = DataLoader( 282 | train_dataset, 283 | batch_size=args["batch_size"], 284 | collate_fn=DefinitionModelingCollate 285 | ) 286 | eval_dataset = DefinitionModelingDataset( 287 | file=args["eval_defs"], 288 | vocab_path=args["voc"], 289 | input_vectors_path=args["input_eval"], 290 | input_adaptive_vectors_path=args["input_adaptive_eval"], 291 | context_vocab_path=args["context_voc"], 292 | ch_vocab_path=args["ch_voc"], 293 | use_seed=args["use_seed"] 294 | ) 295 | eval_dataloader = DataLoader( 296 | eval_dataset, 297 | batch_size=args["batch_size"], 298 | collate_fn=DefinitionModelingCollate 299 | ) 300 | 301 | if args["use_input"] or args["use_hidden"] or args["use_gated"]: 302 | assert args["input_train"] is not None, ("--use_input or " 303 | "--use_hidden or " 304 | "--use_gated is used " 305 | "--input_train is required") 306 | assert args["input_eval"] is not None, ("--use_input or " 307 | "--use_hidden or " 308 | "--use_gated is used " 309 | "--input_eval is required") 310 | assert args["input_test"] is not None, ("--use_input or " 311 | "--use_hidden or " 312 | "--use_gated is used " 313 | "--input_test is required") 314 | args["input_dim"] = train_dataset.input_vectors.shape[1] 315 | 316 | if args["use_input_adaptive"] or args["use_hidden_adaptive"] or args["use_gated_adaptive"]: 317 | assert args["input_adaptive_train"] is not None, ("--use_input_adaptive or " 318 | "--use_hidden_adaptive or " 319 | "--use_gated_adaptive is used " 320 | "--input_adaptive_train is required") 321 | assert args["input_adaptive_eval"] is not None, ("--use_input_adaptive or " 322 | "--use_hidden_adaptive or " 323 | "--use_gated_adaptive is used " 324 | "--input_adaptive_eval is required") 325 | assert args["input_adaptive_test"] is not None, ("--use_input_adaptive or " 326 | "--use_hidden_adaptive or " 327 | "--use_gated_adaptive is used " 328 | "--input_adaptive_test is required") 329 | args["input_adaptive_dim"] = train_dataset.input_adaptive_vectors.shape[1] 330 | 331 | if args["use_input_attention"] or args["use_hidden_attention"] or args["use_gated_attention"]: 332 | assert args["context_voc"] is not None, ("--use_input_attention or " 333 | "--use_hidden_attention or " 334 | "--use_gated_attention is used " 335 | "--context_voc is required") 336 | assert args["n_attn_embsize"] is not None, ("--use_input_attention or " 337 | "--use_hidden_attention or " 338 | "--use_gated_attention is used " 339 | "--n_attn_embsize is required") 340 | assert args["n_attn_hid"] is not None, ("--use_input_attention or " 341 | "--use_hidden_attention or " 342 | "--use_gated_attention is used " 343 | "--n_attn_hid is required") 344 | assert args["attn_dropout"] is not None, ("--use_input_attention or " 345 | "--use_hidden_attention or " 346 | "--use_gated_attention is used " 347 | "--attn_dropout is required") 348 | 349 | args["n_attn_tokens"] = len(train_dataset.context_voc.tok2id) 350 | 351 | if args["use_ch"]: 352 | assert args["ch_voc"] is not None, ("--ch_voc is required " 353 | "if --use_ch") 354 | assert args["ch_emb_size"] is not None, ("--ch_emb_size is required " 355 | "if --use_ch") 356 | assert args["ch_feature_maps"] is not None, ("--ch_feature_maps is " 357 | "required if --use_ch") 358 | assert args["ch_kernel_sizes"] is not None, ("--ch_kernel_sizes is " 359 | "required if --use_ch") 360 | 361 | args["n_ch_tokens"] = len(train_dataset.ch_voc.tok2id) 362 | args["ch_maxlen"] = train_dataset.ch_voc.tok_maxlen + 2 363 | 364 | 365 | args["ntokens"] = len(train_dataset.voc.tok2id) 366 | 367 | np.random.seed(args["random_seed"]) 368 | torch.manual_seed(args["random_seed"]) 369 | if torch.cuda.is_available(): 370 | torch.cuda.manual_seed(args["random_seed"]) 371 | 372 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 373 | model = DefinitionModelingModel(args).to(device) 374 | scheduler = optim.lr_scheduler.ReduceLROnPlateau( 375 | optim.Adam( 376 | filter(lambda p: p.requires_grad, model.parameters()), lr=args["lr"] 377 | ), 378 | factor=args["decay_factor"], 379 | patience=args["decay_patience"] 380 | ) 381 | 382 | best_ppl = float("inf") 383 | for epoch in tqdm(range(args["num_epochs"]), file=logfile): 384 | train_epoch( 385 | train_dataloader, 386 | model, 387 | scheduler.optimizer, 388 | device, 389 | args["clip"], 390 | logfile 391 | ) 392 | eval_ppl = test(eval_dataloader, model, device, logfile) 393 | if eval_ppl < best_ppl: 394 | best_ppl = eval_ppl 395 | torch.save( 396 | {"state_dict": model.state_dict()}, 397 | args["exp_dir"] + "weights.pth" 398 | ) 399 | scheduler.step(metrics=eval_ppl) 400 | 401 | with open(args["exp_dir"] + "params.json", "w") as outfile: 402 | json.dump(args, outfile, indent=4) 403 | -------------------------------------------------------------------------------- /train_attention_skipgram.py: -------------------------------------------------------------------------------- 1 | from source.attention_skipgram import AttentionSkipGram 2 | from source.utils import MultipleOptimizer 3 | import argparse 4 | import numpy as np 5 | import os.path 6 | from tqdm import tqdm 7 | from collections import Counter 8 | import json 9 | from source.datasets import Vocabulary 10 | import torch 11 | import torch.nn as nn 12 | import torch.nn.functional as F 13 | import torch.optim as optim 14 | from itertools import chain 15 | 16 | parser = argparse.ArgumentParser( 17 | description='Script to train a AttentionSkipGram model' 18 | ) 19 | parser.add_argument( 20 | '--data', type=str, required=False, 21 | help="path to data" 22 | ) 23 | parser.add_argument( 24 | '--context_voc', type=str, required=True, 25 | help=("path to context voc for DefinitionModelingModel is necessary to " 26 | "save pretrained attention module, particulary embedding matrix") 27 | ) 28 | parser.add_argument( 29 | '--prepared', dest='prepared', action="store_true", 30 | help='whether to prepare data or use already prepared' 31 | ) 32 | parser.add_argument( 33 | "--window", type=int, required=True, 34 | help="window for AttentionSkipGram model" 35 | ) 36 | parser.add_argument( 37 | "--random_seed", type=int, required=True, 38 | help="random seed for training" 39 | ) 40 | parser.add_argument( 41 | "--sparse", dest="sparse", action="store_true", 42 | help="whether to use sparse embeddings or not" 43 | ) 44 | parser.add_argument( 45 | "--vec_dim", type=int, required=True, 46 | help="vector dim to train" 47 | ) 48 | parser.add_argument( 49 | "--attn_hid", type=int, required=True, 50 | help="hidden size in attention module" 51 | ) 52 | parser.add_argument( 53 | "--attn_dropout", type=float, required=True, 54 | help="dropout prob in attention module" 55 | ) 56 | parser.add_argument( 57 | "--lr", type=float, required=True, 58 | help="initial lr to use" 59 | ) 60 | parser.add_argument( 61 | "--batch_size", type=int, required=True, 62 | help="batch size to use" 63 | ) 64 | parser.add_argument( 65 | "--num_epochs", type=int, required=True, 66 | help="number of epochs to train" 67 | ) 68 | parser.add_argument( 69 | "--exp_dir", type=str, required=True, 70 | help="where to save weights, prepared data and logs" 71 | ) 72 | args = vars(parser.parse_args()) 73 | logfile = open(args["exp_dir"] + "training_log", "a") 74 | 75 | context_voc = Vocabulary() 76 | context_voc.load(args["context_voc"]) 77 | 78 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 79 | 80 | np.random.seed(args["random_seed"]) 81 | torch.manual_seed(args["random_seed"]) 82 | if torch.cuda.is_available(): 83 | torch.cuda.manual_seed(args["random_seed"]) 84 | 85 | if args["prepared"]: 86 | assert os.path.isfile(args["exp_dir"] + "data.npz"), ("prepared data " 87 | "does not exist") 88 | assert os.path.isfile(args["exp_dir"] + "voc.json"), ("prepared voc " 89 | "does not exist") 90 | 91 | tqdm.write("Loading data!", file=logfile) 92 | logfile.flush() 93 | 94 | data = np.load(args["exp_dir"] + "data.npz") 95 | words_idx = data['words_idx'] 96 | cnt_idx = data['cnt_idx'] 97 | freqs = data['freqs'] 98 | 99 | with open(args["exp_dir"] + "voc.json", 'r') as f: 100 | voc = json.load(f) 101 | 102 | word2id = voc['word2id'] 103 | id2word = voc['id2word'] 104 | 105 | else: 106 | assert args["data"] is not None, "--prepared False, provide --data" 107 | 108 | tqdm.write("Preparing data!", file=logfile) 109 | logfile.flush() 110 | 111 | with open(args["data"], 'r') as f: 112 | data = f.read() 113 | 114 | data = data.lower().split() 115 | counter = Counter(data) 116 | word2id = {} 117 | id2word = {} 118 | i = 0 119 | words = [] 120 | counts = [] 121 | for w, c in counter.most_common(): 122 | words.append(w) 123 | counts.append(c) 124 | word2id[words[-1]] = i 125 | id2word[i] = w 126 | i += 1 127 | 128 | freqs = np.array(counts) 129 | freqs = freqs / freqs.sum() 130 | freqs = np.sqrt(freqs) 131 | freqs = freqs / freqs.sum() 132 | data = list(map(lambda w: word2id[w], data)) 133 | 134 | words_idx = np.zeros(len(data) - 2 * args["window"], dtype=np.int) 135 | cnt_idx = np.zeros( 136 | (len(data) - 2 * args["window"], 2 * args["window"]), dtype=np.int 137 | ) 138 | 139 | for i in tqdm(range(args["window"], len(data) - args["window"]), file=logfile): 140 | words_idx[i - args["window"]] = data[i] 141 | cnt_idx[i - args["window"]] = np.array( 142 | data[i - args["window"]:i] + data[i + 1:i + args["window"] + 1] 143 | ) 144 | 145 | np.savez( 146 | args["exp_dir"] + "data", 147 | words_idx=words_idx, 148 | cnt_idx=cnt_idx, 149 | freqs=freqs 150 | ) 151 | with open(args["exp_dir"] + "voc.json", 'w') as f: 152 | json.dump({'word2id': word2id, 'id2word': id2word}, f) 153 | 154 | tqdm.write("Data prepared and saved!", file=logfile) 155 | logfile.flush() 156 | 157 | 158 | def generate_neg(batch_size, negative=10): 159 | return np.random.choice(freqs.size, size=(batch_size, negative), p=freqs) 160 | 161 | 162 | def generate_batch(batch_size=128): 163 | shuffle = np.random.permutation(words_idx.shape[0]) 164 | words_idx_shuffled = words_idx[shuffle] 165 | cnt_idx_shuffled = cnt_idx[shuffle] 166 | for i in tqdm(range(0, words_idx.shape[0], batch_size), file=logfile): 167 | start = i 168 | end = min(i + batch_size, words_idx.shape[0]) 169 | words = words_idx_shuffled[start:end] 170 | context = cnt_idx_shuffled[start:end] 171 | neg = generate_neg(end - start) 172 | 173 | context = torch.from_numpy(context).to(device) 174 | words = torch.from_numpy(words).to(device) 175 | neg = torch.from_numpy(neg).to(device) 176 | 177 | yield words, context, neg 178 | 179 | del words_idx_shuffled 180 | del cnt_idx_shuffled 181 | 182 | tqdm.write("Initialising model!", file=logfile) 183 | logfile.flush() 184 | 185 | model = AttentionSkipGram( 186 | n_attn_tokens=len(word2id), 187 | n_attn_embsize=args["vec_dim"], 188 | n_attn_hid=args["attn_hid"], 189 | attn_dropout=args["attn_dropout"], 190 | sparse=args["sparse"] 191 | ).to(device) 192 | 193 | 194 | if args["sparse"]: 195 | optimizer = MultipleOptimizer( 196 | optim.SparseAdam(chain( 197 | model.emb0_lookup.embs.parameters(), 198 | model.emb1_lookup.parameters() 199 | ), lr=args["lr"]), 200 | optim.Adam(chain( 201 | model.emb0_lookup.ann.parameters(), 202 | model.emb0_lookup.a_linear.parameters() 203 | ), lr=args["lr"]) 204 | ) 205 | else: 206 | optimizer = optim.Adam(model.parameters(), lr=args["lr"]) 207 | 208 | tqdm.write("Start training!", file=logfile) 209 | logfile.flush() 210 | 211 | model.train() 212 | 213 | for _ in range(args["num_epochs"]): 214 | for w, c, n in generate_batch(batch_size=args["batch_size"]): 215 | optimizer.zero_grad() 216 | loss = model(w, c, n) 217 | loss.backward() 218 | optimizer.step() 219 | 220 | 221 | tqdm.write("Training ended! Saving model!", file=logfile) 222 | logfile.flush() 223 | 224 | state_dict = model.emb0_lookup.state_dict() 225 | initrange = 0.5 / args["vec_dim"] 226 | embs_weights = np.random.uniform( 227 | low=-initrange, 228 | high=initrange, 229 | size=(len(context_voc.tok2id), args["vec_dim"]), 230 | ).astype(np.float32) 231 | for word in context_voc.tok2id.keys(): 232 | if word in word2id: 233 | new_id = context_voc.tok2id[word] 234 | old_id = word2id[word] 235 | embs_weights[new_id] = state_dict["embs.weight"][old_id].cpu().numpy() 236 | 237 | state_dict["embs.weight"] = torch.from_numpy(embs_weights).to(device) 238 | 239 | torch.save( 240 | {"state_dict": state_dict}, 241 | args["exp_dir"] + "weights.pth" 242 | ) 243 | 244 | with open(args["exp_dir"] + "params.json", "w") as outfile: 245 | json.dump(args, outfile, indent=4) 246 | --------------------------------------------------------------------------------