├── LICENSE.txt ├── README.md ├── launch-experiment.sbatch ├── logs └── .DS_Store └── python ├── __init__.py ├── autotags.py ├── models ├── __init__.py ├── bilstm.py ├── cbow.py └── esim.py ├── predictions.py ├── requirements.txt ├── train_genre.py ├── train_mnli.py ├── train_snli.py └── util ├── __init__.py ├── blocks.py ├── data_processing.py ├── evaluate.py ├── logger.py └── parameters.py /LICENSE.txt: -------------------------------------------------------------------------------- 1 | Copyright 2018, New York University 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 4 | 5 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 6 | 7 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 8 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Baseline Models for MultiNLI Corpus 2 | 3 | This is the code we used to establish baselines for the MultiNLI corpus introduced in [A Broad-Coverage Challenge Corpus for Sentence Understanding through Inference](https://arxiv.org/pdf/1704.05426.pdf). 4 | 5 | ## Data 6 | The MultiNLI and SNLI corpora are both distributed in JSON lines and tab separated value files. Both can be downloaded [here](https://www.nyu.edu/projects/bowman/multinli/). 7 | 8 | ## Models 9 | We present three baseline neural network models. These range from a bare-bones model (CBOW), to an elaborate model which has achieved state-of-the-art performance on the SNLI corpus (ESIM), 10 | 11 | - Continuous Bag of Words (CBOW): in this model, each sentence is represented as the sum of the embedding representations of its 12 | words. This representation is passed to a deep, 3-layers, MLP. Main code for this model is in [`cbow.py`](https://github.com/NYU-MLL/multiNLI/blob/master/python/models/cbow.py) 13 | - Bi-directional LSTM: in this model, the average of the states of 14 | a bidirectional LSTM RNN is used as the sentence representation. Main code for this model is in [`bilstm.py`](https://github.com/NYU-MLL/multiNLI/blob/master/python/models/bilstm.py) 15 | - Enhanced Sequential Inference Model (ESIM): this is our implementation of the [Chen et al.'s (2017)](https://arxiv.org/pdf/1609.06038v2.pdf) ESIM, without ensembling with a TreeLSTM. Main code for this model is in [`esim.py`](https://github.com/NYU-MLL/multiNLI/blob/master/python/models/esim.py) 16 | 17 | We use dropout for regularization in all three models. 18 | 19 | ## Training and Testing 20 | 21 | ### Training settings 22 | 23 | The models can be trained on three different settings. Each setting has its own training script. 24 | 25 | - To train a model only on SNLI data, 26 | - Use [`train_snli.py`](https://github.com/NYU-MLL/multiNLI/blob/master/python/train_snli.py). 27 | - Accuracy on SNLI's dev-set is used to do early stopping. 28 | 29 | - To train a model on only MultiNLI or on a mixture of MultiNLI and SNLI data, 30 | - Use [`train_mnli.py`](https://github.com/NYU-MLL/multiNLI/blob/master/python/train_mnli.py). 31 | - The optional `alpha` flag determines what percentage of SNLI data is used in training. The default value for alpha is 0.0, which means the model will be only trained on MultiNLI data. 32 | - If `alpha` is a set to a value greater than 0 (and less than 1), an `alpha` percentage of SNLI training data is randomly sampled at the beginning of each epoch. 33 | - When using SNLI training data in this setting, we set `alpha` = 0.15. 34 | - Accuracy on MultiNLI's matched dev-set is used to do early stopping. 35 | 36 | - To train a model on a single MultiNLI genre, 37 | - Use [`train_genre.py`](https://github.com/NYU-MLL/multiNLI/blob/master/python/train_genre.py). 38 | - To use this training setting, you must call the `genre` flag and set it to a valid training genre (`travel`, `fiction`, `slate`, `telephone`, `government`, or `snli`). 39 | - Accuracy on the dev-set for the chosen genre is used to do early stopping. 40 | - Additionally, logs created with this training setting contain evaulation statistics by genre. 41 | - You can also train a model on SNLI with this script if you desire genre specific statistics in your logs. 42 | 43 | ### Command line flags 44 | 45 | To start training with any of the training scripts, there are a couple of required command-line flags and an array of optional flags. The code concerning all flags can be found in [`parameters.py`](https://github.com/NYU-MLL/multiNLI/blob/master/python/util/parameters.py). All the parameters set in `parameters.py` are printed to the log file everytime the training script is launched. 46 | 47 | Required flags, 48 | 49 | - `model_type`: there are three model types in this repository, `cbow`, `bilstm`, and `cbow`. You must state which model you want to use. 50 | - `model_name`: this is your experiment name. This name will be used the prefix the log and checkpoint files. 51 | 52 | Optional flags, 53 | 54 | - `datapath`: path to your directory with MultiNLI, and SNLI data. Default is set to "../data" 55 | - `ckptpath`: path to your directory where you wish to store checkpoint files. Default is set to "../logs" 56 | - `logpath`: path to your directory where you wish to store log files. Default is set to "../logs" 57 | - `emb_to_load`: path to your directory with GloVe data. Default is set to "../data" 58 | - `learning_rate`: the learning rate you wish to use during training. Default value is set to 0.0004 59 | - `keep_rate`: the hyper-parameter for dropout-rate. `keep_rate` = 1 - dropout-rate. The default value is set to 0.5. 60 | - `seq_length`: the maximum sequence length you wish to use. Default value is set to 50. Sentences shorter than `seq_length` are padded to the right. Sentences longer than `seq-length` are truncated. 61 | - `emb_train`: boolean flag that determines if the model updates word embeddings during training. If called, the word embeddings are updated. 62 | - `alpha`: only used during `train_mnli` scheme. Determines what percentage of SNLI training data to use in each epoch of training. Default value set to 0.0 (which makes the model train on MultiNLI only). 63 | - `genre`: only used during `train_genre` scheme. Use this flag to set which single genre you wish to train on. Valid genres are `travel`, `fiction`, `slate`, `telephone`, `government`, or `snli`. 64 | - `test`: boolean used to test a trained model. Call this flag if you wish to load a trained model and test it on MultiNLI dev-sets* and SNLI test-set. When called, the best checkpoint will be used (see section on checkpoints for more details). 65 | 66 | 67 | *Dev-sets are currently used for testing on MultiNLI since the test-sets have not be released. 68 | 69 | ### Other parameters 70 | 71 | Remaining parameters like the size of hidden layers, word embeddings, and minibatch can be changed directly in `parameters.py`. The default hidden embedding and word embedding size is set to 300, the minibatch size (`batch_size` in the code) is set to 32. 72 | 73 | ### Sample commands 74 | To execute all of the following sample commands, you must be in the "python" folder, 75 | 76 | - To train on SNLI data only, here is a sample command, 77 | 78 | `PYTHONPATH=$PYTHONPATH:. python train_snli.py cbow petModel-0 --keep_rate 0.9 --seq_length 25 --emb_train` 79 | 80 | where the `model_type` flag is set to `cbow` and can be swapped for `bilstm` or `esim`, and the `model_name` flag is set to `petModel-0` and can be changed to whatever you please. 81 | 82 | - Similarly, to train on a mixture MultiNLI and SNLI data, here is a sample command, 83 | 84 | `PYTHONPATH=$PYTHONPATH:. python train_mnli.py bilstm petModel-1 --keep_rate 0.9 --alpha 0.15 --emb_train` 85 | 86 | where 15% of SNLI training data is randomly sampled at the beginning of each epoch. 87 | 88 | - To train on just the `travel` genre in MultiNLI data, 89 | 90 | `PYTHONPATH=$PYTHONPATH:. python train_genre.py esim petModel-2 --genre travel --emb_train` 91 | 92 | ### Testing models 93 | 94 | #### On dev set, 95 | To test a trained model, simply add the `test` flag to the command used for training. The best checkpoint will be loaded and used to evaluate the model's performance on the MultiNLI dev-sets, SNLI test-set, and the dev-set for each genre in MultiNLI. 96 | 97 | For example, 98 | 99 | `PYTHONPATH=$PYTHONPATH:. python train_genre.py esim petModel-2 --genre travel --emb_train --test` 100 | 101 | 102 | With the `test` flag, the `train_mnli.py` script will also generate a CSV of predictions for the unlabaled matched and mismatched test-sets. 103 | 104 | #### Results for unlabeled test sets, 105 | To get a CSV of predicted results for unlabeled test sets use `predictions.py`. This script requires the same flags as the training scripts. You must enter the `model_type` and `model_name`, and the path to the saved checkpoint and log files if they are different from the default (the default is set to `../logs` for both paths). 106 | 107 | Here is a sample command, 108 | 109 | `PYTHONPATH=$PYTHONPATH:. python predictions.py esim petModel-1 --alpha 0.15 --emb_train --logpath ../logs_keep --ckptpath ../logs_keep ` 110 | 111 | This script will create a CSV with two columns: pairID and gold_label. 112 | 113 | 114 | ### Checkpoints 115 | 116 | We maintain two checkpoints: the most recent checkpoint and the best checkpoint. Every 500 steps, the most recent checkpoint is updated, and we test to see if the dev-set accuracy has improved by at least 0.04%. If the accuracy has gone up by at least 0.04%, then the best checkpoint is updated. 117 | 118 | ### Annotation Tags 119 | 120 | The script which was used to determine the percentage of annotation tags is available in this repository, within the subfolder "python" under the name "autotags.py". It takes a parsed corpus file (e.g., a dev set file) and reports the percentages of annotation tags in that file. You should also update your paths in the script to reflect your local file organization. 121 | 122 | ## License 123 | 124 | Copyright 2018, New York University 125 | 126 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 127 | 128 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 129 | 130 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 131 | 132 | -------------------------------------------------------------------------------- /launch-experiment.sbatch: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #SBATCH --partition=gpu 4 | #SBATCH --gres=gpu:1 5 | #SBATCH --mem=16000 6 | #SBATCH -t24:00:00 7 | 8 | # Make sure we have access to HPC-managed libraries. 9 | module load tensorflow/python2.7/20170218 10 | 11 | # Run. 12 | PYTHONPATH=$PYTHONPATH:. python training_script model_type experiment_name --keep_rate 0.5 --learning_rate 0.0004 --alpha 0.13 --emb_train 13 | 14 | # Available training_scripts: train_snli.py, train_mnli.py, train_genre.py 15 | # Available model_types: cbow, bilstm, esim -------------------------------------------------------------------------------- /logs/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nyu-mll/multiNLI/c4078cad9f9b5d06a672f2edb69bde007a3567a9/logs/.DS_Store -------------------------------------------------------------------------------- /python/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nyu-mll/multiNLI/c4078cad9f9b5d06a672f2edb69bde007a3567a9/python/__init__.py -------------------------------------------------------------------------------- /python/autotags.py: -------------------------------------------------------------------------------- 1 | 2 | ###################################### 3 | # MultiNLI Annotation Tags Script # 4 | # By Adina Williams # 5 | ###################################### 6 | 7 | # How to use this script: 8 | # You should have a version of MultiNLI downloaded 9 | # You should update the paths to point to your local files 10 | # This can take MultiNLI dev set, test set, matched and mismatched, as well as SNLI files 11 | 12 | 13 | tags_to_results = defaultdict(list) 14 | 15 | def log(tag, is_correct, label): 16 | tags_to_results[tag].append((is_correct, label)) 17 | 18 | def find_1st_verb(str1): #find ptb verb codings for first verb from root of sentence 19 | findy=str1.find('(VB') 20 | if findy >0: 21 | return str1[findy:].split()[0] 22 | else: 23 | return '' 24 | 25 | def tense_match(str1,str2): # this function test for tense match, by finding the first verb and checking it against the second verb occurence 26 | result=find_1st_verb(str1) 27 | if len(result)>0: 28 | findy2=str2.find(result) 29 | return findy2>0 30 | else: 31 | return False 32 | 33 | ################## 34 | # declare # 35 | # your paths # 36 | ################## 37 | 38 | dev_m_path = '/Users/Adina/Documents/MultiGenreNLI/multinli/multinli_1.0_dev_matched.txt' 39 | dev_mm_path = '/Users/Adina/Documents/MultiGenreNLI/multinli/multinli_1.0_dev_mismatched.txt' 40 | test_m_path = '/Users/Adina/Documents/MultiGenreNLI/multinli/multinli_1.0_test_matched.txt' 41 | test_mm_path = '/Users/Adina/Documents/MultiGenreNLI/multinli/multinli_1.0_test_mismatched.txt' 42 | dev_path = '/Users/Adina/Documents/MultiGenreNLI/multinli/multinli_1.0_dev_all.txt' # concatenated match and mismatch 43 | test_path = '/Users/Adina/Documents/MultiGenreNLI/mnli_0.9/multinli_0.9_test_all.txt' # concatenated match and mismatch 44 | snli_test_path = '/Users/Adina/Documents/MultiGenreNLI/snli_1.0/snli_1.0_test.txt' 45 | snli_dev_path = '/Users/Adina/Documents/MultiGenreNLI/snli_1.0/snli_1.0_dev.txt' 46 | #train_path = '/Users/Adina/Documents/MultiGenreNLI/mnli_0.9/multinli_0.9_train.txt' 47 | 48 | ptbtags={"(MD":"modal","(W":"WH","(CD":"card","(PRP":"pron","(EX":"exist","(IN":"prep","(POS":"'s"} # dict of interesting ptb tags to check, pulled directly from PTB tagger 49 | 50 | # for the dataset you want to see annotations for. You should specify its path, above. 51 | # this gets your data in a reasonable order for figuring out the annotations 52 | # this part also selects the PTB parses from the corpus. It works with the .txt distribution but not the jsonl distribution 53 | 54 | with open(dev_path, 'rbU')as csvfile: 55 | reader = csv.reader(csvfile, delimiter="\t") 56 | i = 0 57 | for row in reader: 58 | label = row[0] 59 | if label in ["entailment", "contradiction", "neutral"]: 60 | pairid = row[8] 61 | p1 = row[3] 62 | p2 = row[4] 63 | b1 = row[1] 64 | b2 = row[2] 65 | t1 = row[5] 66 | t2 = row[6] 67 | genre = row[9] 68 | b1t = b1.split() 69 | b2t = b2.split() 70 | sb1t = set(b1t) 71 | sb2t = set(b2t) 72 | parses = p1 + " " + p2 73 | correct = 'correct' # this needs to be supplied from the model outputs 74 | 75 | log("label-" + label, correct, label) 76 | 77 | 78 | #################### 79 | # HAND CHOSEN TAGS # 80 | #################### 81 | # Linguistically # 82 | # interesting # 83 | # phenomena were # 84 | # picked by Adina # 85 | #################### 86 | 87 | # just a note, results will be reported for both hypotheses and pairs (i.e., hypothesis or premise) 88 | 89 | 90 | ################ 91 | # NEG/DET # 92 | ################ 93 | 94 | if "n't" in parses or "not" in parses or "none" in parses or "never" in parses or "neither" in parses or "nor" in parses: # add in un- and non- :/ 95 | log('neg-all', correct, label) 96 | if ("n't" in p2 or "not" in p2 or "none" in p2 or "never" in p2 or "neither" in p2 or "nor" in p2) and not ("n't" in p1 or "not" in p1 or "none" in p1 or "never" in p1 or "neither" in p1 or "nor" in p1): 97 | log('neg-hyp-only', correct, label) 98 | 99 | 100 | if "a" in parses or "the" in parses or "these" in parses or "this" in parses or "those" in parses or "that" in parses: 101 | log('det-all', correct, label) 102 | if ("a" in p2 or "the" in p2 or "these" in p2 or "this" in p2 or "those" in p2 or "that" in p2) and not ("a" in p1 or "the" in p1 or "these" in p1 or "this" in p1 or "those" in p1 or "that" in p1): 103 | log('det-hyp-only', correct, label) 104 | 105 | ################## 106 | # PTB TAGS # 107 | ################## 108 | for key in ptbtags: 109 | if key in parses: 110 | log(ptbtags[key]+'_ptb_all', correct, label) 111 | if (key in p2) and not (key in p1): 112 | log(ptbtags[key]+'_ptb_hyp_only', correct, label) 113 | 114 | if ("(NNS" in p2) and ("(NNP" in p1): 115 | log('plural-premise-sing-hyp_ptb', correct, label) 116 | if ("(NNP" in p2) and ("(NNS" in p1): 117 | log('plural-hyp-sing-premise_ptb', correct, label) 118 | 119 | if tense_match(p1,p2): 120 | log('tense_match', correct, label) 121 | ################### 122 | # interjects & # # we added some extra, potentially interesting things to check, but they didn't turn out to be interesting. 123 | # foreign words # 124 | ################### 125 | 126 | if "(UH" in parses: 127 | log('interject-all_ptb', correct, label) 128 | if ("(UH" in p2) and not ("(UH" in p1): 129 | log('interject-hyp-only_ptb', correct, label) 130 | 131 | if "(FW" in parses: 132 | log('foreign-all_ptb', correct, label) 133 | if ("(FW" in p2) and not ("(FW" in p1): 134 | log('foreign-hyp-only_ptb', correct, label) 135 | 136 | ################### 137 | # PTB modifiers # 138 | ################### 139 | 140 | if "(JJ" in parses: 141 | log('adject-all_ptb', correct, label) 142 | if ("(JJ" in p2) and not ("(JJ" in p1): 143 | log('adject-hyp-only_ptb', correct, label) 144 | 145 | if "(RB" in parses: 146 | log('adverb-all_ptb', correct, label) 147 | if ("(RB" in p2) and not ("(RB" in p1): 148 | log('adverb-hyp-only_ptb', correct, label) 149 | 150 | if "(JJ" in parses or "(RB" in parses: 151 | log('adj/adv-all_ptb', correct, label) 152 | if ("(JJ" in p2 or "(RB" in p2) and not ("(JJ" in p1 or "(RB" in p1): 153 | log('adj/adv-hyp-only_ptb', correct, label) 154 | # modifiers are good examples of how additions/subtractions of single words result in neutral 155 | 156 | # if hyp (and premise) have -er -est adjectives or adverbs in them 157 | if "(RBR" in parses or "(RBS" in parses or "(JJR" in parses or "(JJS" in parses: 158 | log('er-est-all_ptb', correct, label) 159 | if ("(RBR" in p2 or "(RBS" in p2 or "(JJR" in p2 or "(JJS" in p2) and not ("(RBR" in p1 or "(RBS" in p1 or "(JJR" in p1 or "(JJS" in p1): 160 | log('er-est-hyp-only_ptb', correct, label) 161 | 162 | 163 | 164 | 165 | ######################### 166 | # S-Root, length etc. # 167 | ######################### 168 | 169 | s1 = p1[0:8] == "(ROOT (S" 170 | s2 = p2[0:8] == "(ROOT (S" 171 | if s1 and s2: 172 | log('syn-S-S', correct, label) 173 | elif s1 or s2: 174 | log('syn-S-NP', correct, label) 175 | else: 176 | log('syn-NP-NP', correct, label) 177 | 178 | prem_len = len([word for word in b2.split() if word != '(' and word != ')']) 179 | if prem_len < 11: 180 | log('len-0-10', correct, label) 181 | elif prem_len < 15: 182 | log('len-11-14', correct, label) 183 | elif prem_len < 20: 184 | log('len-15-19', correct, label) 185 | else: 186 | log('len-20+', correct, label) 187 | 188 | if sb1t.issubset(sb2t): 189 | log('token-ins-only', correct, label) 190 | elif sb2t.issubset(sb1t): 191 | log('token-del-only', correct, label) 192 | 193 | 194 | if len(sb1t.difference(sb2t)) == 1 and len(sb2t.difference(sb1t)) == 1: 195 | log('token-single-sub-or-move', correct, label) 196 | 197 | if len(sb1t.union(sb2t)) > 0: 198 | overlap = float(len(sb1t.intersection(sb2t)))/len(sb1t.union(sb2t)) 199 | if overlap > 0.6: 200 | log('overlap-xhigh', correct, label) 201 | elif overlap > 0.37: 202 | log('overlap-high', correct, label) 203 | elif overlap > 0.23: 204 | log('overlap-mid', correct, label) 205 | elif overlap > 0.12: 206 | log('overlap-low', correct, label) 207 | else: 208 | log('overlap-xlow', correct, label) 209 | else: 210 | log('overlap-empty', correct, label) 211 | 212 | 213 | ############## 214 | # GREPing # 215 | ############## 216 | 217 | 218 | 219 | # for keyphrase in ["there are", "there is", "There are", "There is", "There's", "there's", "there were", "There were", "There was", "there was", "there will", "There will"]: 220 | # if keyphrase in t2: 221 | # log('template-thereis', correct, label) 222 | # break 223 | 224 | # for keyphrase in ["can", "could", "may", "might", "must", "will", "would", "should"]: 225 | # if keyphrase in p2 or keyphrase in p1: 226 | # log('template-modals', correct, label) 227 | # break 228 | 229 | for keyphrase in ["much", "enough", "more", "most", "every", "each", "less", "least", "no", "none", "some", "all", "any", "many", "few", "several"]: # get a list from Anna's book, think more about it 230 | if keyphrase in p2 or keyphrase in p1: 231 | log('template-quantifiers', correct, label) 232 | break 233 | 234 | for keyphrase in ["know", "knew", "believe", "understood", "understand", "doubt", "notice", "contemplate", "consider", "wonder", "thought", "think", "suspect", "suppose", "recognize", "recognise", "forgot", "forget", "remember", "imagine", "meant", "agree", "mean", "disagree", "denied", "deny", "promise"]: 235 | if keyphrase in p2 or keyphrase in p1: 236 | log('template-beliefVs', correct, label) 237 | break 238 | 239 | # for keyphrase in ["love", "hate", "dislike", "annoy", "angry", "happy", "sad", "bliss", "blissful", "depress","terrified","terrify", "scare", "amuse", "suprise", "guilt", "fear", "afraid", "startle", "confuse", "baffle", "frustrate", "enfuriate", "rage", "befuddle", "fury", "furious", "elated", "elation", "joy", "joyous", "joyful", "enjoy", "relish"]: 240 | # if keyphrase in p2 or keyphrase in p1: 241 | # log('template-psychpreds', correct, label) 242 | # break 243 | 244 | 245 | # for keyphrase in ['I', 'me', 'my', 'mine', 'we', 'our', 'ours', 'you', 'your', 'yours', "y'all", 'he', 'him', 'her', 'she', 'it', 'they', 'their', 'theirs', 'them']:# 246 | # if keyphrase in t2: 247 | # log('template-pronouns', correct, label) 248 | # break 249 | 250 | for keyphrase in ['if']: 251 | if keyphrase in p2 or keyphrase in p1: 252 | log('template-if', correct, label) 253 | break 254 | 255 | # for keyphrase in ["May I", "Mr.", "Mrs." "Ms.", "Dr.", "excuse me", "Excuse me", "pardon me", "sorry", "Sorry", "I'm sorry", "I am sorry", "Pardon me", 'please', 'thank', 'thanks', 'Thanks', 'Thank', 'Please', "you're welcome", "You're welcome", "much obliged", "Much obliged"]: 256 | # if keyphrase in p2 or keyphrase in p1: 257 | # log('template-polite', correct, label) 258 | # break 259 | 260 | for keyphrase in ["time", "January", "February", "March", "April", "May", "June", "July", "August", "September", "October", "November", "December", "Monday", "Tuesday", "Wednesday", "Thursday", "Friday", "Saturday", "Sunday", "morning", "night", "tomorrow", "yesterday", "evening", "week", "weeks", "hours", "minutes", "seconds" "hour", "days", "years", "decades", "lifetime", "lifetimes", "epoch", "epochs", "day", "recent", "recently", "habitually", "whenever", "during", "while", "before", "after", "previously", "again", "often", "repeatedly", "frequently", "dusk", "dawn", "midnight", "afternoon", "when", "daybreak", "later", "earlier", "month", "year", "decade", "biweekly", "millenium", "midday", "daily", "weekly", "monthly", "yearly", "hourly", "fortnight", "now", "then"]: 261 | if keyphrase in p2 or keyphrase in p1: 262 | log('template-timeterms', correct, label) 263 | break 264 | 265 | for keyphrase in ["too", "anymore", "also", "as well", "again", "no longer", "start", "started", "starting", "stopping", "stop", "stopped", "regretting", "regret", "regretted", "realizing", "realize", "realized", "aware", "manage", "managed", "forgetting", "forget", "forgot", "began", "begin", "finish", "finished", "finishing", "ceasing", "cease", "ceased", "enter", "entered", "entering", "leaving", "leave", "left", "carry on", "carried on", "return", "returned", "returning", "restoring", "restore", "restored", "repeat", "repeated", "repeating", "another", "only", "coming back", "come back", "came back"]: 266 | if keyphrase in p2 or keyphrase in p1: 267 | log('template-presupptrigs', correct, label) 268 | break 269 | 270 | for keyphrase in ["although", "but", "yet", "despite", "however", "However", "Although", "But", "Yet", "Despite", "therefore", "Therefore", "Thus", "thus"]: 271 | if keyphrase in p2 or keyphrase in p1: 272 | log('template-convo-pivot', correct, label) 273 | break 274 | 275 | # for keyphrase in ["weight", "height", "age", "width", "length", "mother", "father", "sister", "brother", "aunt", "uncle", "cousin", "husband", "wife", "mom", "dad", "Mom", "Dad", "Mama", "Papa", "mama", "papa", "grandma", "grandpa", "nephew", "niece", "widow", "family", "kin", "bride", "spouse"]: 276 | # if keyphrase in p2 or keyphrase in p1: 277 | # log('template-relNs', correct, label) 278 | # break 279 | 280 | # for keyphrase in ['who', 'what', 'why', 'when', 'how', "where", "which", "whose", "whether"]: 281 | # if keyphrase in t2: 282 | # log('template-WH', correct, label) 283 | # break 284 | 285 | # for keyphrase in ['for', 'with', 'in', 'of', 'on', 'at', 'into', 'by', 'through', 'via', 'throughout', 'near', 'up', 'down', 'off', 'over', 'under', 'underneath', 'against', 'above', 'to', 'towards', 'toward', 'until', 'away', 'from', 'beneath', 'beside', 'within', 'without', 'upon', 'onto', 'aside', 'across', 'about', 'after', 'before', 'along', 'among', 'around', 'after', 'between', 'beyond', 'below']: 286 | # if keyphrase in t2: 287 | # log('template-prep', correct, label) 288 | # break 289 | 290 | ############################### 291 | # Too few to be interesting # 292 | ############################### 293 | 294 | # if ("(RBR" in p2 or "(JJR" in p2 ) and ("(RBS" in p1 or "(JJS" in p1): 295 | # log('er-hyp-est-premise_ptb', correct, label) 296 | # 297 | # if ("(RBS" in p2 or "(JJS" in p2 ) and ("(RBR" in p1 or "(JJR" in p1): 298 | # log('est-hyp-er-premise_ptb', correct, label) 299 | # 300 | # if "(PDT" in parses: 301 | # log("pre-det-all_ptb", correct, label) 302 | # if ("(PDT" in p2) and not ("(PDT" in p1): 303 | # log("pre-det-hyp-only_ptb", correct, label) 304 | # for keyphrase in ["at home", "at school", "home", "to school", "from school", "at church", "from church", "to church", "in jail", "in prison"]: 305 | # if keyphrase in t2: 306 | # log('template-PPincorp', correct, label) 307 | # break 308 | # for keyphrase in ["more", "less", "than"]: 309 | # if keyphrase in t2: 310 | # log('template-more/less', correct, label) 311 | # break 312 | # for keyphrase in ['fake', 'false', 'counterfeit', 'alleged', 'former','mock', "imitation"]: 313 | # if keyphrase in t2: 314 | # log('template-nonsubsectAdj', correct, label) 315 | # break 316 | 317 | 318 | i += 1 319 | 320 | # This will print your results to the terminal and create a summary .csv file in current directory that saves them for you 321 | 322 | with open('snli_autoannotationTags_results.csv', 'w') as csvfile: #you name your file here 323 | writer = csv.writer(csvfile, delimiter='\t') 324 | for tag in sorted(tags_to_results): 325 | correct = len([result[0] for result in tags_to_results[tag] if result[0]]) 326 | counts = Counter([result[1] for result in tags_to_results[tag]]) 327 | best_label, best_count = max(counts.iteritems(), key=operator.itemgetter(1)) 328 | 329 | attempted = len(tags_to_results[tag]) 330 | baseline = float(best_count) / attempted 331 | 332 | acc = float(correct)/attempted 333 | totalpercent= float(attempted)/i 334 | 335 | print tag, "\t", correct, "\t", attempted, "\t", acc, "\t", baseline, "\t", best_label, "\t", totalpercent 336 | writer.writerow([tag, correct, attempted, acc, baseline, best_label, totalpercent]) 337 | 338 | -------------------------------------------------------------------------------- /python/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nyu-mll/multiNLI/c4078cad9f9b5d06a672f2edb69bde007a3567a9/python/models/__init__.py -------------------------------------------------------------------------------- /python/models/bilstm.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from util import blocks 3 | 4 | class MyModel(object): 5 | def __init__(self, seq_length, emb_dim, hidden_dim, embeddings, emb_train): 6 | ## Define hyperparameters 7 | self.embedding_dim = emb_dim 8 | self.dim = hidden_dim 9 | self.sequence_length = seq_length 10 | 11 | ## Define the placeholders 12 | self.premise_x = tf.placeholder(tf.int32, [None, self.sequence_length]) 13 | self.hypothesis_x = tf.placeholder(tf.int32, [None, self.sequence_length]) 14 | self.y = tf.placeholder(tf.int32, [None]) 15 | self.keep_rate_ph = tf.placeholder(tf.float32, []) 16 | 17 | ## Define parameters 18 | self.E = tf.Variable(embeddings, trainable=emb_train) 19 | 20 | self.W_mlp = tf.Variable(tf.random_normal([self.dim * 8, self.dim], stddev=0.1)) 21 | self.b_mlp = tf.Variable(tf.random_normal([self.dim], stddev=0.1)) 22 | 23 | self.W_cl = tf.Variable(tf.random_normal([self.dim, 3], stddev=0.1)) 24 | self.b_cl = tf.Variable(tf.random_normal([3], stddev=0.1)) 25 | 26 | ## Function for embedding lookup and dropout at embedding layer 27 | def emb_drop(x): 28 | emb = tf.nn.embedding_lookup(self.E, x) 29 | emb_drop = tf.nn.dropout(emb, self.keep_rate_ph) 30 | return emb_drop 31 | 32 | # Get lengths of unpadded sentences 33 | prem_seq_lengths, prem_mask = blocks.length(self.premise_x) 34 | hyp_seq_lengths, hyp_mask = blocks.length(self.hypothesis_x) 35 | 36 | 37 | ### BiLSTM layer ### 38 | premise_in = emb_drop(self.premise_x) 39 | hypothesis_in = emb_drop(self.hypothesis_x) 40 | 41 | premise_outs, c1 = blocks.biLSTM(premise_in, dim=self.dim, seq_len=prem_seq_lengths, name='premise') 42 | hypothesis_outs, c2 = blocks.biLSTM(hypothesis_in, dim=self.dim, seq_len=hyp_seq_lengths, name='hypothesis') 43 | 44 | premise_bi = tf.concat(premise_outs, axis=2) 45 | hypothesis_bi = tf.concat(hypothesis_outs, axis=2) 46 | 47 | #premise_final = blocks.last_output(premise_bi, prem_seq_lengths) 48 | #hypothesis_final = blocks.last_output(hypothesis_bi, hyp_seq_lengths) 49 | 50 | ### Mean pooling 51 | premise_sum = tf.reduce_sum(premise_bi, 1) 52 | premise_ave = tf.div(premise_sum, tf.expand_dims(tf.cast(prem_seq_lengths, tf.float32), -1)) 53 | 54 | hypothesis_sum = tf.reduce_sum(hypothesis_bi, 1) 55 | hypothesis_ave = tf.div(hypothesis_sum, tf.expand_dims(tf.cast(hyp_seq_lengths, tf.float32), -1)) 56 | 57 | ### Mou et al. concat layer ### 58 | diff = tf.subtract(premise_ave, hypothesis_ave) 59 | mul = tf.multiply(premise_ave, hypothesis_ave) 60 | h = tf.concat([premise_ave, hypothesis_ave, diff, mul], 1) 61 | 62 | # MLP layer 63 | h_mlp = tf.nn.relu(tf.matmul(h, self.W_mlp) + self.b_mlp) 64 | # Dropout applied to classifier 65 | h_drop = tf.nn.dropout(h_mlp, self.keep_rate_ph) 66 | 67 | # Get prediction 68 | self.logits = tf.matmul(h_drop, self.W_cl) + self.b_cl 69 | 70 | # Define the cost function 71 | self.total_cost = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(labels=self.y, logits=self.logits)) 72 | -------------------------------------------------------------------------------- /python/models/cbow.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | class MyModel(object): 4 | def __init__(self, seq_length, emb_dim, hidden_dim, embeddings, emb_train): 5 | ## Define hyperparameters 6 | self.embedding_dim = emb_dim 7 | self.dim = hidden_dim 8 | self.sequence_length = seq_length 9 | 10 | ## Define placeholders 11 | self.premise_x = tf.placeholder(tf.int32, [None, self.sequence_length]) 12 | self.hypothesis_x = tf.placeholder(tf.int32, [None, self.sequence_length]) 13 | self.y = tf.placeholder(tf.int32, [None]) 14 | self.keep_rate_ph = tf.placeholder(tf.float32, []) 15 | 16 | ## Define remaning parameters 17 | self.E = tf.Variable(embeddings, trainable=emb_train, name="emb") 18 | 19 | self.W_0 = tf.Variable(tf.random_normal([self.embedding_dim * 4, self.dim], stddev=0.1), name="w0") 20 | self.b_0 = tf.Variable(tf.random_normal([self.dim], stddev=0.1), name="b0") 21 | 22 | self.W_1 = tf.Variable(tf.random_normal([self.dim, self.dim], stddev=0.1), name="w1") 23 | self.b_1 = tf.Variable(tf.random_normal([self.dim], stddev=0.1), name="b1") 24 | 25 | self.W_2 = tf.Variable(tf.random_normal([self.dim, self.dim], stddev=0.1), name="w2") 26 | self.b_2 = tf.Variable(tf.random_normal([self.dim], stddev=0.1), name="b2") 27 | 28 | self.W_cl = tf.Variable(tf.random_normal([self.dim, 3], stddev=0.1), name="wcl") 29 | self.b_cl = tf.Variable(tf.random_normal([3], stddev=0.1), name="bcl") 30 | 31 | 32 | ## Calculate representaitons by CBOW method 33 | emb_premise = tf.nn.embedding_lookup(self.E, self.premise_x) 34 | emb_premise_drop = tf.nn.dropout(emb_premise, self.keep_rate_ph) 35 | 36 | emb_hypothesis = tf.nn.embedding_lookup(self.E, self.hypothesis_x) 37 | emb_hypothesis_drop = tf.nn.dropout(emb_hypothesis, self.keep_rate_ph) 38 | 39 | premise_rep = tf.reduce_sum(emb_premise_drop, 1) 40 | hypothesis_rep = tf.reduce_sum(emb_hypothesis_drop, 1) 41 | 42 | ## Combinations 43 | h_diff = premise_rep - hypothesis_rep 44 | h_mul = premise_rep * hypothesis_rep 45 | 46 | ### MLP 47 | mlp_input = tf.concat([premise_rep, hypothesis_rep, h_diff, h_mul], 1) 48 | h_1 = tf.nn.relu(tf.matmul(mlp_input, self.W_0) + self.b_0) 49 | h_2 = tf.nn.relu(tf.matmul(h_1, self.W_1) + self.b_1) 50 | h_3 = tf.nn.relu(tf.matmul(h_2, self.W_2) + self.b_2) 51 | h_drop = tf.nn.dropout(h_3, self.keep_rate_ph) 52 | 53 | # Get prediction 54 | self.logits = tf.matmul(h_drop, self.W_cl) + self.b_cl 55 | 56 | # Define the cost function 57 | self.total_cost = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(labels=self.y, logits=self.logits)) 58 | -------------------------------------------------------------------------------- /python/models/esim.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from util import blocks 3 | 4 | class MyModel(object): 5 | def __init__(self, seq_length, emb_dim, hidden_dim, embeddings, emb_train): 6 | ## Define hyperparameters 7 | self.embedding_dim = emb_dim 8 | self.dim = hidden_dim 9 | self.sequence_length = seq_length 10 | 11 | ## Define the placeholders 12 | self.premise_x = tf.placeholder(tf.int32, [None, self.sequence_length]) 13 | self.hypothesis_x = tf.placeholder(tf.int32, [None, self.sequence_length]) 14 | self.y = tf.placeholder(tf.int32, [None]) 15 | self.keep_rate_ph = tf.placeholder(tf.float32, []) 16 | 17 | ## Define parameters 18 | self.E = tf.Variable(embeddings, trainable=emb_train) 19 | 20 | self.W_mlp = tf.Variable(tf.random_normal([self.dim * 8, self.dim], stddev=0.1)) 21 | self.b_mlp = tf.Variable(tf.random_normal([self.dim], stddev=0.1)) 22 | 23 | self.W_cl = tf.Variable(tf.random_normal([self.dim, 3], stddev=0.1)) 24 | self.b_cl = tf.Variable(tf.random_normal([3], stddev=0.1)) 25 | 26 | ## Function for embedding lookup and dropout at embedding layer 27 | def emb_drop(x): 28 | emb = tf.nn.embedding_lookup(self.E, x) 29 | emb_drop = tf.nn.dropout(emb, self.keep_rate_ph) 30 | return emb_drop 31 | 32 | # Get lengths of unpadded sentences 33 | prem_seq_lengths, mask_prem = blocks.length(self.premise_x) 34 | hyp_seq_lengths, mask_hyp = blocks.length(self.hypothesis_x) 35 | 36 | 37 | ### First biLSTM layer ### 38 | 39 | premise_in = emb_drop(self.premise_x) 40 | hypothesis_in = emb_drop(self.hypothesis_x) 41 | 42 | premise_outs, c1 = blocks.biLSTM(premise_in, dim=self.dim, seq_len=prem_seq_lengths, name='premise') 43 | hypothesis_outs, c2 = blocks.biLSTM(hypothesis_in, dim=self.dim, seq_len=hyp_seq_lengths, name='hypothesis') 44 | 45 | premise_bi = tf.concat(premise_outs, axis=2) 46 | hypothesis_bi = tf.concat(hypothesis_outs, axis=2) 47 | 48 | premise_list = tf.unstack(premise_bi, axis=1) 49 | hypothesis_list = tf.unstack(hypothesis_bi, axis=1) 50 | 51 | 52 | ### Attention ### 53 | 54 | scores_all = [] 55 | premise_attn = [] 56 | alphas = [] 57 | 58 | for i in range(self.sequence_length): 59 | 60 | scores_i_list = [] 61 | for j in range(self.sequence_length): 62 | score_ij = tf.reduce_sum(tf.multiply(premise_list[i], hypothesis_list[j]), 1, keep_dims=True) 63 | scores_i_list.append(score_ij) 64 | 65 | scores_i = tf.stack(scores_i_list, axis=1) 66 | alpha_i = blocks.masked_softmax(scores_i, mask_hyp) 67 | a_tilde_i = tf.reduce_sum(tf.multiply(alpha_i, hypothesis_bi), 1) 68 | premise_attn.append(a_tilde_i) 69 | 70 | scores_all.append(scores_i) 71 | alphas.append(alpha_i) 72 | 73 | scores_stack = tf.stack(scores_all, axis=2) 74 | scores_list = tf.unstack(scores_stack, axis=1) 75 | 76 | hypothesis_attn = [] 77 | betas = [] 78 | for j in range(self.sequence_length): 79 | scores_j = scores_list[j] 80 | beta_j = blocks.masked_softmax(scores_j, mask_prem) 81 | b_tilde_j = tf.reduce_sum(tf.multiply(beta_j, premise_bi), 1) 82 | hypothesis_attn.append(b_tilde_j) 83 | 84 | betas.append(beta_j) 85 | 86 | # Make attention-weighted sentence representations into one tensor, 87 | premise_attns = tf.stack(premise_attn, axis=1) 88 | hypothesis_attns = tf.stack(hypothesis_attn, axis=1) 89 | 90 | # For making attention plots, 91 | self.alpha_s = tf.stack(alphas, axis=2) 92 | self.beta_s = tf.stack(betas, axis=2) 93 | 94 | 95 | ### Subcomponent Inference ### 96 | 97 | prem_diff = tf.subtract(premise_bi, premise_attns) 98 | prem_mul = tf.multiply(premise_bi, premise_attns) 99 | hyp_diff = tf.subtract(hypothesis_bi, hypothesis_attns) 100 | hyp_mul = tf.multiply(hypothesis_bi, hypothesis_attns) 101 | 102 | m_a = tf.concat([premise_bi, premise_attns, prem_diff, prem_mul], 2) 103 | m_b = tf.concat([hypothesis_bi, hypothesis_attns, hyp_diff, hyp_mul], 2) 104 | 105 | 106 | ### Inference Composition ### 107 | 108 | v1_outs, c3 = blocks.biLSTM(m_a, dim=self.dim, seq_len=prem_seq_lengths, name='v1') 109 | v2_outs, c4 = blocks.biLSTM(m_b, dim=self.dim, seq_len=hyp_seq_lengths, name='v2') 110 | 111 | v1_bi = tf.concat(v1_outs, axis=2) 112 | v2_bi = tf.concat(v2_outs, axis=2) 113 | 114 | 115 | ### Pooling Layer ### 116 | 117 | v_1_sum = tf.reduce_sum(v1_bi, 1) 118 | v_1_ave = tf.div(v_1_sum, tf.expand_dims(tf.cast(prem_seq_lengths, tf.float32), -1)) 119 | 120 | v_2_sum = tf.reduce_sum(v2_bi, 1) 121 | v_2_ave = tf.div(v_2_sum, tf.expand_dims(tf.cast(hyp_seq_lengths, tf.float32), -1)) 122 | 123 | v_1_max = tf.reduce_max(v1_bi, 1) 124 | v_2_max = tf.reduce_max(v2_bi, 1) 125 | 126 | v = tf.concat([v_1_ave, v_2_ave, v_1_max, v_2_max], 1) 127 | 128 | 129 | # MLP layer 130 | h_mlp = tf.nn.tanh(tf.matmul(v, self.W_mlp) + self.b_mlp) 131 | 132 | # Dropout applied to classifier 133 | h_drop = tf.nn.dropout(h_mlp, self.keep_rate_ph) 134 | 135 | # Get prediction 136 | self.logits = tf.matmul(h_drop, self.W_cl) + self.b_cl 137 | 138 | # Define the cost function 139 | self.total_cost = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(labels=self.y, logits=self.logits)) 140 | -------------------------------------------------------------------------------- /python/predictions.py: -------------------------------------------------------------------------------- 1 | """ 2 | Script to generate a CSV file of predictions on the test data. 3 | """ 4 | 5 | import tensorflow as tf 6 | import os 7 | import importlib 8 | import random 9 | from util import logger 10 | import util.parameters as params 11 | from util.data_processing import * 12 | from util.evaluate import * 13 | import pickle 14 | 15 | FIXED_PARAMETERS = params.load_parameters() 16 | modname = FIXED_PARAMETERS["model_name"] 17 | logpath = os.path.join(FIXED_PARAMETERS["log_path"], modname) + ".log" 18 | logger = logger.Logger(logpath) 19 | 20 | model = FIXED_PARAMETERS["model_type"] 21 | 22 | module = importlib.import_module(".".join(['models', model])) 23 | MyModel = getattr(module, 'MyModel') 24 | 25 | # Logging parameter settings at each launch of training script 26 | # This will help ensure nothing goes awry in reloading a model and we consistenyl use the same hyperparameter settings. 27 | logger.Log("FIXED_PARAMETERS\n %s" % FIXED_PARAMETERS) 28 | 29 | 30 | ######################### LOAD DATA ############################# 31 | 32 | logger.Log("Loading data") 33 | training_snli = load_nli_data(FIXED_PARAMETERS["training_snli"], snli=True) 34 | dev_snli = load_nli_data(FIXED_PARAMETERS["dev_snli"], snli=True) 35 | test_snli = load_nli_data(FIXED_PARAMETERS["test_snli"], snli=True) 36 | 37 | training_mnli = load_nli_data(FIXED_PARAMETERS["training_mnli"]) 38 | dev_matched = load_nli_data(FIXED_PARAMETERS["dev_matched"]) 39 | dev_mismatched = load_nli_data(FIXED_PARAMETERS["dev_mismatched"]) 40 | test_matched = load_nli_data(FIXED_PARAMETERS["test_matched"]) 41 | test_mismatched = load_nli_data(FIXED_PARAMETERS["test_mismatched"]) 42 | 43 | dictpath = os.path.join(FIXED_PARAMETERS["log_path"], modname) + ".p" 44 | 45 | if not os.path.isfile(dictpath): 46 | print "No dictionary found!" 47 | exit(1) 48 | 49 | else: 50 | logger.Log("Loading dictionary from %s" % (dictpath)) 51 | word_indices = pickle.load(open(dictpath, "rb")) 52 | logger.Log("Padding and indexifying sentences") 53 | sentences_to_padded_index_sequences(word_indices, [training_mnli, training_snli, dev_matched, dev_mismatched, dev_snli, test_snli, test_matched, test_mismatched]) 54 | 55 | loaded_embeddings = loadEmbedding_rand(FIXED_PARAMETERS["embedding_data_path"], word_indices) 56 | 57 | class modelClassifier: 58 | def __init__(self, seq_length): 59 | ## Define hyperparameters 60 | self.learning_rate = FIXED_PARAMETERS["learning_rate"] 61 | self.display_epoch_freq = 1 62 | self.display_step_freq = 50 63 | self.embedding_dim = FIXED_PARAMETERS["word_embedding_dim"] 64 | self.dim = FIXED_PARAMETERS["hidden_embedding_dim"] 65 | self.batch_size = FIXED_PARAMETERS["batch_size"] 66 | self.emb_train = FIXED_PARAMETERS["emb_train"] 67 | self.keep_rate = FIXED_PARAMETERS["keep_rate"] 68 | self.sequence_length = FIXED_PARAMETERS["seq_length"] 69 | self.alpha = FIXED_PARAMETERS["alpha"] 70 | 71 | logger.Log("Building model from %s.py" %(model)) 72 | self.model = MyModel(seq_length=self.sequence_length, emb_dim=self.embedding_dim, hidden_dim=self.dim, embeddings=loaded_embeddings, emb_train=self.emb_train) 73 | 74 | # Perform gradient descent with Adam 75 | self.optimizer = tf.train.AdamOptimizer(self.learning_rate, beta1=0.9, beta2=0.999).minimize(self.model.total_cost) 76 | 77 | # tf things: initialize variables and create placeholder for session 78 | logger.Log("Initializing variables") 79 | self.init = tf.global_variables_initializer() 80 | self.sess = None 81 | self.saver = tf.train.Saver() 82 | 83 | def get_minibatch(self, dataset, start_index, end_index): 84 | indices = range(start_index, end_index) 85 | premise_vectors = np.vstack([dataset[i]['sentence1_binary_parse_index_sequence'] for i in indices]) 86 | hypothesis_vectors = np.vstack([dataset[i]['sentence2_binary_parse_index_sequence'] for i in indices]) 87 | labels = [dataset[i]['label'] for i in indices] 88 | return premise_vectors, hypothesis_vectors, labels 89 | 90 | def classify(self, examples): 91 | # This classifies a list of examples 92 | best_path = os.path.join(FIXED_PARAMETERS["ckpt_path"], modname) + ".ckpt_best" 93 | self.sess = tf.Session() 94 | self.sess.run(self.init) 95 | self.saver.restore(self.sess, best_path) 96 | logger.Log("Model restored from file: %s" % best_path) 97 | 98 | logits = np.empty(3) 99 | minibatch_premise_vectors, minibatch_hypothesis_vectors, minibatch_labels = self.get_minibatch(examples, 0, len(examples)) 100 | feed_dict = {self.model.premise_x: minibatch_premise_vectors, 101 | self.model.hypothesis_x: minibatch_hypothesis_vectors, 102 | self.model.keep_rate_ph: 1.0} 103 | logit = self.sess.run(self.model.logits, feed_dict) 104 | logits = np.vstack([logits, logit]) 105 | 106 | return np.argmax(logits[1:], axis=1) 107 | 108 | 109 | classifier = modelClassifier(FIXED_PARAMETERS["seq_length"]) 110 | 111 | """ 112 | Get CSVs of predictions. 113 | """ 114 | 115 | logger.Log("Creating CSV of predicitons on matched test set: %s" %(modname+"_matched_predictions.csv")) 116 | predictions_kaggle(classifier.classify, test_matched, FIXED_PARAMETERS["batch_size"], modname+"_dev_matched") 117 | 118 | logger.Log("Creating CSV of predicitons on mismatched test set: %s" %(modname+"_mismatched_predictions.csv")) 119 | predictions_kaggle(classifier.classify, test_mismatched, FIXED_PARAMETERS["batch_size"], modname+"_dev_mismatched") 120 | 121 | -------------------------------------------------------------------------------- /python/requirements.txt: -------------------------------------------------------------------------------- 1 | tensorflow==1.0.0 2 | 3 | numpy==1.11.2 4 | -------------------------------------------------------------------------------- /python/train_genre.py: -------------------------------------------------------------------------------- 1 | """ 2 | Training script to train a model on a single genre from MultiNLI or on SNLI data. 3 | The logs created during this training scheme have genre-specific statistics. 4 | """ 5 | 6 | import tensorflow as tf 7 | import os 8 | import importlib 9 | import random 10 | from util import logger 11 | import util.parameters as params 12 | from util.data_processing import * 13 | from util.evaluate import * 14 | 15 | FIXED_PARAMETERS = params.load_parameters() 16 | modname = FIXED_PARAMETERS["model_name"] 17 | logpath = os.path.join(FIXED_PARAMETERS["log_path"], modname) + ".log" 18 | logger = logger.Logger(logpath) 19 | 20 | model = FIXED_PARAMETERS["model_type"] 21 | 22 | module = importlib.import_module(".".join(['models', model])) 23 | MyModel = getattr(module, 'MyModel') 24 | 25 | # Logging parameter settings at each launch of training script 26 | # This will help ensure nothing goes awry in reloading a model and we consistenyl use the same hyperparameter settings. 27 | logger.Log("FIXED_PARAMETERS\n %s" % FIXED_PARAMETERS) 28 | 29 | 30 | ######################### LOAD DATA ############################# 31 | 32 | logger.Log("Loading data") 33 | genres = ['travel', 'fiction', 'slate', 'telephone', 'government', 'snli'] 34 | 35 | alpha = FIXED_PARAMETERS["alpha"] 36 | genre = FIXED_PARAMETERS["genre"] 37 | 38 | # TODO: make script stop in parameter.py if genre name is invalid. 39 | if genre not in genres: 40 | logger.Log("Invalid genre") 41 | exit() 42 | else: 43 | logger.Log("Training on %s genre" %(genre)) 44 | 45 | if genre == "snli": 46 | training_data = load_nli_data_genre(FIXED_PARAMETERS["training_snli"], genre, snli=True) 47 | beta = int(len(training_data) * alpha) 48 | training_data = random.sample(training_data, beta) 49 | else: 50 | training_data = load_nli_data_genre(FIXED_PARAMETERS["training_mnli"], genre, snli=False) 51 | 52 | dev_snli = load_nli_data(FIXED_PARAMETERS["dev_snli"], snli=True) 53 | test_snli = load_nli_data(FIXED_PARAMETERS["test_snli"], snli=True) 54 | dev_matched = load_nli_data(FIXED_PARAMETERS["dev_matched"]) 55 | dev_mismatched = load_nli_data(FIXED_PARAMETERS["dev_mismatched"]) 56 | test_matched = load_nli_data(FIXED_PARAMETERS["test_matched"]) 57 | test_mismatched = load_nli_data(FIXED_PARAMETERS["test_mismatched"]) 58 | 59 | if 'temp.jsonl' in FIXED_PARAMETERS["test_matched"]: 60 | # Removing temporary empty file that was created in parameters.py 61 | os.remove(FIXED_PARAMETERS["test_matched"]) 62 | logger.Log("Created and removed empty file called temp.jsonl since test set is not available.") 63 | 64 | dictpath = os.path.join(FIXED_PARAMETERS["log_path"], modname) + ".p" 65 | 66 | if not os.path.isfile(dictpath): 67 | logger.Log("Building dictionary") 68 | word_indices = build_dictionary([training_data]) 69 | logger.Log("Padding and indexifying sentences") 70 | sentences_to_padded_index_sequences(word_indices, [training_data, dev_matched, dev_mismatched, dev_snli, test_snli, test_matched, test_mismatched]) 71 | pickle.dump(word_indices, open(dictpath, "wb")) 72 | 73 | else: 74 | logger.Log("Loading dictionary from %s" % (dictpath)) 75 | word_indices = pickle.load(open(dictpath, "rb")) 76 | logger.Log("Padding and indexifying sentences") 77 | sentences_to_padded_index_sequences(word_indices, [training_data,dev_matched, dev_mismatched, dev_snli, test_snli, test_matched, test_mismatched]) 78 | 79 | logger.Log("Loading embeddings") 80 | loaded_embeddings = loadEmbedding_rand(FIXED_PARAMETERS["embedding_data_path"], word_indices) 81 | 82 | 83 | class modelClassifier: 84 | def __init__(self, seq_length): 85 | ## Define hyperparameters 86 | self.learning_rate = FIXED_PARAMETERS["learning_rate"] 87 | self.display_epoch_freq = 1 88 | self.display_step_freq = 50 89 | self.embedding_dim = FIXED_PARAMETERS["word_embedding_dim"] 90 | self.dim = FIXED_PARAMETERS["hidden_embedding_dim"] 91 | self.batch_size = FIXED_PARAMETERS["batch_size"] 92 | self.emb_train = FIXED_PARAMETERS["emb_train"] 93 | self.keep_rate = FIXED_PARAMETERS["keep_rate"] 94 | self.sequence_length = FIXED_PARAMETERS["seq_length"] 95 | self.alpha = FIXED_PARAMETERS["alpha"] 96 | 97 | logger.Log("Building model from %s.py" %(model)) 98 | self.model = MyModel(seq_length=self.sequence_length, emb_dim=self.embedding_dim, hidden_dim=self.dim, embeddings=loaded_embeddings, emb_train=self.emb_train) 99 | 100 | # Boolean stating that training has not been completed, 101 | self.completed = False 102 | 103 | # Perform gradient descent with Adam 104 | self.optimizer = tf.train.AdamOptimizer(self.learning_rate, beta1=0.9, beta2=0.999).minimize(self.model.total_cost) 105 | 106 | # tf things: initialize variables and create placeholder for session 107 | logger.Log("Initializing variables") 108 | self.init = tf.global_variables_initializer() 109 | self.sess = None 110 | self.saver = tf.train.Saver() 111 | 112 | 113 | def get_minibatch(self, dataset, start_index, end_index): 114 | indices = range(start_index, end_index) 115 | premise_vectors = np.vstack([dataset[i]['sentence1_binary_parse_index_sequence'] for i in indices]) 116 | hypothesis_vectors = np.vstack([dataset[i]['sentence2_binary_parse_index_sequence'] for i in indices]) 117 | genres = [dataset[i]['genre'] for i in indices] 118 | labels = [dataset[i]['label'] for i in indices] 119 | return premise_vectors, hypothesis_vectors, labels, genres 120 | 121 | 122 | def train(self, training_data, dev_mat, dev_mismat, dev_snli): 123 | self.sess = tf.Session() 124 | self.sess.run(self.init) 125 | 126 | self.step = 1 127 | self.epoch = 0 128 | self.best_dev = 0. 129 | self.best_mtrain_acc = 0. 130 | self.last_train_acc = [.001, .001, .001, .001, .001] 131 | self.best_step = 0 132 | 133 | # Restore best-checkpoint if it exists. 134 | # Also restore values for best dev-set accuracy and best training-set accuracy. 135 | ckpt_file = os.path.join(FIXED_PARAMETERS["ckpt_path"], modname) + ".ckpt" 136 | if os.path.isfile(ckpt_file + ".meta"): 137 | if os.path.isfile(ckpt_file + "_best.meta"): 138 | self.saver.restore(self.sess, (ckpt_file + "_best")) 139 | if genre == 'snli': 140 | dev_acc, dev_cost_snli = evaluate_classifier(self.classify, dev_snli, self.batch_size) 141 | self.best_dev = dev_acc 142 | else: 143 | best_dev_mat, dev_cost_mat = evaluate_classifier_genre(self.classify, dev_mat, self.batch_size) 144 | self.best_dev = best_dev_mat[genre] 145 | self.best_mtrain_acc, mtrain_cost = evaluate_classifier(self.classify, training_data[0:5000], self.batch_size) 146 | 147 | logger.Log("Restored best dev acc: %f\n Restored best train acc: %f" %(self.best_dev, self.best_mtrain_acc)) 148 | 149 | self.saver.restore(self.sess, ckpt_file) 150 | logger.Log("Model restored from file: %s" % ckpt_file) 151 | 152 | 153 | ### Training cycle 154 | logger.Log("Training...") 155 | 156 | while True: 157 | random.shuffle(training_data) 158 | avg_cost = 0. 159 | total_batch = int(len(training_data) / self.batch_size) 160 | 161 | # Loop over all batches in epoch 162 | for i in range(total_batch): 163 | # Assemble a minibatch of the next B examples 164 | minibatch_premise_vectors, minibatch_hypothesis_vectors, minibatch_labels, minibatch_genres = self.get_minibatch( 165 | training_data, self.batch_size * i, self.batch_size * (i + 1)) 166 | 167 | # Run the optimizer to take a gradient step, and also fetch the value of the 168 | # cost function for logging 169 | feed_dict = {self.model.premise_x: minibatch_premise_vectors, 170 | self.model.hypothesis_x: minibatch_hypothesis_vectors, 171 | self.model.y: minibatch_labels, 172 | self.model.keep_rate_ph: self.keep_rate} 173 | _, c = self.sess.run([self.optimizer, self.model.total_cost], feed_dict) 174 | 175 | # Since a single epoch can take a ages for larger models (ESIM), 176 | # we'll print accuracy every 50 steps 177 | if self.step % self.display_step_freq == 0: 178 | dev_acc_mat, dev_cost_mat = evaluate_classifier_genre(self.classify, dev_mat, self.batch_size) 179 | if genre == 'snli': 180 | dev_acc, dev_cost_snli = evaluate_classifier(self.classify, dev_snli, self.batch_size) 181 | else: 182 | dev_acc = dev_acc_mat[genre] 183 | dev_acc_mismat, dev_cost_mismat = evaluate_classifier_genre(self.classify, dev_mismat, self.batch_size) 184 | dev_acc_snli, dev_cost_snli = evaluate_classifier(self.classify, dev_snli, self.batch_size) 185 | mtrain_acc, mtrain_cost = evaluate_classifier(self.classify, training_data[0:5000], self.batch_size) 186 | 187 | logger.Log("Step: %i\t Dev-genre acc: %f\t Dev-mrest acc: %r\t Dev-mmrest acc: %r\t Dev-SNLI acc: %f\t Genre train acc: %f" %(self.step, dev_acc, dev_acc_mat, dev_acc_mismat, dev_acc_snli, mtrain_acc)) 188 | logger.Log("Step: %i\t Dev-matched cost: %f\t Dev-mismatched cost: %f\t Dev-SNLI cost: %f\t Genre train cost: %f" %(self.step, dev_cost_mat, dev_cost_mismat, dev_cost_snli, mtrain_cost)) 189 | 190 | if self.step % 500 == 0: 191 | self.saver.save(self.sess, ckpt_file) 192 | best_test = 100 * (1 - self.best_dev / dev_acc) 193 | if best_test > 0.04: 194 | self.saver.save(self.sess, ckpt_file + "_best") 195 | self.best_dev = dev_acc 196 | self.best_mtrain_acc = mtrain_acc 197 | self.best_step = self.step 198 | logger.Log("Checkpointing with new best dev accuracy: %f" %(self.best_dev)) 199 | 200 | self.step += 1 201 | 202 | # Compute average loss 203 | avg_cost += c / (total_batch * self.batch_size) 204 | 205 | # Display some statistics about the epoch 206 | if self.epoch % self.display_epoch_freq == 0: 207 | logger.Log("Epoch: %i\t Avg. Cost: %f" %(self.epoch+1, avg_cost)) 208 | 209 | self.epoch += 1 210 | self.last_train_acc[(self.epoch % 5) - 1] = mtrain_acc 211 | 212 | # Early stopping 213 | progress = 1000 * (sum(self.last_train_acc)/(5 * min(self.last_train_acc)) - 1) 214 | 215 | if (progress < 0.1) or (self.step > self.best_step + 10000): 216 | logger.Log("Best matched-dev accuracy: %s" %(self.best_dev)) 217 | logger.Log("MultiNLI Train accuracy: %s" %(self.best_mtrain_acc)) 218 | self.completed = True 219 | break 220 | 221 | def restore(self, best=True): 222 | if True: 223 | path = os.path.join(FIXED_PARAMETERS["ckpt_path"], modname) + ".ckpt_best" 224 | else: 225 | path = os.path.join(FIXED_PARAMETERS["ckpt_path"], modname) + ".ckpt" 226 | self.sess = tf.Session() 227 | self.sess.run(self.init) 228 | self.saver.restore(self.sess, path) 229 | logger.Log("Model restored from file: %s" % path) 230 | 231 | def classify(self, examples): 232 | # This classifies a list of examples 233 | total_batch = int(len(examples) / self.batch_size) 234 | logits = np.empty(3) 235 | genres = [] 236 | for i in range(total_batch): 237 | minibatch_premise_vectors, minibatch_hypothesis_vectors, minibatch_labels, minibatch_genres = self.get_minibatch( 238 | examples, self.batch_size * i, self.batch_size * (i + 1)) 239 | feed_dict = {self.model.premise_x: minibatch_premise_vectors, 240 | self.model.hypothesis_x: minibatch_hypothesis_vectors, 241 | self.model.y: minibatch_labels, 242 | self.model.keep_rate_ph: 1.0} 243 | genres += minibatch_genres 244 | logit, cost = self.sess.run([self.model.logits, self.model.total_cost], feed_dict) 245 | logits = np.vstack([logits, logit]) 246 | 247 | return genres, np.argmax(logits[1:], axis=1), cost 248 | 249 | 250 | classifier = modelClassifier(FIXED_PARAMETERS["seq_length"]) 251 | 252 | """ 253 | Either train the model and then run it on the test-sets or 254 | load the best checkpoint and get accuracy on the test set. Default setting is to train the model. 255 | """ 256 | 257 | test = params.train_or_test() 258 | 259 | # While test-set isn't released, use dev-sets for testing 260 | test_matched = dev_matched 261 | test_mismatched = dev_mismatched 262 | 263 | 264 | if test == False: 265 | classifier.train(training_data, dev_matched, dev_mismatched, dev_snli) 266 | logger.Log("Dev acc on matched multiNLI: %s" %(evaluate_classifier(classifier.classify, \ 267 | test_matched, FIXED_PARAMETERS["batch_size"]))[0]) 268 | 269 | logger.Log("Dev acc on mismatched multiNLI: %s" %(evaluate_classifier(classifier.classify, \ 270 | test_mismatched, FIXED_PARAMETERS["batch_size"]))[0]) 271 | 272 | logger.Log("Test acc on SNLI: %s" %(evaluate_classifier(classifier.classify, \ 273 | test_snli, FIXED_PARAMETERS["batch_size"]))[0]) 274 | else: 275 | results = evaluate_final(classifier.restore, classifier.classify, [test_matched, test_mismatched, test_snli], FIXED_PARAMETERS["batch_size"]) 276 | logger.Log("Acc on multiNLI matched dev-set: %s" %(results[0])) 277 | logger.Log("Acc on multiNLI mismatched dev-set: %s" %(results[1])) 278 | logger.Log("Acc on SNLI test set: %s" %(results[2])) 279 | 280 | # Results by genre, 281 | logger.Log("Acc on matched genre dev-sets: %s" %(evaluate_classifier_genre(classifier.classify, test_matched, FIXED_PARAMETERS["batch_size"])[0])) 282 | logger.Log("Acc on mismatched genres dev-sets: %s" %(evaluate_classifier_genre(classifier.classify, test_mismatched, FIXED_PARAMETERS["batch_size"])[0])) 283 | 284 | 285 | 286 | -------------------------------------------------------------------------------- /python/train_mnli.py: -------------------------------------------------------------------------------- 1 | """ 2 | Training script to train a model on MultiNLI and, optionally, on SNLI data as well. 3 | The "alpha" hyperparamaters set in paramaters.py determines if SNLI data is used in training. 4 | If alpha = 0, no SNLI data is used in training. If alpha > 0, then down-sampled SNLI data is used in training. 5 | """ 6 | 7 | import tensorflow as tf 8 | import os 9 | import importlib 10 | import random 11 | from util import logger 12 | import util.parameters as params 13 | from util.data_processing import * 14 | from util.evaluate import * 15 | 16 | FIXED_PARAMETERS = params.load_parameters() 17 | modname = FIXED_PARAMETERS["model_name"] 18 | logpath = os.path.join(FIXED_PARAMETERS["log_path"], modname) + ".log" 19 | logger = logger.Logger(logpath) 20 | 21 | model = FIXED_PARAMETERS["model_type"] 22 | 23 | module = importlib.import_module(".".join(['models', model])) 24 | MyModel = getattr(module, 'MyModel') 25 | 26 | # Logging parameter settings at each launch of training script 27 | # This will help ensure nothing goes awry in reloading a model and we consistently use the same hyperparameter settings. 28 | logger.Log("FIXED_PARAMETERS\n %s" % FIXED_PARAMETERS) 29 | 30 | 31 | ######################### LOAD DATA ############################# 32 | 33 | logger.Log("Loading data") 34 | training_snli = load_nli_data(FIXED_PARAMETERS["training_snli"], snli=True) 35 | dev_snli = load_nli_data(FIXED_PARAMETERS["dev_snli"], snli=True) 36 | test_snli = load_nli_data(FIXED_PARAMETERS["test_snli"], snli=True) 37 | 38 | training_mnli = load_nli_data(FIXED_PARAMETERS["training_mnli"]) 39 | dev_matched = load_nli_data(FIXED_PARAMETERS["dev_matched"]) 40 | dev_mismatched = load_nli_data(FIXED_PARAMETERS["dev_mismatched"]) 41 | test_matched = load_nli_data(FIXED_PARAMETERS["test_matched"]) 42 | test_mismatched = load_nli_data(FIXED_PARAMETERS["test_mismatched"]) 43 | 44 | if 'temp.jsonl' in FIXED_PARAMETERS["test_matched"]: 45 | # Removing temporary empty file that was created in parameters.py 46 | os.remove(FIXED_PARAMETERS["test_matched"]) 47 | logger.Log("Created and removed empty file called temp.jsonl since test set is not available.") 48 | 49 | dictpath = os.path.join(FIXED_PARAMETERS["log_path"], modname) + ".p" 50 | 51 | if not os.path.isfile(dictpath): 52 | logger.Log("Building dictionary") 53 | if FIXED_PARAMETERS["alpha"] == 0: 54 | word_indices = build_dictionary([training_mnli]) 55 | else: 56 | word_indices = build_dictionary([training_mnli, training_snli]) 57 | 58 | logger.Log("Padding and indexifying sentences") 59 | sentences_to_padded_index_sequences(word_indices, [training_mnli, training_snli, 60 | dev_matched, dev_mismatched, dev_snli, test_snli, 61 | test_matched, test_mismatched]) 62 | pickle.dump(word_indices, open(dictpath, "wb")) 63 | 64 | else: 65 | logger.Log("Loading dictionary from %s" % (dictpath)) 66 | word_indices = pickle.load(open(dictpath, "rb")) 67 | logger.Log("Padding and indexifying sentences") 68 | sentences_to_padded_index_sequences(word_indices, [training_mnli, training_snli, 69 | dev_matched, dev_mismatched, dev_snli, 70 | test_snli, test_matched, test_mismatched]) 71 | 72 | logger.Log("Loading embeddings") 73 | loaded_embeddings = loadEmbedding_rand(FIXED_PARAMETERS["embedding_data_path"], word_indices) 74 | 75 | 76 | class modelClassifier: 77 | def __init__(self, seq_length): 78 | ## Define hyperparameters 79 | self.learning_rate = FIXED_PARAMETERS["learning_rate"] 80 | self.display_epoch_freq = 1 81 | self.display_step_freq = 50 82 | self.embedding_dim = FIXED_PARAMETERS["word_embedding_dim"] 83 | self.dim = FIXED_PARAMETERS["hidden_embedding_dim"] 84 | self.batch_size = FIXED_PARAMETERS["batch_size"] 85 | self.emb_train = FIXED_PARAMETERS["emb_train"] 86 | self.keep_rate = FIXED_PARAMETERS["keep_rate"] 87 | self.sequence_length = FIXED_PARAMETERS["seq_length"] 88 | self.alpha = FIXED_PARAMETERS["alpha"] 89 | 90 | logger.Log("Building model from %s.py" %(model)) 91 | self.model = MyModel(seq_length=self.sequence_length, emb_dim=self.embedding_dim, 92 | hidden_dim=self.dim, embeddings=loaded_embeddings, 93 | emb_train=self.emb_train) 94 | 95 | # Perform gradient descent with Adam 96 | self.optimizer = tf.train.AdamOptimizer(self.learning_rate, beta1=0.9, beta2=0.999).minimize(self.model.total_cost) 97 | 98 | # Boolean stating that training has not been completed, 99 | self.completed = False 100 | 101 | # tf things: initialize variables and create placeholder for session 102 | logger.Log("Initializing variables") 103 | self.init = tf.global_variables_initializer() 104 | self.sess = None 105 | self.saver = tf.train.Saver() 106 | 107 | 108 | def get_minibatch(self, dataset, start_index, end_index): 109 | indices = range(start_index, end_index) 110 | premise_vectors = np.vstack([dataset[i]['sentence1_binary_parse_index_sequence'] for i in indices]) 111 | hypothesis_vectors = np.vstack([dataset[i]['sentence2_binary_parse_index_sequence'] for i in indices]) 112 | genres = [dataset[i]['genre'] for i in indices] 113 | labels = [dataset[i]['label'] for i in indices] 114 | return premise_vectors, hypothesis_vectors, labels, genres 115 | 116 | 117 | def train(self, train_mnli, train_snli, dev_mat, dev_mismat, dev_snli): 118 | self.sess = tf.Session() 119 | self.sess.run(self.init) 120 | 121 | self.step = 0 122 | self.epoch = 0 123 | self.best_dev_mat = 0. 124 | self.best_mtrain_acc = 0. 125 | self.last_train_acc = [.001, .001, .001, .001, .001] 126 | self.best_step = 0 127 | 128 | # Restore most recent checkpoint if it exists. 129 | # Also restore values for best dev-set accuracy and best training-set accuracy 130 | ckpt_file = os.path.join(FIXED_PARAMETERS["ckpt_path"], modname) + ".ckpt" 131 | if os.path.isfile(ckpt_file + ".meta"): 132 | if os.path.isfile(ckpt_file + "_best.meta"): 133 | self.saver.restore(self.sess, (ckpt_file + "_best")) 134 | self.best_dev_mat, dev_cost_mat = evaluate_classifier(self.classify, dev_mat, self.batch_size) 135 | best_dev_mismat, dev_cost_mismat = evaluate_classifier(self.classify, dev_mismat, self.batch_size) 136 | best_dev_snli, dev_cost_snli = evaluate_classifier(self.classify, dev_snli, self.batch_size) 137 | self.best_mtrain_acc, mtrain_cost = evaluate_classifier(self.classify, train_mnli[0:5000], self.batch_size) 138 | if self.alpha != 0.: 139 | self.best_strain_acc, strain_cost = evaluate_classifier(self.classify, train_snli[0:5000], self.batch_size) 140 | logger.Log("Restored best matched-dev acc: %f\n Restored best mismatched-dev acc: %f\n \ 141 | Restored best SNLI-dev acc: %f\n Restored best MulitNLI train acc: %f\n \ 142 | Restored best SNLI train acc: %f" %(self.best_dev_mat, best_dev_mismat, best_dev_snli, 143 | self.best_mtrain_acc, self.best_strain_acc)) 144 | else: 145 | logger.Log("Restored best matched-dev acc: %f\n Restored best mismatched-dev acc: %f\n \ 146 | Restored best SNLI-dev acc: %f\n Restored best MulitNLI train acc: %f" 147 | % (self.best_dev_mat, best_dev_mismat, best_dev_snli, self.best_mtrain_acc)) 148 | 149 | self.saver.restore(self.sess, ckpt_file) 150 | logger.Log("Model restored from file: %s" % ckpt_file) 151 | 152 | # Combine MultiNLI and SNLI data. Alpha has a default value of 0, if we want to use SNLI data, it must be passed as an argument. 153 | beta = int(self.alpha * len(train_snli)) 154 | 155 | ### Training cycle 156 | logger.Log("Training...") 157 | logger.Log("Model will use %s percent of SNLI data during training" %(self.alpha * 100)) 158 | 159 | while True: 160 | training_data = train_mnli + random.sample(train_snli, beta) 161 | random.shuffle(training_data) 162 | avg_cost = 0. 163 | total_batch = int(len(training_data) / self.batch_size) 164 | 165 | # Loop over all batches in epoch 166 | for i in range(total_batch): 167 | # Assemble a minibatch of the next B examples 168 | minibatch_premise_vectors, minibatch_hypothesis_vectors, minibatch_labels, minibatch_genres = self.get_minibatch( 169 | training_data, self.batch_size * i, self.batch_size * (i + 1)) 170 | 171 | # Run the optimizer to take a gradient step, and also fetch the value of the 172 | # cost function for logging 173 | feed_dict = {self.model.premise_x: minibatch_premise_vectors, 174 | self.model.hypothesis_x: minibatch_hypothesis_vectors, 175 | self.model.y: minibatch_labels, 176 | self.model.keep_rate_ph: self.keep_rate} 177 | _, c = self.sess.run([self.optimizer, self.model.total_cost], feed_dict) 178 | 179 | # Since a single epoch can take a ages for larger models (ESIM), 180 | # we'll print accuracy every 50 steps 181 | if self.step % self.display_step_freq == 0: 182 | dev_acc_mat, dev_cost_mat = evaluate_classifier(self.classify, dev_mat, self.batch_size) 183 | dev_acc_mismat, dev_cost_mismat = evaluate_classifier(self.classify, dev_mismat, self.batch_size) 184 | dev_acc_snli, dev_cost_snli = evaluate_classifier(self.classify, dev_snli, self.batch_size) 185 | mtrain_acc, mtrain_cost = evaluate_classifier(self.classify, train_mnli[0:5000], self.batch_size) 186 | 187 | if self.alpha != 0.: 188 | strain_acc, strain_cost = evaluate_classifier(self.classify, train_snli[0:5000], self.batch_size) 189 | logger.Log("Step: %i\t Dev-matched acc: %f\t Dev-mismatched acc: %f\t \ 190 | Dev-SNLI acc: %f\t MultiNLI train acc: %f\t SNLI train acc: %f" 191 | % (self.step, dev_acc_mat, dev_acc_mismat, dev_acc_snli, mtrain_acc, strain_acc)) 192 | logger.Log("Step: %i\t Dev-matched cost: %f\t Dev-mismatched cost: %f\t \ 193 | Dev-SNLI cost: %f\t MultiNLI train cost: %f\t SNLI train cost: %f" 194 | % (self.step, dev_cost_mat, dev_cost_mismat, dev_cost_snli, mtrain_cost, strain_cost)) 195 | else: 196 | logger.Log("Step: %i\t Dev-matched acc: %f\t Dev-mismatched acc: %f\t \ 197 | Dev-SNLI acc: %f\t MultiNLI train acc: %f" %(self.step, dev_acc_mat, 198 | dev_acc_mismat, dev_acc_snli, mtrain_acc)) 199 | logger.Log("Step: %i\t Dev-matched cost: %f\t Dev-mismatched cost: %f\t \ 200 | Dev-SNLI cost: %f\t MultiNLI train cost: %f" %(self.step, dev_cost_mat, 201 | dev_cost_mismat, dev_cost_snli, mtrain_cost)) 202 | 203 | if self.step % 500 == 0: 204 | self.saver.save(self.sess, ckpt_file) 205 | best_test = 100 * (1 - self.best_dev_mat / dev_acc_mat) 206 | if best_test > 0.04: 207 | self.saver.save(self.sess, ckpt_file + "_best") 208 | self.best_dev_mat = dev_acc_mat 209 | self.best_mtrain_acc = mtrain_acc 210 | if self.alpha != 0.: 211 | self.best_strain_acc = strain_acc 212 | self.best_step = self.step 213 | logger.Log("Checkpointing with new best matched-dev accuracy: %f" %(self.best_dev_mat)) 214 | 215 | self.step += 1 216 | 217 | # Compute average loss 218 | avg_cost += c / (total_batch * self.batch_size) 219 | 220 | # Display some statistics about the epoch 221 | if self.epoch % self.display_epoch_freq == 0: 222 | logger.Log("Epoch: %i\t Avg. Cost: %f" %(self.epoch+1, avg_cost)) 223 | 224 | self.epoch += 1 225 | self.last_train_acc[(self.epoch % 5) - 1] = mtrain_acc 226 | 227 | # Early stopping 228 | progress = 1000 * (sum(self.last_train_acc)/(5 * min(self.last_train_acc)) - 1) 229 | 230 | if (progress < 0.1) or (self.step > self.best_step + 30000): 231 | logger.Log("Best matched-dev accuracy: %s" %(self.best_dev_mat)) 232 | logger.Log("MultiNLI Train accuracy: %s" %(self.best_mtrain_acc)) 233 | self.completed = True 234 | break 235 | 236 | def classify(self, examples): 237 | # This classifies a list of examples 238 | if (test == True) or (self.completed == True): 239 | best_path = os.path.join(FIXED_PARAMETERS["ckpt_path"], modname) + ".ckpt_best" 240 | self.sess = tf.Session() 241 | self.sess.run(self.init) 242 | self.saver.restore(self.sess, best_path) 243 | logger.Log("Model restored from file: %s" % best_path) 244 | 245 | total_batch = int(len(examples) / self.batch_size) 246 | logits = np.empty(3) 247 | genres = [] 248 | for i in range(total_batch): 249 | minibatch_premise_vectors, minibatch_hypothesis_vectors, minibatch_labels, minibatch_genres = self.get_minibatch(examples, 250 | self.batch_size * i, self.batch_size * (i + 1)) 251 | feed_dict = {self.model.premise_x: minibatch_premise_vectors, 252 | self.model.hypothesis_x: minibatch_hypothesis_vectors, 253 | self.model.y: minibatch_labels, 254 | self.model.keep_rate_ph: 1.0} 255 | genres += minibatch_genres 256 | logit, cost = self.sess.run([self.model.logits, self.model.total_cost], feed_dict) 257 | logits = np.vstack([logits, logit]) 258 | 259 | return genres, np.argmax(logits[1:], axis=1), cost 260 | 261 | def restore(self, best=True): 262 | if True: 263 | path = os.path.join(FIXED_PARAMETERS["ckpt_path"], modname) + ".ckpt_best" 264 | else: 265 | path = os.path.join(FIXED_PARAMETERS["ckpt_path"], modname) + ".ckpt" 266 | self.sess = tf.Session() 267 | self.sess.run(self.init) 268 | self.saver.restore(self.sess, path) 269 | logger.Log("Model restored from file: %s" % path) 270 | 271 | def classify(self, examples): 272 | # This classifies a list of examples 273 | total_batch = int(len(examples) / self.batch_size) 274 | logits = np.empty(3) 275 | genres = [] 276 | for i in range(total_batch): 277 | minibatch_premise_vectors, minibatch_hypothesis_vectors, minibatch_labels, minibatch_genres = self.get_minibatch(examples, 278 | self.batch_size * i, self.batch_size * (i + 1)) 279 | feed_dict = {self.model.premise_x: minibatch_premise_vectors, 280 | self.model.hypothesis_x: minibatch_hypothesis_vectors, 281 | self.model.y: minibatch_labels, 282 | self.model.keep_rate_ph: 1.0} 283 | genres += minibatch_genres 284 | logit, cost = self.sess.run([self.model.logits, self.model.total_cost], feed_dict) 285 | logits = np.vstack([logits, logit]) 286 | 287 | return genres, np.argmax(logits[1:], axis=1), cost 288 | 289 | 290 | 291 | classifier = modelClassifier(FIXED_PARAMETERS["seq_length"]) 292 | 293 | """ 294 | Either train the model and then run it on the test-sets or 295 | load the best checkpoint and get accuracy on the test set. Default setting is to train the model. 296 | """ 297 | 298 | test = params.train_or_test() 299 | 300 | # While test-set isn't released, use dev-sets for testing 301 | #test_matched = dev_matched 302 | #test_mismatched = dev_mismatched 303 | print("ALL RESULTS ON TEST") 304 | 305 | if test == False: 306 | classifier.train(training_mnli, training_snli, dev_matched, dev_mismatched, dev_snli) 307 | logger.Log("Acc on matched multiNLI dev-set: %s" 308 | % (evaluate_classifier(classifier.classify, test_matched, FIXED_PARAMETERS["batch_size"]))[0]) 309 | logger.Log("Acc on mismatched multiNLI dev-set: %s" 310 | % (evaluate_classifier(classifier.classify, test_mismatched, FIXED_PARAMETERS["batch_size"]))[0]) 311 | logger.Log("Acc on SNLI test-set: %s" 312 | % (evaluate_classifier(classifier.classify, test_snli, FIXED_PARAMETERS["batch_size"]))[0]) 313 | else: 314 | results, bylength = evaluate_final(classifier.restore, classifier.classify, 315 | [test_matched, test_mismatched, test_snli], FIXED_PARAMETERS["batch_size"]) 316 | logger.Log("Acc on multiNLI matched dev-set: %s" %(results[0])) 317 | logger.Log("Acc on multiNLI mismatched dev-set: %s" %(results[1])) 318 | logger.Log("Acc on SNLI test set: %s" %(results[2])) 319 | 320 | #dumppath = os.path.join("./", modname) + "_length.p" 321 | #pickle.dump(bylength, open(dumppath, "wb")) 322 | 323 | # Results by genre, 324 | logger.Log("Acc on matched genre dev-sets: %s" 325 | % (evaluate_classifier_genre(classifier.classify, test_matched, FIXED_PARAMETERS["batch_size"])[0])) 326 | logger.Log("Acc on mismatched genres dev-sets: %s" 327 | % (evaluate_classifier_genre(classifier.classify, test_mismatched, FIXED_PARAMETERS["batch_size"])[0])) 328 | 329 | -------------------------------------------------------------------------------- /python/train_snli.py: -------------------------------------------------------------------------------- 1 | """ 2 | Training script to train a model on only SNLI data. MultiNLI data is loaded into the embeddings enabling us to test the model on MultiNLI data. 3 | """ 4 | 5 | import tensorflow as tf 6 | import os 7 | import importlib 8 | import random 9 | from util import logger 10 | import util.parameters as params 11 | from util.data_processing import * 12 | from util.evaluate import * 13 | 14 | FIXED_PARAMETERS = params.load_parameters() 15 | modname = FIXED_PARAMETERS["model_name"] 16 | logpath = os.path.join(FIXED_PARAMETERS["log_path"], modname) + ".log" 17 | logger = logger.Logger(logpath) 18 | 19 | model = FIXED_PARAMETERS["model_type"] 20 | 21 | module = importlib.import_module(".".join(['models', model])) 22 | MyModel = getattr(module, 'MyModel') 23 | 24 | # Logging parameter settings at each launch of training script 25 | # This will help ensure nothing goes awry in reloading a model and we consistenyl use the same hyperparameter settings. 26 | logger.Log("FIXED_PARAMETERS\n %s" % FIXED_PARAMETERS) 27 | 28 | ######################### LOAD DATA ############################# 29 | 30 | logger.Log("Loading data") 31 | training_snli = load_nli_data(FIXED_PARAMETERS["training_snli"], snli=True) 32 | dev_snli = load_nli_data(FIXED_PARAMETERS["dev_snli"], snli=True) 33 | test_snli = load_nli_data(FIXED_PARAMETERS["test_snli"], snli=True) 34 | 35 | training_mnli = load_nli_data(FIXED_PARAMETERS["training_mnli"]) 36 | dev_matched = load_nli_data(FIXED_PARAMETERS["dev_matched"]) 37 | dev_mismatched = load_nli_data(FIXED_PARAMETERS["dev_mismatched"]) 38 | test_matched = load_nli_data(FIXED_PARAMETERS["test_matched"]) 39 | test_mismatched = load_nli_data(FIXED_PARAMETERS["test_mismatched"]) 40 | 41 | if 'temp.jsonl' in FIXED_PARAMETERS["test_matched"]: 42 | # Removing temporary empty file that was created in parameters.py 43 | os.remove(FIXED_PARAMETERS["test_matched"]) 44 | logger.Log("Created and removed empty file called temp.jsonl since test set is not available.") 45 | 46 | dictpath = os.path.join(FIXED_PARAMETERS["log_path"], modname) + ".p" 47 | 48 | if not os.path.isfile(dictpath): 49 | logger.Log("Building dictionary") 50 | word_indices = build_dictionary([training_snli]) 51 | logger.Log("Padding and indexifying sentences") 52 | sentences_to_padded_index_sequences(word_indices, [training_snli, training_mnli, dev_matched, dev_mismatched, dev_snli, test_snli, test_matched, test_mismatched]) 53 | pickle.dump(word_indices, open(dictpath, "wb")) 54 | 55 | else: 56 | logger.Log("Loading dictionary from %s" % (dictpath)) 57 | word_indices = pickle.load(open(dictpath, "rb")) 58 | logger.Log("Padding and indexifying sentences") 59 | sentences_to_padded_index_sequences(word_indices, [training_mnli, training_snli, dev_matched, dev_mismatched, dev_snli, test_snli, test_matched, test_mismatched]) 60 | 61 | logger.Log("Loading embeddings") 62 | loaded_embeddings = loadEmbedding_rand(FIXED_PARAMETERS["embedding_data_path"], word_indices) 63 | 64 | class modelClassifier: 65 | def __init__(self, seq_length): 66 | ## Define hyperparameters 67 | self.learning_rate = FIXED_PARAMETERS["learning_rate"] 68 | self.display_epoch_freq = 1 69 | self.display_step_freq = 50 70 | self.embedding_dim = FIXED_PARAMETERS["word_embedding_dim"] 71 | self.dim = FIXED_PARAMETERS["hidden_embedding_dim"] 72 | self.batch_size = FIXED_PARAMETERS["batch_size"] 73 | self.emb_train = FIXED_PARAMETERS["emb_train"] 74 | self.keep_rate = FIXED_PARAMETERS["keep_rate"] 75 | self.sequence_length = FIXED_PARAMETERS["seq_length"] 76 | self.alpha = FIXED_PARAMETERS["alpha"] 77 | 78 | logger.Log("Building model from %s.py" %(model)) 79 | self.model = MyModel(seq_length=self.sequence_length, emb_dim=self.embedding_dim, hidden_dim=self.dim, embeddings=loaded_embeddings, emb_train=self.emb_train) 80 | 81 | # Perform gradient descent with Adam 82 | self.optimizer = tf.train.AdamOptimizer(self.learning_rate, beta1=0.9, beta2=0.999).minimize(self.model.total_cost) 83 | 84 | # Boolean stating that training has not been completed, 85 | self.completed = False 86 | 87 | # tf things: initialize variables and create placeholder for session 88 | logger.Log("Initializing variables") 89 | self.init = tf.global_variables_initializer() 90 | self.sess = None 91 | self.saver = tf.train.Saver() 92 | 93 | 94 | def get_minibatch(self, dataset, start_index, end_index): 95 | indices = range(start_index, end_index) 96 | premise_vectors = np.vstack([dataset[i]['sentence1_binary_parse_index_sequence'] for i in indices]) 97 | hypothesis_vectors = np.vstack([dataset[i]['sentence2_binary_parse_index_sequence'] for i in indices]) 98 | genres = [dataset[i]['genre'] for i in indices] 99 | labels = [dataset[i]['label'] for i in indices] 100 | return premise_vectors, hypothesis_vectors, labels, genres 101 | 102 | 103 | def train(self, train_mnli, train_snli, dev_mat, dev_mismat, dev_snli): 104 | self.sess = tf.Session() 105 | self.sess.run(self.init) 106 | 107 | self.step = 1 108 | self.epoch = 0 109 | self.best_dev_snli = 0. 110 | self.best_strain_acc = 0. 111 | self.last_train_acc = [.001, .001, .001, .001, .001] 112 | self.best_step = 0 113 | 114 | # Restore most recent checkpoint if it exists. 115 | # Also restore values for best dev-set accuracy and best training-set accuracy. 116 | ckpt_file = os.path.join(FIXED_PARAMETERS["ckpt_path"], modname) + ".ckpt" 117 | if os.path.isfile(ckpt_file + ".meta"): 118 | if os.path.isfile(ckpt_file + "_best.meta"): 119 | self.saver.restore(self.sess, (ckpt_file + "_best")) 120 | best_dev_mat, dev_cost_mat = evaluate_classifier(self.classify, dev_mat, self.batch_size) 121 | best_dev_mismat, dev_cost_mismat = evaluate_classifier(self.classify, dev_mismat, self.batch_size) 122 | self.best_dev_snli, dev_cost_snli = evaluate_classifier(self.classify, dev_snli, self.batch_size) 123 | self.best_strain_acc, strain_cost = evaluate_classifier(self.classify, train_snli[0:5000], self.batch_size) 124 | logger.Log("Restored best matched-dev acc: %f\n Restored best mismatched-dev acc: %f\n Restored best SNLI-dev acc: %f\n Restored best SNLI train acc: %f" %(best_dev_mat, best_dev_mismat, self.best_dev_snli, self.best_strain_acc)) 125 | 126 | self.saver.restore(self.sess, ckpt_file) 127 | logger.Log("Model restored from file: %s" % ckpt_file) 128 | 129 | training_data = train_snli 130 | 131 | ### Training cycle 132 | logger.Log("Training...") 133 | 134 | while True: 135 | random.shuffle(training_data) 136 | avg_cost = 0. 137 | total_batch = int(len(training_data) / self.batch_size) 138 | 139 | # Loop over all batches in epoch 140 | for i in range(total_batch): 141 | # Assemble a minibatch of the next B examples 142 | minibatch_premise_vectors, minibatch_hypothesis_vectors, minibatch_labels, minibatch_genres = self.get_minibatch( 143 | training_data, self.batch_size * i, self.batch_size * (i + 1)) 144 | 145 | # Run the optimizer to take a gradient step, and also fetch the value of the 146 | # cost function for logging 147 | feed_dict = {self.model.premise_x: minibatch_premise_vectors, 148 | self.model.hypothesis_x: minibatch_hypothesis_vectors, 149 | self.model.y: minibatch_labels, 150 | self.model.keep_rate_ph: self.keep_rate} 151 | _, c = self.sess.run([self.optimizer, self.model.total_cost], feed_dict) 152 | 153 | # Since a single epoch can take a ages for larger models (ESIM), 154 | # we'll print accuracy every 50 steps 155 | if self.step % self.display_step_freq == 0: 156 | dev_acc_mat, dev_cost_mat = evaluate_classifier(self.classify, dev_mat, self.batch_size) 157 | dev_acc_mismat, dev_cost_mismat = evaluate_classifier(self.classify, dev_mismat, self.batch_size) 158 | dev_acc_snli, dev_cost_snli = evaluate_classifier(self.classify, dev_snli, self.batch_size) 159 | strain_acc, strain_cost = evaluate_classifier(self.classify, train_snli[0:5000], self.batch_size) 160 | 161 | logger.Log("Step: %i\t Dev-matched acc: %f\t Dev-mismatched acc: %f\t Dev-SNLI acc: %f\t SNLI train acc: %f" %(self.step, dev_acc_mat, dev_acc_mismat, dev_acc_snli, strain_acc)) 162 | logger.Log("Step: %i\t Dev-matched cost: %f\t Dev-mismatched cost: %f\t Dev-SNLI cost: %f\t SNLI train cost: %f" %(self.step, dev_cost_mat, dev_cost_mismat, dev_cost_snli, strain_cost)) 163 | 164 | if self.step % 500 == 0: 165 | self.saver.save(self.sess, ckpt_file) 166 | best_test = 100 * (1 - self.best_dev_snli / dev_acc_snli) 167 | if best_test > 0.04: 168 | self.saver.save(self.sess, ckpt_file + "_best") 169 | self.best_dev_snli = dev_acc_snli 170 | self.best_strain_acc = strain_acc 171 | self.best_step = self.step 172 | logger.Log("Checkpointing with new best SNLI-dev accuracy: %f" %(self.best_dev_snli)) 173 | 174 | self.step += 1 175 | 176 | # Compute average loss 177 | avg_cost += c / (total_batch * self.batch_size) 178 | 179 | # Display some statistics about the epoch 180 | if self.epoch % self.display_epoch_freq == 0: 181 | logger.Log("Epoch: %i\t Avg. Cost: %f" %(self.epoch+1, avg_cost)) 182 | 183 | self.epoch += 1 184 | self.last_train_acc[(self.epoch % 5) - 1] = strain_acc 185 | 186 | # Early stopping 187 | progress = 1000 * (sum(self.last_train_acc)/(5 * min(self.last_train_acc)) - 1) 188 | 189 | if (progress < 0.1) or (self.step > self.best_step + 30000): 190 | logger.Log("Best snli-dev accuracy: %s" %(self.best_dev_snli)) 191 | logger.Log("MultiNLI Train accuracy: %s" %(self.best_strain_acc)) 192 | self.completed = True 193 | break 194 | 195 | def restore(self, best=True): 196 | if True: 197 | path = os.path.join(FIXED_PARAMETERS["ckpt_path"], modname) + ".ckpt_best" 198 | else: 199 | path = os.path.join(FIXED_PARAMETERS["ckpt_path"], modname) + ".ckpt" 200 | self.sess = tf.Session() 201 | self.sess.run(self.init) 202 | self.saver.restore(self.sess, path) 203 | logger.Log("Model restored from file: %s" % path) 204 | 205 | def classify(self, examples): 206 | # This classifies a list of examples 207 | total_batch = int(len(examples) / self.batch_size) 208 | logits = np.empty(3) 209 | genres = [] 210 | for i in range(total_batch): 211 | minibatch_premise_vectors, minibatch_hypothesis_vectors, minibatch_labels, minibatch_genres = self.get_minibatch( 212 | examples, self.batch_size * i, self.batch_size * (i + 1)) 213 | feed_dict = {self.model.premise_x: minibatch_premise_vectors, 214 | self.model.hypothesis_x: minibatch_hypothesis_vectors, 215 | self.model.y: minibatch_labels, 216 | self.model.keep_rate_ph: 1.0} 217 | genres += minibatch_genres 218 | logit, cost = self.sess.run([self.model.logits, self.model.total_cost], feed_dict) 219 | logits = np.vstack([logits, logit]) 220 | 221 | return genres, np.argmax(logits[1:], axis=1), cost 222 | 223 | 224 | classifier = modelClassifier(FIXED_PARAMETERS["seq_length"]) 225 | 226 | """ 227 | Either train the model and then run it on the test-sets or 228 | load the best checkpoint and get accuracy on the test set. Default setting is to train the model. 229 | """ 230 | 231 | test = params.train_or_test() 232 | 233 | # While test-set isn't released, use dev-sets for testing 234 | test_matched = dev_matched 235 | test_mismatched = dev_mismatched 236 | 237 | 238 | if test == False: 239 | classifier.train(training_mnli, training_snli, dev_matched, dev_mismatched, dev_snli) 240 | logger.Log("Acc on matched multiNLI dev-set: %s" %(evaluate_classifier(classifier.classify, test_matched, FIXED_PARAMETERS["batch_size"]))[0]) 241 | logger.Log("Acc on mismatched multiNLI dev-set: %s" %(evaluate_classifier(classifier.classify, test_mismatched, FIXED_PARAMETERS["batch_size"]))[0]) 242 | logger.Log("Acc on SNLI test-set: %s" %(evaluate_classifier(classifier.classify, test_snli, FIXED_PARAMETERS["batch_size"]))[0]) 243 | else: 244 | results = evaluate_final(classifier.restore, classifier.classify, [test_matched, test_mismatched, test_snli], FIXED_PARAMETERS["batch_size"]) 245 | logger.Log("Acc on multiNLI matched dev-set: %s" %(results[0])) 246 | logger.Log("Acc on multiNLI mismatched dev-set: %s" %(results[1])) 247 | logger.Log("Acc on SNLI test set: %s" %(results[2])) 248 | 249 | # Results by genre, 250 | logger.Log("Acc on matched genre dev-sets: %s" %(evaluate_classifier_genre(classifier.classify, test_matched, FIXED_PARAMETERS["batch_size"])[0])) 251 | logger.Log("Acc on mismatched genres dev-sets: %s" %(evaluate_classifier_genre(classifier.classify, test_mismatched, FIXED_PARAMETERS["batch_size"])[0])) 252 | 253 | -------------------------------------------------------------------------------- /python/util/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nyu-mll/multiNLI/c4078cad9f9b5d06a672f2edb69bde007a3567a9/python/util/__init__.py -------------------------------------------------------------------------------- /python/util/blocks.py: -------------------------------------------------------------------------------- 1 | """ 2 | 3 | Functions and components that can be slotted into tensorflow models. 4 | 5 | TODO: Write functions for various types of attention. 6 | 7 | """ 8 | 9 | import tensorflow as tf 10 | 11 | 12 | def length(sequence): 13 | """ 14 | Get true length of sequences (without padding), and mask for true-length in max-length. 15 | 16 | Input of shape: (batch_size, max_seq_length, hidden_dim) 17 | Output shapes, 18 | length: (batch_size) 19 | mask: (batch_size, max_seq_length, 1) 20 | """ 21 | populated = tf.sign(tf.abs(sequence)) 22 | length = tf.cast(tf.reduce_sum(populated, axis=1), tf.int32) 23 | mask = tf.cast(tf.expand_dims(populated, -1), tf.float32) 24 | return length, mask 25 | 26 | 27 | 28 | def biLSTM(inputs, dim, seq_len, name): 29 | """ 30 | A Bi-Directional LSTM layer. Returns forward and backward hidden states as a tuple, and cell states as a tuple. 31 | 32 | Ouput of hidden states: [(batch_size, max_seq_length, hidden_dim), (batch_size, max_seq_length, hidden_dim)] 33 | Same shape for cell states. 34 | """ 35 | with tf.name_scope(name): 36 | with tf.variable_scope('forward' + name): 37 | lstm_fwd = tf.contrib.rnn.LSTMCell(num_units=dim) 38 | with tf.variable_scope('backward' + name): 39 | lstm_bwd = tf.contrib.rnn.LSTMCell(num_units=dim) 40 | 41 | hidden_states, cell_states = tf.nn.bidirectional_dynamic_rnn(cell_fw=lstm_fwd, cell_bw=lstm_bwd, inputs=inputs, sequence_length=seq_len, dtype=tf.float32, scope=name) 42 | 43 | return hidden_states, cell_states 44 | 45 | 46 | def LSTM(inputs, dim, seq_len, name): 47 | """ 48 | An LSTM layer. Returns hidden states and cell states as a tuple. 49 | 50 | Ouput shape of hidden states: (batch_size, max_seq_length, hidden_dim) 51 | Same shape for cell states. 52 | """ 53 | with tf.name_scope(name): 54 | cell = tf.contrib.rnn.LSTMCell(num_units=dim) 55 | hidden_states, cell_states = tf.nn.dynamic_rnn(cell, inputs=inputs, sequence_length=seq_len, dtype=tf.float32, scope=name) 56 | 57 | return hidden_states, cell_states 58 | 59 | 60 | def last_output(output, true_length): 61 | """ 62 | To get the last hidden layer form a dynamically unrolled RNN. 63 | Input of shape (batch_size, max_seq_length, hidden_dim). 64 | 65 | true_length: Tensor of shape (batch_size). Such a tensor is given by the length() function. 66 | Output of shape (batch_size, hidden_dim). 67 | """ 68 | max_length = int(output.get_shape()[1]) 69 | length_mask = tf.expand_dims(tf.one_hot(true_length-1, max_length, on_value=1., off_value=0.), -1) 70 | last_output = tf.reduce_sum(tf.multiply(output, length_mask), 1) 71 | return last_output 72 | 73 | 74 | def masked_softmax(scores, mask): 75 | """ 76 | Used to calculcate a softmax score with true sequence length (without padding), rather than max-sequence length. 77 | 78 | Input shape: (batch_size, max_seq_length, hidden_dim). 79 | mask parameter: Tensor of shape (batch_size, max_seq_length). Such a mask is given by the length() function. 80 | """ 81 | numerator = tf.exp(tf.subtract(scores, tf.reduce_max(scores, 1, keep_dims=True))) * mask 82 | denominator = tf.reduce_sum(numerator, 1, keep_dims=True) 83 | weights = tf.div(numerator, denominator) 84 | return weights 85 | -------------------------------------------------------------------------------- /python/util/data_processing.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import re 3 | import random 4 | import json 5 | import collections 6 | import parameters as params 7 | import pickle 8 | 9 | FIXED_PARAMETERS = params.load_parameters() 10 | 11 | LABEL_MAP = { 12 | "entailment": 0, 13 | "neutral": 1, 14 | "contradiction": 2, 15 | "hidden": 0 16 | } 17 | 18 | PADDING = "" 19 | UNKNOWN = "" 20 | 21 | def load_nli_data(path, snli=False): 22 | """ 23 | Load MultiNLI or SNLI data. 24 | If the "snli" parameter is set to True, a genre label of snli will be assigned to the data. 25 | """ 26 | data = [] 27 | with open(path) as f: 28 | for line in f: 29 | loaded_example = json.loads(line) 30 | if loaded_example["gold_label"] not in LABEL_MAP: 31 | continue 32 | loaded_example["label"] = LABEL_MAP[loaded_example["gold_label"]] 33 | if snli: 34 | loaded_example["genre"] = "snli" 35 | data.append(loaded_example) 36 | random.seed(1) 37 | random.shuffle(data) 38 | return data 39 | 40 | def load_nli_data_genre(path, genre, snli=True): 41 | """ 42 | Load a specific genre's examples from MultiNLI, or load SNLI data and assign a "snli" genre to the examples. 43 | If the "snli" parameter is set to True, a genre label of snli will be assigned to the data. If set to true, it will overwrite the genre label for MultiNLI data. 44 | """ 45 | data = [] 46 | j = 0 47 | with open(path) as f: 48 | for line in f: 49 | loaded_example = json.loads(line) 50 | if loaded_example["gold_label"] not in LABEL_MAP: 51 | continue 52 | loaded_example["label"] = LABEL_MAP[loaded_example["gold_label"]] 53 | if snli: 54 | loaded_example["genre"] = "snli" 55 | if loaded_example["genre"] == genre: 56 | data.append(loaded_example) 57 | random.seed(1) 58 | random.shuffle(data) 59 | return data 60 | 61 | def tokenize(string): 62 | string = re.sub(r'\(|\)', '', string) 63 | return string.split() 64 | 65 | def build_dictionary(training_datasets): 66 | """ 67 | Extract vocabulary and build dictionary. 68 | """ 69 | word_counter = collections.Counter() 70 | for i, dataset in enumerate(training_datasets): 71 | for example in dataset: 72 | word_counter.update(tokenize(example['sentence1_binary_parse'])) 73 | word_counter.update(tokenize(example['sentence2_binary_parse'])) 74 | 75 | vocabulary = set([word for word in word_counter]) 76 | vocabulary = list(vocabulary) 77 | vocabulary = [PADDING, UNKNOWN] + vocabulary 78 | 79 | word_indices = dict(zip(vocabulary, range(len(vocabulary)))) 80 | 81 | return word_indices 82 | 83 | def sentences_to_padded_index_sequences(word_indices, datasets): 84 | """ 85 | Annotate datasets with feature vectors. Adding right-sided padding. 86 | """ 87 | for i, dataset in enumerate(datasets): 88 | for example in dataset: 89 | for sentence in ['sentence1_binary_parse', 'sentence2_binary_parse']: 90 | example[sentence + '_index_sequence'] = np.zeros((FIXED_PARAMETERS["seq_length"]), dtype=np.int32) 91 | 92 | token_sequence = tokenize(example[sentence]) 93 | padding = FIXED_PARAMETERS["seq_length"] - len(token_sequence) 94 | 95 | for i in range(FIXED_PARAMETERS["seq_length"]): 96 | if i >= len(token_sequence): 97 | index = word_indices[PADDING] 98 | else: 99 | if token_sequence[i] in word_indices: 100 | index = word_indices[token_sequence[i]] 101 | else: 102 | index = word_indices[UNKNOWN] 103 | example[sentence + '_index_sequence'][i] = index 104 | 105 | 106 | def loadEmbedding_zeros(path, word_indices): 107 | """ 108 | Load GloVe embeddings. Initializng OOV words to vector of zeros. 109 | """ 110 | emb = np.zeros((len(word_indices), FIXED_PARAMETERS["word_embedding_dim"]), dtype='float32') 111 | 112 | with open(path, 'r') as f: 113 | for i, line in enumerate(f): 114 | if FIXED_PARAMETERS["embeddings_to_load"] != None: 115 | if i >= FIXED_PARAMETERS["embeddings_to_load"]: 116 | break 117 | 118 | s = line.split() 119 | if s[0] in word_indices: 120 | emb[word_indices[s[0]], :] = np.asarray(s[1:]) 121 | 122 | return emb 123 | 124 | 125 | def loadEmbedding_rand(path, word_indices): 126 | """ 127 | Load GloVe embeddings. Doing a random normal initialization for OOV words. 128 | """ 129 | n = len(word_indices) 130 | m = FIXED_PARAMETERS["word_embedding_dim"] 131 | emb = np.empty((n, m), dtype=np.float32) 132 | 133 | emb[:,:] = np.random.normal(size=(n,m)) 134 | 135 | # Explicitly assign embedding of to be zeros. 136 | emb[0:2, :] = np.zeros((1,m), dtype="float32") 137 | 138 | with open(path, 'r') as f: 139 | for i, line in enumerate(f): 140 | if FIXED_PARAMETERS["embeddings_to_load"] != None: 141 | if i >= FIXED_PARAMETERS["embeddings_to_load"]: 142 | break 143 | 144 | s = line.split() 145 | if s[0] in word_indices: 146 | emb[word_indices[s[0]], :] = np.asarray(s[1:]) 147 | 148 | return emb 149 | 150 | -------------------------------------------------------------------------------- /python/util/evaluate.py: -------------------------------------------------------------------------------- 1 | import csv 2 | import sys 3 | 4 | def evaluate_classifier(classifier, eval_set, batch_size): 5 | """ 6 | Function to get accuracy and cost of the model, evaluated on a chosen dataset. 7 | 8 | classifier: the model's classfier, it should return genres, logit values, and cost for a given minibatch of the evaluation dataset 9 | eval_set: the chosen evaluation set, for eg. the dev-set 10 | batch_size: the size of minibatches. 11 | """ 12 | correct = 0 13 | genres, hypotheses, cost = classifier(eval_set) 14 | cost = cost / batch_size 15 | full_batch = int(len(eval_set) / batch_size) * batch_size 16 | for i in range(full_batch): 17 | hypothesis = hypotheses[i] 18 | if hypothesis == eval_set[i]['label']: 19 | correct += 1 20 | return correct / float(len(eval_set)), cost 21 | 22 | def evaluate_classifier_genre(classifier, eval_set, batch_size): 23 | """ 24 | Function to get accuracy and cost of the model by genre, evaluated on a chosen dataset. It returns a dictionary of accuracies by genre and cost for the full evaluation dataset. 25 | 26 | classifier: the model's classfier, it should return genres, logit values, and cost for a given minibatch of the evaluation dataset 27 | eval_set: the chosen evaluation set, for eg. the dev-set 28 | batch_size: the size of minibatches. 29 | """ 30 | genres, hypotheses, cost = classifier(eval_set) 31 | correct = dict((genre,0) for genre in set(genres)) 32 | count = dict((genre,0) for genre in set(genres)) 33 | cost = cost / batch_size 34 | full_batch = int(len(eval_set) / batch_size) * batch_size 35 | 36 | for i in range(full_batch): 37 | hypothesis = hypotheses[i] 38 | genre = genres[i] 39 | if hypothesis == eval_set[i]['label']: 40 | correct[genre] += 1. 41 | count[genre] += 1. 42 | 43 | if genre != eval_set[i]['genre']: 44 | print 'welp!' 45 | 46 | accuracy = {k: correct[k]/count[k] for k in correct} 47 | 48 | return accuracy, cost 49 | 50 | def evaluate_classifier_bylength(classifier, eval_set, batch_size): 51 | """ 52 | Function to get accuracy and cost of the model by genre, evaluated on a chosen dataset. It returns a dictionary of accuracies by genre and cost for the full evaluation dataset. 53 | 54 | classifier: the model's classfier, it should return genres, logit values, and cost for a given minibatch of the evaluation dataset 55 | eval_set: the chosen evaluation set, for eg. the dev-set 56 | batch_size: the size of minibatches. 57 | """ 58 | genres, hypotheses, cost = classifier(eval_set) 59 | correct = dict((genre,0) for genre in set(genres)) 60 | count = dict((genre,0) for genre in set(genres)) 61 | cost = cost / batch_size 62 | full_batch = int(len(eval_set) / batch_size) * batch_size 63 | 64 | for i in range(full_batch): 65 | hypothesis = hypotheses[i] 66 | genre = genres[i] 67 | if hypothesis == eval_set[i]['label']: 68 | correct[genre] += 1. 69 | count[genre] += 1. 70 | 71 | if genre != eval_set[i]['genre']: 72 | print 'welp!' 73 | 74 | accuracy = {k: correct[k]/count[k] for k in correct} 75 | 76 | return accuracy, cost 77 | 78 | def evaluate_final(restore, classifier, eval_sets, batch_size): 79 | """ 80 | Function to get percentage accuracy of the model, evaluated on a set of chosen datasets. 81 | 82 | restore: a function to restore a stored checkpoint 83 | classifier: the model's classfier, it should return genres, logit values, and cost for a given minibatch of the evaluation dataset 84 | eval_set: the chosen evaluation set, for eg. the dev-set 85 | batch_size: the size of minibatches. 86 | """ 87 | restore(best=True) 88 | percentages = [] 89 | length_results = [] 90 | for eval_set in eval_sets: 91 | bylength_prem = {} 92 | bylength_hyp = {} 93 | genres, hypotheses, cost = classifier(eval_set) 94 | correct = 0 95 | cost = cost / batch_size 96 | full_batch = int(len(eval_set) / batch_size) * batch_size 97 | 98 | for i in range(full_batch): 99 | hypothesis = hypotheses[i] 100 | 101 | length_1 = len(eval_set[i]['sentence1'].split()) 102 | length_2 = len(eval_set[i]['sentence2'].split()) 103 | if length_1 not in bylength_prem.keys(): 104 | bylength_prem[length_1] = [0,0] 105 | if length_2 not in bylength_hyp.keys(): 106 | bylength_hyp[length_2] = [0,0] 107 | 108 | bylength_prem[length_1][1] += 1 109 | bylength_hyp[length_2][1] += 1 110 | 111 | if hypothesis == eval_set[i]['label']: 112 | correct += 1 113 | bylength_prem[length_1][0] += 1 114 | bylength_hyp[length_2][0] += 1 115 | percentages.append(correct / float(len(eval_set))) 116 | length_results.append((bylength_prem, bylength_hyp)) 117 | return percentages, length_results 118 | 119 | 120 | def predictions_kaggle(classifier, eval_set, batch_size, name): 121 | """ 122 | Get comma-separated CSV of predictions. 123 | Output file has two columns: pairID, prediction 124 | """ 125 | INVERSE_MAP = { 126 | 0: "entailment", 127 | 1: "neutral", 128 | 2: "contradiction" 129 | } 130 | 131 | hypotheses = classifier(eval_set) 132 | predictions = [] 133 | 134 | for i in range(len(eval_set)): 135 | hypothesis = hypotheses[i] 136 | prediction = INVERSE_MAP[hypothesis] 137 | pairID = eval_set[i]["pairID"] 138 | predictions.append((pairID, prediction)) 139 | 140 | #predictions = sorted(predictions, key=lambda x: int(x[0])) 141 | 142 | f = open( name + '_predictions.csv', 'wb') 143 | w = csv.writer(f, delimiter = ',') 144 | w.writerow(['pairID','gold_label']) 145 | for example in predictions: 146 | w.writerow(example) 147 | f.close() 148 | -------------------------------------------------------------------------------- /python/util/logger.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import sys 3 | import json 4 | 5 | class Logger(object): 6 | """ 7 | A logging that doesn't leave logs open between writes, so as to allow AFS synchronization. 8 | """ 9 | 10 | # Level constants 11 | DEBUG = 0 12 | INFO = 1 13 | WARNING = 2 14 | ERROR = 3 15 | 16 | def __init__(self, log_path=None, json_log_path=None, min_print_level=0, min_file_level=0): 17 | """ 18 | log_path: The full path for the log file to write. The file will be appended to if it exists. 19 | min_print_level: Only messages with level above this level will be printed to stderr. 20 | min_file_level: Only messages with level above this level will be written to disk. 21 | """ 22 | self.log_path = log_path 23 | self.json_log_path = json_log_path 24 | self.min_print_level = min_print_level 25 | self.min_file_level = min_file_level 26 | 27 | def Log(self, message, level=INFO): 28 | if level >= self.min_print_level: 29 | # Write to STDERR 30 | sys.stderr.write("[%i] %s\n" % (level, message)) 31 | if self.log_path and level >= self.min_file_level: 32 | # Write to the log file then close it 33 | with open(self.log_path, 'a') as f: 34 | datetime_string = datetime.datetime.now().strftime( 35 | "%y-%m-%d %H:%M:%S") 36 | f.write("%s [%i] %s\n" % (datetime_string, level, message)) 37 | 38 | def LogJSON(self, message_obj, level=INFO): 39 | if self.json_log_path and level >= self.min_file_level: 40 | with open(self.json_log_path, 'w') as f: 41 | print >>f, json.dumps(message_obj) 42 | else: 43 | sys.stderr.write('WARNING: No JSON log filename.') 44 | 45 | -------------------------------------------------------------------------------- /python/util/parameters.py: -------------------------------------------------------------------------------- 1 | """ 2 | The hyperparameters for a model are defined here. Arguments like the type of model, model name, paths to data, logs etc. are also defined here. 3 | All paramters and arguments can be changed by calling flags in the command line. 4 | 5 | Required arguements are, 6 | model_type: which model you wish to train with. Valid model types: cbow, bilstm, and esim. 7 | model_name: the name assigned to the model being trained, this will prefix the name of the logs and checkpoint files. 8 | """ 9 | 10 | import argparse 11 | import io 12 | import os 13 | import json 14 | 15 | parser = argparse.ArgumentParser() 16 | 17 | models = ['esim','cbow', 'bilstm', 'lstm'] 18 | def types(s): 19 | options = [mod for mod in models if s in models] 20 | if len(options) == 1: 21 | return options[0] 22 | return s 23 | 24 | # Valid genres to train on. 25 | genres = ['travel', 'fiction', 'slate', 'telephone', 'government'] 26 | def subtypes(s): 27 | options = [mod for mod in genres if s in genres] 28 | if len(options) == 1: 29 | return options[0] 30 | return s 31 | 32 | parser.add_argument("model_type", choices=models, type=types, help="Give model type.") 33 | parser.add_argument("model_name", type=str, help="Give model name, this will name logs and checkpoints made. For example cbow, esim_test etc.") 34 | 35 | parser.add_argument("--datapath", type=str, default="../data") 36 | parser.add_argument("--ckptpath", type=str, default="../logs") 37 | parser.add_argument("--logpath", type=str, default="../logs") 38 | 39 | parser.add_argument("--emb_to_load", type=int, default=None, help="Number of embeddings to load. If None, all embeddings are loaded.") 40 | parser.add_argument("--learning_rate", type=float, default=0.0004, help="Learning rate for model") 41 | parser.add_argument("--keep_rate", type=float, default=0.5, help="Keep rate for dropout in the model") 42 | parser.add_argument("--seq_length", type=int, default=50, help="Max sequence length") 43 | parser.add_argument("--emb_train", action='store_true', help="Call if you want to make your word embeddings trainable.") 44 | 45 | parser.add_argument("--genre", type=str, help="Which genre to train on") 46 | parser.add_argument("--alpha", type=float, default=0., help="What percentage of SNLI data to use in training") 47 | 48 | parser.add_argument("--test", action='store_true', help="Call if you want to only test on the best checkpoint.") 49 | 50 | args = parser.parse_args() 51 | 52 | """ 53 | # Check if test sets are available. If not, create an empty file. 54 | test_matched = "{}/multinli_0.9/multinli_0.9_test_matched_unlabeled.jsonl".format(args.datapath) 55 | 56 | if os.path.isfile(test_matched): 57 | test_matched = "{}/multinli_0.9/multinli_0.9_test_matched_unlabeled.jsonl".format(args.datapath) 58 | test_mismatched = "{}/multinli_0.9/multinli_0.9_test_matched_unlabeled.jsonl".format(args.datapath) 59 | test_path = "{}/multinli_0.9/".format(args.datapath) 60 | else: 61 | test_path = "{}/multinli_0.9/".format(args.datapath) 62 | temp_file = os.path.join(test_path, "temp.jsonl") 63 | io.open(temp_file, "wb") 64 | test_matched = temp_file 65 | test_mismatched = temp_file 66 | """ 67 | # Check if test sets are available. If not, create an empty file. 68 | test_matched = "{}/multinli_0.9/multinli_0.9_test_matched.jsonl".format(args.datapath) 69 | 70 | if os.path.isfile(test_matched): 71 | test_matched = "{}/multinli_0.9/multinli_0.9_dev_matched.jsonl".format(args.datapath) #"{}/multinli_0.9/multinli_0.9_test_matched.jsonl".format(args.datapath) 72 | test_mismatched = "{}/multinli_0.9/multinli_0.9_dev_mismatched.jsonl".format(args.datapath) #"{}/multinli_0.9/multinli_0.9_test_mismatched.jsonl".format(args.datapath) 73 | test_path = "{}".format(args.datapath) 74 | else: 75 | test_path = "{}".format(args.datapath) 76 | temp_file = os.path.join(test_path, "temp.jsonl") 77 | io.open(temp_file, "wb") 78 | test_matched = temp_file 79 | test_mismatched = temp_file 80 | 81 | 82 | def load_parameters(): 83 | FIXED_PARAMETERS = { 84 | "model_type": args.model_type, 85 | "model_name": args.model_name, 86 | "training_mnli": "{}/multinli_0.9/multinli_0.9_train.jsonl".format(args.datapath), 87 | "dev_matched": "{}/multinli_0.9/multinli_0.9_dev_matched.jsonl".format(args.datapath), 88 | "dev_mismatched": "{}/multinli_0.9/multinli_0.9_dev_mismatched.jsonl".format(args.datapath), 89 | "test_matched": test_matched, 90 | "test_mismatched": test_mismatched, 91 | "training_snli": "{}/snli_1.0/snli_1.0_train.jsonl".format(args.datapath), 92 | "dev_snli": "{}/snli_1.0/snli_1.0_dev.jsonl".format(args.datapath), 93 | "test_snli": "{}/snli_1.0/snli_1.0_test.jsonl".format(args.datapath), 94 | "embedding_data_path": "{}/glove.840B.300d.txt".format(args.datapath), 95 | #"embedding_data_path": "{}/glove.6B.50d.txt".format(args.datapath), 96 | "log_path": "{}".format(args.logpath), 97 | "ckpt_path": "{}".format(args.ckptpath), 98 | "embeddings_to_load": args.emb_to_load, 99 | "word_embedding_dim": 300, 100 | "hidden_embedding_dim": 300, 101 | #"word_embedding_dim": 50, 102 | #"hidden_embedding_dim": 50, 103 | "seq_length": args.seq_length, 104 | "keep_rate": args.keep_rate, 105 | "batch_size": 32, 106 | "learning_rate": args.learning_rate, 107 | "emb_train": args.emb_train, 108 | "alpha": args.alpha, 109 | "genre": args.genre 110 | } 111 | 112 | return FIXED_PARAMETERS 113 | 114 | def train_or_test(): 115 | return args.test 116 | 117 | --------------------------------------------------------------------------------