├── .gitignore ├── LICENSE ├── README.md ├── config ├── preprocessing │ ├── bnli_preprocessing.json │ ├── mnli_preprocessing.json │ └── snli_preprocessing.json ├── testing │ └── mnli_testing.json └── training │ ├── mnli_training.json │ └── snli_training.json ├── data ├── checkpoints │ └── SNLI │ │ └── best.pth.tar └── dataset │ └── bnli.jsonl ├── esim.png ├── esim ├── __init__.py ├── data.py ├── layers.py ├── model.py └── utils.py ├── scripts ├── fetch_data.py ├── preprocessing │ ├── preprocess_bnli.py │ ├── preprocess_mnli.py │ └── preprocess_snli.py ├── testing │ ├── test_mnli.py │ └── test_snli.py └── training │ ├── train_mnli.py │ ├── train_snli.py │ └── utils.py └── setup.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | 106 | # VSCode 107 | .vscode 108 | 109 | # Data 110 | data/dataset/snli_1.0 111 | data/dataset/mnli_1.0 112 | data/embeddings 113 | data/preprocessed 114 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ESIM - Enhanced Sequential Inference Model 2 | Implementation of the ESIM model for natural language inference with PyTorch 3 | 4 | This repository contains an implementation with PyTorch of the sequential model presented in the paper 5 | ["Enhanced LSTM for Natural Language Inference"](https://arxiv.org/pdf/1609.06038.pdf) by Chen et al. in 2016. 6 | 7 | The figure below illustrates a high-level view of the model's architecture. 8 | 9 | ![alt-text](./esim.png) 10 | 11 | This model was implemented in the context of [my master's thesis](https://coetaur0.github.io/projects/lean/) at the university of Geneva. 12 | 13 | ## How to 14 | ### Install the package 15 | To use the model defined in this repository, you will first need to install PyTorch on your machine by following the steps 16 | described on the package's [official page](https://pytorch.org/get-started/locally/) (this step is only necessary if you use 17 | Windows). 18 | Then, to install the dependencies necessary to run the model, simply execute the command `pip install --upgrade .` from within 19 | the cloned repository (at the root, and preferably inside of a [virtual environment](https://docs.python.org/3/tutorial/venv.html)). 20 | 21 | ### Fetch the data to train and test the model 22 | The *fetch_data.py* script located in the *scripts/* folder of this repository can be used to download some NLI dataset and 23 | pretrained word embeddings. By default, the script fetches the [SNLI](https://nlp.stanford.edu/projects/snli/) corpus and 24 | the [GloVe 840B 300d](https://nlp.stanford.edu/projects/glove/) embeddings. Other datasets can be downloaded by simply passing 25 | their URL as argument to the script (for example, the [MultNLI dataset](https://www.nyu.edu/projects/bowman/multinli/)). 26 | 27 | The script's usage is the following: 28 | ``` 29 | fetch_data.py [-h] [--dataset_url DATASET_URL] 30 | [--embeddings_url EMBEDDINGS_URL] 31 | [--target_dir TARGET_DIR] 32 | ``` 33 | where `target_dir` is the path to a directory where the downloaded data must be saved (defaults to *../data/*). 34 | 35 | For MultiNLI, the matched and mismatched test sets need to be manually downloaded from Kaggle and the corresponding .txt files 36 | copied in the *multinli_1.0* dataset folder. 37 | 38 | ### Preprocess the data 39 | Before the downloaded corpus and embeddings can be used in the ESIM model, they need to be preprocessed. This can be done with 40 | the *preprocess_\*.py* scripts in the *scripts/preprocessing* folder of this repository. The *preprocess_snli.py* script can be 41 | used to preprocess SNLI, *preprocess_mnli.py* to preprocess MultiNLI, and *preprocess_bnli.py* to preprocess the Breaking NLI 42 | (BNLI) dataset. Note that when calling the script fot BNLI, the SNLI data should have been preprocessed first, so that the 43 | worddict produced for it can be used on BNLI. 44 | 45 | The scripts' usage is the following (replace the \* with *snli*, *mnli* or *bnli*): 46 | ``` 47 | preprocess_*.py [-h] [--config CONFIG] 48 | ``` 49 | where `config` is the path to a configuration file defining the parameters to be used for preprocessing. Default 50 | configuration files can be found in the *config/preprocessing* folder of this repository. 51 | 52 | ### Train the model 53 | The *train_\*.py* scripts in the *scripts/training* folder can be used to train the ESIM model on some training data and 54 | validate it on some validation data. 55 | 56 | The script's usage is the following (replace the \* with *snli* or *mnli*): 57 | ``` 58 | train_*.py [-h] [--config CONFIG] [--checkpoint CHECKPOINT] 59 | ``` 60 | where `config` is a configuration file (default ones are located in the *config/training* folder), and `checkpoint` is an 61 | optional checkpoint from which training can be resumed. Checkpoints are created by the script after each training epoch, with 62 | the name *esim_\*.pth.tar*, where '\*' indicates the epoch's number. 63 | 64 | ### Test the model 65 | The *test_\*.py* scripts in the *scripts/testing* folder can be used to test a pretrained ESIM model on some test data. 66 | 67 | To test on SNLI, use the *test_snli.py* script as follows: 68 | ``` 69 | test_snli.py [-h] test_data checkpoint 70 | ``` 71 | where `test_data` is the path to some preprocessed test set, and `checkpoint` is the path to a checkpoint produced by the 72 | *train_snli.py* script (either one of the checkpoints created after the training epochs, or the best model seen during 73 | training, which is saved in *data/checkpoints/SNLI/best.pth.tar* - the difference between the *esim_\*.pth.tar* files and 74 | *best.pth.tar* is that the latter cannot be used to resume training, as it doesn't contain the optimizer's state). 75 | 76 | The *test_snli.py* script can also be used on the Breaking NLI dataset with a model pretrained on SNLI. 77 | 78 | To test on MultiNLI, use the *test_mnli.py* script as follows: 79 | ``` 80 | test_mnli.py [-h] [--config CONFIG] checkpoint 81 | ``` 82 | where `config` is a configuration file (a default one is available in *config/testing*) and `checkpoint` is a checkpoint 83 | produced by the *train_mnli.py* script. 84 | 85 | The *test_mnli.py* script makes predictions on MultiNLI's matched and mismatched test sets and saves them in .csv files. 86 | To get the classification accuracy associated to the model's predictions, the .csv files it produces need to be submitted 87 | to the Kaggle competitions for MultiNLI. 88 | 89 | ## Results 90 | A model pre-trained on SNLI is made available in the *data/checkpoints/SNLI* folder of this repository. The model was trained 91 | with the parameters defined in the default configuration files provided in *config/*. 92 | To test it, simply execute `python test_snli.py ../../preprocessed/SNLI/test_data.pkl ../../data/checkpoints/best.pth.tar` 93 | from within the *scripts/testing* folder. 94 | 95 | The pretrained model achieves the following performance on the SNLI dataset: 96 | 97 | | Split | Accuracy (%) | 98 | |-------|--------------| 99 | | Train | 93.2 | 100 | | Dev | 88.4 | 101 | | Test | 88.0 | 102 | 103 | The results are in line with those presented in the paper by Chen et al. 104 | 105 | On the [Breaking NLI](https://github.com/BIU-NLP/Breaking_NLI) dataset, published by [Glockner et al. in 2018](https://arxiv.org/pdf/1805.02266.pdf), the model reaches **65.5%** accuracy, as reported in the paper. 106 | 107 | On MultiNLI, the model reaches the following accuracy: 108 | 109 | | Split | Matched | Mismatched | 110 | |-------|---------|------------| 111 | | Dev | 77.0 % | 76.8 % | 112 | | Test | 76.6 % | 75.8 % | 113 | 114 | These results are slightly above what was reported by Williams et al. in their MultiNLI paper. 115 | -------------------------------------------------------------------------------- /config/preprocessing/bnli_preprocessing.json: -------------------------------------------------------------------------------- 1 | { 2 | "data_file": "../../data/dataset/bnli.jsonl", 3 | "target_dir": "../../data/preprocessed/BNLI", 4 | "worddict": "../../data/preprocessed/SNLI/worddict.pkl", 5 | "labeldict": {"entailment": 0, 6 | "neutral": 1, 7 | "contradiction": 2} 8 | } -------------------------------------------------------------------------------- /config/preprocessing/mnli_preprocessing.json: -------------------------------------------------------------------------------- 1 | { 2 | "data_dir": "../../data/dataset/multinli_1.0", 3 | "embeddings_file": "../../data/embeddings/glove.840B.300d.txt", 4 | "target_dir": "../../data/preprocessed/MNLI", 5 | 6 | "lowercase": false, 7 | "ignore_punctuation": false, 8 | "num_words": null, 9 | "stopwords": [], 10 | "labeldict": {"entailment": 0, 11 | "neutral": 1, 12 | "contradiction": 2}, 13 | "bos": "_BOS_", 14 | "eos": "_EOS_" 15 | } -------------------------------------------------------------------------------- /config/preprocessing/snli_preprocessing.json: -------------------------------------------------------------------------------- 1 | { 2 | "data_dir": "../../data/dataset/snli_1.0", 3 | "embeddings_file": "../../data/embeddings/glove.840B.300d.txt", 4 | "target_dir": "../../data/preprocessed/SNLI", 5 | 6 | "lowercase": false, 7 | "ignore_punctuation": false, 8 | "num_words": null, 9 | "stopwords": [], 10 | "labeldict": {"entailment": 0, 11 | "neutral": 1, 12 | "contradiction": 2}, 13 | "bos": "_BOS_", 14 | "eos": "_EOS_" 15 | } -------------------------------------------------------------------------------- /config/testing/mnli_testing.json: -------------------------------------------------------------------------------- 1 | { 2 | "test_files": {"matched": "../../data/preprocessed/MNLI/matched_test_data.pkl", 3 | "mismatched": "../../data/preprocessed/MNLI/mismatched_test_data.pkl"}, 4 | "labeldict": {"entailment": 0, 5 | "neutral": 1, 6 | "contradiction": 2}, 7 | "output_dir": "../../data/results", 8 | "batch_size": 32 9 | } -------------------------------------------------------------------------------- /config/training/mnli_training.json: -------------------------------------------------------------------------------- 1 | { 2 | "train_data": "../../data/preprocessed/MNLI/train_data.pkl", 3 | "valid_data": {"matched": "../../data/preprocessed/MNLI/matched_dev_data.pkl", 4 | "mismatched": "../../data/preprocessed/MNLI/mismatched_dev_data.pkl"}, 5 | "embeddings": "../../data/preprocessed/MNLI/embeddings.pkl", 6 | 7 | "target_dir": "../../data/checkpoints/MNLI", 8 | 9 | "hidden_size": 300, 10 | "dropout": 0.5, 11 | "num_classes": 3, 12 | 13 | "epochs": 64, 14 | "batch_size": 32, 15 | "lr": 0.0004, 16 | "patience": 5, 17 | "max_gradient_norm": 10.0 18 | } -------------------------------------------------------------------------------- /config/training/snli_training.json: -------------------------------------------------------------------------------- 1 | { 2 | "train_data": "../../data/preprocessed/SNLI/train_data.pkl", 3 | "valid_data": "../../data/preprocessed/SNLI/dev_data.pkl", 4 | "embeddings": "../../data/preprocessed/SNLI/embeddings.pkl", 5 | 6 | "target_dir": "../../data/checkpoints/SNLI", 7 | 8 | "hidden_size": 300, 9 | "dropout": 0.5, 10 | "num_classes": 3, 11 | 12 | "epochs": 64, 13 | "batch_size": 32, 14 | "lr": 0.0004, 15 | "patience": 5, 16 | "max_gradient_norm": 10.0 17 | } -------------------------------------------------------------------------------- /data/checkpoints/SNLI/best.pth.tar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/coetaur0/ESIM/65611601ff9f17f76e1f246e8e46b5fc4bee13fc/data/checkpoints/SNLI/best.pth.tar -------------------------------------------------------------------------------- /esim.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/coetaur0/ESIM/65611601ff9f17f76e1f246e8e46b5fc4bee13fc/esim.png -------------------------------------------------------------------------------- /esim/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/coetaur0/ESIM/65611601ff9f17f76e1f246e8e46b5fc4bee13fc/esim/__init__.py -------------------------------------------------------------------------------- /esim/data.py: -------------------------------------------------------------------------------- 1 | """ 2 | Preprocessor and dataset definition for NLI. 3 | """ 4 | # Aurelien Coet, 2018. 5 | 6 | import string 7 | import torch 8 | import numpy as np 9 | 10 | from collections import Counter 11 | from torch.utils.data import Dataset 12 | 13 | 14 | class Preprocessor(object): 15 | """ 16 | Preprocessor class for Natural Language Inference datasets. 17 | 18 | The class can be used to read NLI datasets, build worddicts for them 19 | and transform their premises, hypotheses and labels into lists of 20 | integer indices. 21 | """ 22 | 23 | def __init__(self, 24 | lowercase=False, 25 | ignore_punctuation=False, 26 | num_words=None, 27 | stopwords=[], 28 | labeldict={}, 29 | bos=None, 30 | eos=None): 31 | """ 32 | Args: 33 | lowercase: A boolean indicating whether the words in the datasets 34 | being preprocessed must be lowercased or not. Defaults to 35 | False. 36 | ignore_punctuation: A boolean indicating whether punctuation must 37 | be ignored or not in the datasets preprocessed by the object. 38 | num_words: An integer indicating the number of words to use in the 39 | worddict of the object. If set to None, all the words in the 40 | data are kept. Defaults to None. 41 | stopwords: A list of words that must be ignored when building the 42 | worddict for a dataset. Defaults to an empty list. 43 | bos: A string indicating the symbol to use for the 'beginning of 44 | sentence' token in the data. If set to None, the token isn't 45 | used. Defaults to None. 46 | eos: A string indicating the symbol to use for the 'end of 47 | sentence' token in the data. If set to None, the token isn't 48 | used. Defaults to None. 49 | """ 50 | self.lowercase = lowercase 51 | self.ignore_punctuation = ignore_punctuation 52 | self.num_words = num_words 53 | self.stopwords = stopwords 54 | self.labeldict = labeldict 55 | self.bos = bos 56 | self.eos = eos 57 | 58 | def read_data(self, filepath): 59 | """ 60 | Read the premises, hypotheses and labels from some NLI dataset's 61 | file and return them in a dictionary. The file should be in the same 62 | form as SNLI's .txt files. 63 | 64 | Args: 65 | filepath: The path to a file containing some premises, hypotheses 66 | and labels that must be read. The file should be formatted in 67 | the same way as the SNLI (and MultiNLI) dataset. 68 | 69 | Returns: 70 | A dictionary containing three lists, one for the premises, one for 71 | the hypotheses, and one for the labels in the input data. 72 | """ 73 | with open(filepath, "r", encoding="utf8") as input_data: 74 | ids, premises, hypotheses, labels = [], [], [], [] 75 | 76 | # Translation tables to remove parentheses and punctuation from 77 | # strings. 78 | parentheses_table = str.maketrans({"(": None, ")": None}) 79 | punct_table = str.maketrans({key: " " 80 | for key in string.punctuation}) 81 | 82 | # Ignore the headers on the first line of the file. 83 | next(input_data) 84 | 85 | for line in input_data: 86 | line = line.strip().split("\t") 87 | 88 | # Ignore sentences that have no gold label. 89 | if line[0] == "-": 90 | continue 91 | 92 | pair_id = line[7] 93 | premise = line[1] 94 | hypothesis = line[2] 95 | 96 | # Remove '(' and ')' from the premises and hypotheses. 97 | premise = premise.translate(parentheses_table) 98 | hypothesis = hypothesis.translate(parentheses_table) 99 | 100 | if self.lowercase: 101 | premise = premise.lower() 102 | hypothesis = hypothesis.lower() 103 | 104 | if self.ignore_punctuation: 105 | premise = premise.translate(punct_table) 106 | hypothesis = hypothesis.translate(punct_table) 107 | 108 | # Each premise and hypothesis is split into a list of words. 109 | premises.append([w for w in premise.rstrip().split() 110 | if w not in self.stopwords]) 111 | hypotheses.append([w for w in hypothesis.rstrip().split() 112 | if w not in self.stopwords]) 113 | labels.append(line[0]) 114 | ids.append(pair_id) 115 | 116 | return {"ids": ids, 117 | "premises": premises, 118 | "hypotheses": hypotheses, 119 | "labels": labels} 120 | 121 | def build_worddict(self, data): 122 | """ 123 | Build a dictionary associating words to unique integer indices for 124 | some dataset. The worddict can then be used to transform the words 125 | in datasets to their indices. 126 | 127 | Args: 128 | data: A dictionary containing the premises, hypotheses and 129 | labels of some NLI dataset, in the format returned by the 130 | 'read_data' method of the Preprocessor class. 131 | """ 132 | words = [] 133 | [words.extend(sentence) for sentence in data["premises"]] 134 | [words.extend(sentence) for sentence in data["hypotheses"]] 135 | 136 | counts = Counter(words) 137 | num_words = self.num_words 138 | if self.num_words is None: 139 | num_words = len(counts) 140 | 141 | self.worddict = {} 142 | 143 | # Special indices are used for padding, out-of-vocabulary words, and 144 | # beginning and end of sentence tokens. 145 | self.worddict["_PAD_"] = 0 146 | self.worddict["_OOV_"] = 1 147 | 148 | offset = 2 149 | if self.bos: 150 | self.worddict["_BOS_"] = 2 151 | offset += 1 152 | if self.eos: 153 | self.worddict["_EOS_"] = 3 154 | offset += 1 155 | 156 | for i, word in enumerate(counts.most_common(num_words)): 157 | self.worddict[word[0]] = i + offset 158 | 159 | if self.labeldict == {}: 160 | label_names = set(data["labels"]) 161 | self.labeldict = {label_name: i 162 | for i, label_name in enumerate(label_names)} 163 | 164 | def words_to_indices(self, sentence): 165 | """ 166 | Transform the words in a sentence to their corresponding integer 167 | indices. 168 | 169 | Args: 170 | sentence: A list of words that must be transformed to indices. 171 | 172 | Returns: 173 | A list of indices. 174 | """ 175 | indices = [] 176 | # Include the beggining of sentence token at the start of the sentence 177 | # if one is defined. 178 | if self.bos: 179 | indices.append(self.worddict["_BOS_"]) 180 | 181 | for word in sentence: 182 | if word in self.worddict: 183 | index = self.worddict[word] 184 | else: 185 | # Words absent from 'worddict' are treated as a special 186 | # out-of-vocabulary word (OOV). 187 | index = self.worddict["_OOV_"] 188 | indices.append(index) 189 | # Add the end of sentence token at the end of the sentence if one 190 | # is defined. 191 | if self.eos: 192 | indices.append(self.worddict["_EOS_"]) 193 | 194 | return indices 195 | 196 | def indices_to_words(self, indices): 197 | """ 198 | Transform the indices in a list to their corresponding words in 199 | the object's worddict. 200 | 201 | Args: 202 | indices: A list of integer indices corresponding to words in 203 | the Preprocessor's worddict. 204 | 205 | Returns: 206 | A list of words. 207 | """ 208 | return [list(self.worddict.keys())[list(self.worddict.values()) 209 | .index(i)] 210 | for i in indices] 211 | 212 | def transform_to_indices(self, data): 213 | """ 214 | Transform the words in the premises and hypotheses of a dataset, as 215 | well as their associated labels, to integer indices. 216 | 217 | Args: 218 | data: A dictionary containing lists of premises, hypotheses 219 | and labels, in the format returned by the 'read_data' 220 | method of the Preprocessor class. 221 | 222 | Returns: 223 | A dictionary containing the transformed premises, hypotheses and 224 | labels. 225 | """ 226 | transformed_data = {"ids": [], 227 | "premises": [], 228 | "hypotheses": [], 229 | "labels": []} 230 | 231 | for i, premise in enumerate(data["premises"]): 232 | # Ignore sentences that have a label for which no index was 233 | # defined in 'labeldict'. 234 | label = data["labels"][i] 235 | if label not in self.labeldict and label != "hidden": 236 | continue 237 | 238 | transformed_data["ids"].append(data["ids"][i]) 239 | 240 | if label == "hidden": 241 | transformed_data["labels"].append(-1) 242 | else: 243 | transformed_data["labels"].append(self.labeldict[label]) 244 | 245 | indices = self.words_to_indices(premise) 246 | transformed_data["premises"].append(indices) 247 | 248 | indices = self.words_to_indices(data["hypotheses"][i]) 249 | transformed_data["hypotheses"].append(indices) 250 | 251 | return transformed_data 252 | 253 | def build_embedding_matrix(self, embeddings_file): 254 | """ 255 | Build an embedding matrix with pretrained weights for object's 256 | worddict. 257 | 258 | Args: 259 | embeddings_file: A file containing pretrained word embeddings. 260 | 261 | Returns: 262 | A numpy matrix of size (num_words+n_special_tokens, embedding_dim) 263 | containing pretrained word embeddings (the +n_special_tokens is for 264 | the padding and out-of-vocabulary tokens, as well as BOS and EOS if 265 | they're used). 266 | """ 267 | # Load the word embeddings in a dictionnary. 268 | embeddings = {} 269 | with open(embeddings_file, "r", encoding="utf8") as input_data: 270 | for line in input_data: 271 | line = line.split() 272 | 273 | try: 274 | # Check that the second element on the line is the start 275 | # of the embedding and not another word. Necessary to 276 | # ignore multiple word lines. 277 | float(line[1]) 278 | word = line[0] 279 | if word in self.worddict: 280 | embeddings[word] = line[1:] 281 | 282 | # Ignore lines corresponding to multiple words separated 283 | # by spaces. 284 | except ValueError: 285 | continue 286 | 287 | num_words = len(self.worddict) 288 | embedding_dim = len(list(embeddings.values())[0]) 289 | embedding_matrix = np.zeros((num_words, embedding_dim)) 290 | 291 | # Actual building of the embedding matrix. 292 | missed = 0 293 | for word, i in self.worddict.items(): 294 | if word in embeddings: 295 | embedding_matrix[i] = np.array(embeddings[word], dtype=float) 296 | else: 297 | if word == "_PAD_": 298 | continue 299 | missed += 1 300 | # Out of vocabulary words are initialised with random gaussian 301 | # samples. 302 | embedding_matrix[i] = np.random.normal(size=(embedding_dim)) 303 | print("Missed words: ", missed) 304 | 305 | return embedding_matrix 306 | 307 | 308 | class NLIDataset(Dataset): 309 | """ 310 | Dataset class for Natural Language Inference datasets. 311 | 312 | The class can be used to read preprocessed datasets where the premises, 313 | hypotheses and labels have been transformed to unique integer indices 314 | (this can be done with the 'preprocess_data' script in the 'scripts' 315 | folder of this repository). 316 | """ 317 | 318 | def __init__(self, 319 | data, 320 | padding_idx=0, 321 | max_premise_length=None, 322 | max_hypothesis_length=None): 323 | """ 324 | Args: 325 | data: A dictionary containing the preprocessed premises, 326 | hypotheses and labels of some dataset. 327 | padding_idx: An integer indicating the index being used for the 328 | padding token in the preprocessed data. Defaults to 0. 329 | max_premise_length: An integer indicating the maximum length 330 | accepted for the sequences in the premises. If set to None, 331 | the length of the longest premise in 'data' is used. 332 | Defaults to None. 333 | max_hypothesis_length: An integer indicating the maximum length 334 | accepted for the sequences in the hypotheses. If set to None, 335 | the length of the longest hypothesis in 'data' is used. 336 | Defaults to None. 337 | """ 338 | self.premises_lengths = [len(seq) for seq in data["premises"]] 339 | self.max_premise_length = max_premise_length 340 | if self.max_premise_length is None: 341 | self.max_premise_length = max(self.premises_lengths) 342 | 343 | self.hypotheses_lengths = [len(seq) for seq in data["hypotheses"]] 344 | self.max_hypothesis_length = max_hypothesis_length 345 | if self.max_hypothesis_length is None: 346 | self.max_hypothesis_length = max(self.hypotheses_lengths) 347 | 348 | self.num_sequences = len(data["premises"]) 349 | 350 | self.data = {"ids": [], 351 | "premises": torch.ones((self.num_sequences, 352 | self.max_premise_length), 353 | dtype=torch.long) * padding_idx, 354 | "hypotheses": torch.ones((self.num_sequences, 355 | self.max_hypothesis_length), 356 | dtype=torch.long) * padding_idx, 357 | "labels": torch.tensor(data["labels"], dtype=torch.long)} 358 | 359 | for i, premise in enumerate(data["premises"]): 360 | self.data["ids"].append(data["ids"][i]) 361 | end = min(len(premise), self.max_premise_length) 362 | self.data["premises"][i][:end] = torch.tensor(premise[:end]) 363 | 364 | hypothesis = data["hypotheses"][i] 365 | end = min(len(hypothesis), self.max_hypothesis_length) 366 | self.data["hypotheses"][i][:end] = torch.tensor(hypothesis[:end]) 367 | 368 | def __len__(self): 369 | return self.num_sequences 370 | 371 | def __getitem__(self, index): 372 | return {"id": self.data["ids"][index], 373 | "premise": self.data["premises"][index], 374 | "premise_length": min(self.premises_lengths[index], 375 | self.max_premise_length), 376 | "hypothesis": self.data["hypotheses"][index], 377 | "hypothesis_length": min(self.hypotheses_lengths[index], 378 | self.max_hypothesis_length), 379 | "label": self.data["labels"][index]} 380 | -------------------------------------------------------------------------------- /esim/layers.py: -------------------------------------------------------------------------------- 1 | """ 2 | Definition of custom layers for the ESIM model. 3 | """ 4 | # Aurelien Coet, 2018. 5 | 6 | import torch.nn as nn 7 | 8 | from .utils import sort_by_seq_lens, masked_softmax, weighted_sum 9 | 10 | 11 | # Class widely inspired from: 12 | # https://github.com/allenai/allennlp/blob/master/allennlp/modules/input_variational_dropout.py 13 | class RNNDropout(nn.Dropout): 14 | """ 15 | Dropout layer for the inputs of RNNs. 16 | 17 | Apply the same dropout mask to all the elements of the same sequence in 18 | a batch of sequences of size (batch, sequences_length, embedding_dim). 19 | """ 20 | 21 | def forward(self, sequences_batch): 22 | """ 23 | Apply dropout to the input batch of sequences. 24 | 25 | Args: 26 | sequences_batch: A batch of sequences of vectors that will serve 27 | as input to an RNN. 28 | Tensor of size (batch, sequences_length, emebdding_dim). 29 | 30 | Returns: 31 | A new tensor on which dropout has been applied. 32 | """ 33 | ones = sequences_batch.data.new_ones(sequences_batch.shape[0], 34 | sequences_batch.shape[-1]) 35 | dropout_mask = nn.functional.dropout(ones, self.p, self.training, 36 | inplace=False) 37 | return dropout_mask.unsqueeze(1) * sequences_batch 38 | 39 | 40 | class Seq2SeqEncoder(nn.Module): 41 | """ 42 | RNN taking variable length padded sequences of vectors as input and 43 | encoding them into padded sequences of vectors of the same length. 44 | 45 | This module is useful to handle batches of padded sequences of vectors 46 | that have different lengths and that need to be passed through a RNN. 47 | The sequences are sorted in descending order of their lengths, packed, 48 | passed through the RNN, and the resulting sequences are then padded and 49 | permuted back to the original order of the input sequences. 50 | """ 51 | 52 | def __init__(self, 53 | rnn_type, 54 | input_size, 55 | hidden_size, 56 | num_layers=1, 57 | bias=True, 58 | dropout=0.0, 59 | bidirectional=False): 60 | """ 61 | Args: 62 | rnn_type: The type of RNN to use as encoder in the module. 63 | Must be a class inheriting from torch.nn.RNNBase 64 | (such as torch.nn.LSTM for example). 65 | input_size: The number of expected features in the input of the 66 | module. 67 | hidden_size: The number of features in the hidden state of the RNN 68 | used as encoder by the module. 69 | num_layers: The number of recurrent layers in the encoder of the 70 | module. Defaults to 1. 71 | bias: If False, the encoder does not use bias weights b_ih and 72 | b_hh. Defaults to True. 73 | dropout: If non-zero, introduces a dropout layer on the outputs 74 | of each layer of the encoder except the last one, with dropout 75 | probability equal to 'dropout'. Defaults to 0.0. 76 | bidirectional: If True, the encoder of the module is bidirectional. 77 | Defaults to False. 78 | """ 79 | assert issubclass(rnn_type, nn.RNNBase),\ 80 | "rnn_type must be a class inheriting from torch.nn.RNNBase" 81 | 82 | super(Seq2SeqEncoder, self).__init__() 83 | 84 | self.rnn_type = rnn_type 85 | self.input_size = input_size 86 | self.hidden_size = hidden_size 87 | self.num_layers = num_layers 88 | self.bias = bias 89 | self.dropout = dropout 90 | self.bidirectional = bidirectional 91 | 92 | self._encoder = rnn_type(input_size, 93 | hidden_size, 94 | num_layers=num_layers, 95 | bias=bias, 96 | batch_first=True, 97 | dropout=dropout, 98 | bidirectional=bidirectional) 99 | 100 | def forward(self, sequences_batch, sequences_lengths): 101 | """ 102 | Args: 103 | sequences_batch: A batch of variable length sequences of vectors. 104 | The batch is assumed to be of size 105 | (batch, sequence, vector_dim). 106 | sequences_lengths: A 1D tensor containing the sizes of the 107 | sequences in the input batch. 108 | 109 | Returns: 110 | reordered_outputs: The outputs (hidden states) of the encoder for 111 | the sequences in the input batch, in the same order. 112 | """ 113 | sorted_batch, sorted_lengths, _, restoration_idx =\ 114 | sort_by_seq_lens(sequences_batch, sequences_lengths) 115 | packed_batch = nn.utils.rnn.pack_padded_sequence(sorted_batch, 116 | sorted_lengths, 117 | batch_first=True) 118 | 119 | outputs, _ = self._encoder(packed_batch, None) 120 | 121 | outputs, _ = nn.utils.rnn.pad_packed_sequence(outputs, 122 | batch_first=True) 123 | reordered_outputs = outputs.index_select(0, restoration_idx) 124 | 125 | return reordered_outputs 126 | 127 | 128 | class SoftmaxAttention(nn.Module): 129 | """ 130 | Attention layer taking premises and hypotheses encoded by an RNN as input 131 | and computing the soft attention between their elements. 132 | 133 | The dot product of the encoded vectors in the premises and hypotheses is 134 | first computed. The softmax of the result is then used in a weighted sum 135 | of the vectors of the premises for each element of the hypotheses, and 136 | conversely for the elements of the premises. 137 | """ 138 | 139 | def forward(self, 140 | premise_batch, 141 | premise_mask, 142 | hypothesis_batch, 143 | hypothesis_mask): 144 | """ 145 | Args: 146 | premise_batch: A batch of sequences of vectors representing the 147 | premises in some NLI task. The batch is assumed to have the 148 | size (batch, sequences, vector_dim). 149 | premise_mask: A mask for the sequences in the premise batch, to 150 | ignore padding data in the sequences during the computation of 151 | the attention. 152 | hypothesis_batch: A batch of sequences of vectors representing the 153 | hypotheses in some NLI task. The batch is assumed to have the 154 | size (batch, sequences, vector_dim). 155 | hypothesis_mask: A mask for the sequences in the hypotheses batch, 156 | to ignore padding data in the sequences during the computation 157 | of the attention. 158 | 159 | Returns: 160 | attended_premises: The sequences of attention vectors for the 161 | premises in the input batch. 162 | attended_hypotheses: The sequences of attention vectors for the 163 | hypotheses in the input batch. 164 | """ 165 | # Dot product between premises and hypotheses in each sequence of 166 | # the batch. 167 | similarity_matrix = premise_batch.bmm(hypothesis_batch.transpose(2, 1) 168 | .contiguous()) 169 | 170 | # Softmax attention weights. 171 | prem_hyp_attn = masked_softmax(similarity_matrix, hypothesis_mask) 172 | hyp_prem_attn = masked_softmax(similarity_matrix.transpose(1, 2) 173 | .contiguous(), 174 | premise_mask) 175 | 176 | # Weighted sums of the hypotheses for the the premises attention, 177 | # and vice-versa for the attention of the hypotheses. 178 | attended_premises = weighted_sum(hypothesis_batch, 179 | prem_hyp_attn, 180 | premise_mask) 181 | attended_hypotheses = weighted_sum(premise_batch, 182 | hyp_prem_attn, 183 | hypothesis_mask) 184 | 185 | return attended_premises, attended_hypotheses 186 | -------------------------------------------------------------------------------- /esim/model.py: -------------------------------------------------------------------------------- 1 | """ 2 | Definition of the ESIM model. 3 | """ 4 | # Aurelien Coet, 2018. 5 | 6 | import torch 7 | import torch.nn as nn 8 | 9 | from .layers import RNNDropout, Seq2SeqEncoder, SoftmaxAttention 10 | from .utils import get_mask, replace_masked 11 | 12 | 13 | class ESIM(nn.Module): 14 | """ 15 | Implementation of the ESIM model presented in the paper "Enhanced LSTM for 16 | Natural Language Inference" by Chen et al. 17 | """ 18 | 19 | def __init__(self, 20 | vocab_size, 21 | embedding_dim, 22 | hidden_size, 23 | embeddings=None, 24 | padding_idx=0, 25 | dropout=0.5, 26 | num_classes=3, 27 | device="cpu"): 28 | """ 29 | Args: 30 | vocab_size: The size of the vocabulary of embeddings in the model. 31 | embedding_dim: The dimension of the word embeddings. 32 | hidden_size: The size of all the hidden layers in the network. 33 | embeddings: A tensor of size (vocab_size, embedding_dim) containing 34 | pretrained word embeddings. If None, word embeddings are 35 | initialised randomly. Defaults to None. 36 | padding_idx: The index of the padding token in the premises and 37 | hypotheses passed as input to the model. Defaults to 0. 38 | dropout: The dropout rate to use between the layers of the network. 39 | A dropout rate of 0 corresponds to using no dropout at all. 40 | Defaults to 0.5. 41 | num_classes: The number of classes in the output of the network. 42 | Defaults to 3. 43 | device: The name of the device on which the model is being 44 | executed. Defaults to 'cpu'. 45 | """ 46 | super(ESIM, self).__init__() 47 | 48 | self.vocab_size = vocab_size 49 | self.embedding_dim = embedding_dim 50 | self.hidden_size = hidden_size 51 | self.num_classes = num_classes 52 | self.dropout = dropout 53 | self.device = device 54 | 55 | self._word_embedding = nn.Embedding(self.vocab_size, 56 | self.embedding_dim, 57 | padding_idx=padding_idx, 58 | _weight=embeddings) 59 | 60 | if self.dropout: 61 | self._rnn_dropout = RNNDropout(p=self.dropout) 62 | # self._rnn_dropout = nn.Dropout(p=self.dropout) 63 | 64 | self._encoding = Seq2SeqEncoder(nn.LSTM, 65 | self.embedding_dim, 66 | self.hidden_size, 67 | bidirectional=True) 68 | 69 | self._attention = SoftmaxAttention() 70 | 71 | self._projection = nn.Sequential(nn.Linear(4*2*self.hidden_size, 72 | self.hidden_size), 73 | nn.ReLU()) 74 | 75 | self._composition = Seq2SeqEncoder(nn.LSTM, 76 | self.hidden_size, 77 | self.hidden_size, 78 | bidirectional=True) 79 | 80 | self._classification = nn.Sequential(nn.Dropout(p=self.dropout), 81 | nn.Linear(2*4*self.hidden_size, 82 | self.hidden_size), 83 | nn.Tanh(), 84 | nn.Dropout(p=self.dropout), 85 | nn.Linear(self.hidden_size, 86 | self.num_classes)) 87 | 88 | # Initialize all weights and biases in the model. 89 | self.apply(_init_esim_weights) 90 | 91 | def forward(self, 92 | premises, 93 | premises_lengths, 94 | hypotheses, 95 | hypotheses_lengths): 96 | """ 97 | Args: 98 | premises: A batch of varaible length sequences of word indices 99 | representing premises. The batch is assumed to be of size 100 | (batch, premises_length). 101 | premises_lengths: A 1D tensor containing the lengths of the 102 | premises in 'premises'. 103 | hypothesis: A batch of varaible length sequences of word indices 104 | representing hypotheses. The batch is assumed to be of size 105 | (batch, hypotheses_length). 106 | hypotheses_lengths: A 1D tensor containing the lengths of the 107 | hypotheses in 'hypotheses'. 108 | 109 | Returns: 110 | logits: A tensor of size (batch, num_classes) containing the 111 | logits for each output class of the model. 112 | probabilities: A tensor of size (batch, num_classes) containing 113 | the probabilities of each output class in the model. 114 | """ 115 | premises_mask = get_mask(premises, premises_lengths).to(self.device) 116 | hypotheses_mask = get_mask(hypotheses, hypotheses_lengths)\ 117 | .to(self.device) 118 | 119 | embedded_premises = self._word_embedding(premises) 120 | embedded_hypotheses = self._word_embedding(hypotheses) 121 | 122 | if self.dropout: 123 | embedded_premises = self._rnn_dropout(embedded_premises) 124 | embedded_hypotheses = self._rnn_dropout(embedded_hypotheses) 125 | 126 | encoded_premises = self._encoding(embedded_premises, 127 | premises_lengths) 128 | encoded_hypotheses = self._encoding(embedded_hypotheses, 129 | hypotheses_lengths) 130 | 131 | attended_premises, attended_hypotheses =\ 132 | self._attention(encoded_premises, premises_mask, 133 | encoded_hypotheses, hypotheses_mask) 134 | 135 | enhanced_premises = torch.cat([encoded_premises, 136 | attended_premises, 137 | encoded_premises - attended_premises, 138 | encoded_premises * attended_premises], 139 | dim=-1) 140 | enhanced_hypotheses = torch.cat([encoded_hypotheses, 141 | attended_hypotheses, 142 | encoded_hypotheses - 143 | attended_hypotheses, 144 | encoded_hypotheses * 145 | attended_hypotheses], 146 | dim=-1) 147 | 148 | projected_premises = self._projection(enhanced_premises) 149 | projected_hypotheses = self._projection(enhanced_hypotheses) 150 | 151 | if self.dropout: 152 | projected_premises = self._rnn_dropout(projected_premises) 153 | projected_hypotheses = self._rnn_dropout(projected_hypotheses) 154 | 155 | v_ai = self._composition(projected_premises, premises_lengths) 156 | v_bj = self._composition(projected_hypotheses, hypotheses_lengths) 157 | 158 | v_a_avg = torch.sum(v_ai * premises_mask.unsqueeze(1) 159 | .transpose(2, 1), dim=1)\ 160 | / torch.sum(premises_mask, dim=1, keepdim=True) 161 | v_b_avg = torch.sum(v_bj * hypotheses_mask.unsqueeze(1) 162 | .transpose(2, 1), dim=1)\ 163 | / torch.sum(hypotheses_mask, dim=1, keepdim=True) 164 | 165 | v_a_max, _ = replace_masked(v_ai, premises_mask, -1e7).max(dim=1) 166 | v_b_max, _ = replace_masked(v_bj, hypotheses_mask, -1e7).max(dim=1) 167 | 168 | v = torch.cat([v_a_avg, v_a_max, v_b_avg, v_b_max], dim=1) 169 | 170 | logits = self._classification(v) 171 | probabilities = nn.functional.softmax(logits, dim=-1) 172 | 173 | return logits, probabilities 174 | 175 | 176 | def _init_esim_weights(module): 177 | """ 178 | Initialise the weights of the ESIM model. 179 | """ 180 | if isinstance(module, nn.Linear): 181 | nn.init.xavier_uniform_(module.weight.data) 182 | nn.init.constant_(module.bias.data, 0.0) 183 | 184 | elif isinstance(module, nn.LSTM): 185 | nn.init.xavier_uniform_(module.weight_ih_l0.data) 186 | nn.init.orthogonal_(module.weight_hh_l0.data) 187 | nn.init.constant_(module.bias_ih_l0.data, 0.0) 188 | nn.init.constant_(module.bias_hh_l0.data, 0.0) 189 | hidden_size = module.bias_hh_l0.data.shape[0] // 4 190 | module.bias_hh_l0.data[hidden_size:(2*hidden_size)] = 1.0 191 | 192 | if (module.bidirectional): 193 | nn.init.xavier_uniform_(module.weight_ih_l0_reverse.data) 194 | nn.init.orthogonal_(module.weight_hh_l0_reverse.data) 195 | nn.init.constant_(module.bias_ih_l0_reverse.data, 0.0) 196 | nn.init.constant_(module.bias_hh_l0_reverse.data, 0.0) 197 | module.bias_hh_l0_reverse.data[hidden_size:(2*hidden_size)] = 1.0 198 | -------------------------------------------------------------------------------- /esim/utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Utility functions for the ESIM model. 3 | """ 4 | # Aurelien Coet, 2018. 5 | 6 | import torch 7 | import torch.nn as nn 8 | 9 | 10 | # Code widely inspired from: 11 | # https://github.com/allenai/allennlp/blob/master/allennlp/nn/util.py. 12 | def sort_by_seq_lens(batch, sequences_lengths, descending=True): 13 | """ 14 | Sort a batch of padded variable length sequences by their length. 15 | 16 | Args: 17 | batch: A batch of padded variable length sequences. The batch should 18 | have the dimensions (batch_size x max_sequence_length x *). 19 | sequences_lengths: A tensor containing the lengths of the sequences in the 20 | input batch. The tensor should be of size (batch_size). 21 | descending: A boolean value indicating whether to sort the sequences 22 | by their lengths in descending order. Defaults to True. 23 | 24 | Returns: 25 | sorted_batch: A tensor containing the input batch reordered by 26 | sequences lengths. 27 | sorted_seq_lens: A tensor containing the sorted lengths of the 28 | sequences in the input batch. 29 | sorting_idx: A tensor containing the indices used to permute the input 30 | batch in order to get 'sorted_batch'. 31 | restoration_idx: A tensor containing the indices that can be used to 32 | restore the order of the sequences in 'sorted_batch' so that it 33 | matches the input batch. 34 | """ 35 | sorted_seq_lens, sorting_index =\ 36 | sequences_lengths.sort(0, descending=descending) 37 | 38 | sorted_batch = batch.index_select(0, sorting_index) 39 | 40 | idx_range =\ 41 | sequences_lengths.new_tensor(torch.arange(0, len(sequences_lengths))) 42 | _, reverse_mapping = sorting_index.sort(0, descending=False) 43 | restoration_index = idx_range.index_select(0, reverse_mapping) 44 | 45 | return sorted_batch, sorted_seq_lens, sorting_index, restoration_index 46 | 47 | 48 | def get_mask(sequences_batch, sequences_lengths): 49 | """ 50 | Get the mask for a batch of padded variable length sequences. 51 | 52 | Args: 53 | sequences_batch: A batch of padded variable length sequences 54 | containing word indices. Must be a 2-dimensional tensor of size 55 | (batch, sequence). 56 | sequences_lengths: A tensor containing the lengths of the sequences in 57 | 'sequences_batch'. Must be of size (batch). 58 | 59 | Returns: 60 | A mask of size (batch, max_sequence_length), where max_sequence_length 61 | is the length of the longest sequence in the batch. 62 | """ 63 | batch_size = sequences_batch.size()[0] 64 | max_length = torch.max(sequences_lengths) 65 | mask = torch.ones(batch_size, max_length, dtype=torch.float) 66 | mask[sequences_batch[:, :max_length] == 0] = 0.0 67 | return mask 68 | 69 | 70 | # Code widely inspired from: 71 | # https://github.com/allenai/allennlp/blob/master/allennlp/nn/util.py. 72 | def masked_softmax(tensor, mask): 73 | """ 74 | Apply a masked softmax on the last dimension of a tensor. 75 | The input tensor and mask should be of size (batch, *, sequence_length). 76 | 77 | Args: 78 | tensor: The tensor on which the softmax function must be applied along 79 | the last dimension. 80 | mask: A mask of the same size as the tensor with 0s in the positions of 81 | the values that must be masked and 1s everywhere else. 82 | 83 | Returns: 84 | A tensor of the same size as the inputs containing the result of the 85 | softmax. 86 | """ 87 | tensor_shape = tensor.size() 88 | reshaped_tensor = tensor.view(-1, tensor_shape[-1]) 89 | 90 | # Reshape the mask so it matches the size of the input tensor. 91 | while mask.dim() < tensor.dim(): 92 | mask = mask.unsqueeze(1) 93 | mask = mask.expand_as(tensor).contiguous().float() 94 | reshaped_mask = mask.view(-1, mask.size()[-1]) 95 | 96 | result = nn.functional.softmax(reshaped_tensor * reshaped_mask, dim=-1) 97 | result = result * reshaped_mask 98 | # 1e-13 is added to avoid divisions by zero. 99 | result = result / (result.sum(dim=-1, keepdim=True) + 1e-13) 100 | 101 | return result.view(*tensor_shape) 102 | 103 | 104 | # Code widely inspired from: 105 | # https://github.com/allenai/allennlp/blob/master/allennlp/nn/util.py. 106 | def weighted_sum(tensor, weights, mask): 107 | """ 108 | Apply a weighted sum on the vectors along the last dimension of 'tensor', 109 | and mask the vectors in the result with 'mask'. 110 | 111 | Args: 112 | tensor: A tensor of vectors on which a weighted sum must be applied. 113 | weights: The weights to use in the weighted sum. 114 | mask: A mask to apply on the result of the weighted sum. 115 | 116 | Returns: 117 | A new tensor containing the result of the weighted sum after the mask 118 | has been applied on it. 119 | """ 120 | weighted_sum = weights.bmm(tensor) 121 | 122 | while mask.dim() < weighted_sum.dim(): 123 | mask = mask.unsqueeze(1) 124 | mask = mask.transpose(-1, -2) 125 | mask = mask.expand_as(weighted_sum).contiguous().float() 126 | 127 | return weighted_sum * mask 128 | 129 | 130 | # Code inspired from: 131 | # https://github.com/allenai/allennlp/blob/master/allennlp/nn/util.py. 132 | def replace_masked(tensor, mask, value): 133 | """ 134 | Replace the all the values of vectors in 'tensor' that are masked in 135 | 'masked' by 'value'. 136 | 137 | Args: 138 | tensor: The tensor in which the masked vectors must have their values 139 | replaced. 140 | mask: A mask indicating the vectors which must have their values 141 | replaced. 142 | value: The value to place in the masked vectors of 'tensor'. 143 | 144 | Returns: 145 | A new tensor of the same size as 'tensor' where the values of the 146 | vectors masked in 'mask' were replaced by 'value'. 147 | """ 148 | mask = mask.unsqueeze(1).transpose(2, 1) 149 | reverse_mask = 1.0 - mask 150 | values_to_add = value * reverse_mask 151 | return tensor * mask + values_to_add 152 | 153 | 154 | def correct_predictions(output_probabilities, targets): 155 | """ 156 | Compute the number of predictions that match some target classes in the 157 | output of a model. 158 | 159 | Args: 160 | output_probabilities: A tensor of probabilities for different output 161 | classes. 162 | targets: The indices of the actual target classes. 163 | 164 | Returns: 165 | The number of correct predictions in 'output_probabilities'. 166 | """ 167 | _, out_classes = output_probabilities.max(dim=1) 168 | correct = (out_classes == targets).sum() 169 | return correct.item() 170 | -------------------------------------------------------------------------------- /scripts/fetch_data.py: -------------------------------------------------------------------------------- 1 | """ 2 | Fetch datasets and pretrained word embeddings for the ESIM model. 3 | 4 | By default, the script downloads the following. 5 | - The SNLI corpus; 6 | - GloVe word embeddings (840B - 300d). 7 | """ 8 | # Aurelien Coet, 2018. 9 | 10 | import os 11 | import argparse 12 | import zipfile 13 | import wget 14 | 15 | 16 | def download(url, targetdir): 17 | """ 18 | Download a file and save it in some target directory. 19 | 20 | Args: 21 | url: The url from which the file must be downloaded. 22 | targetdir: The path to the directory where the file must be saved. 23 | 24 | Returns: 25 | The path to the downloaded file. 26 | """ 27 | print("* Downloading data from {}...".format(url)) 28 | filepath = os.path.join(targetdir, url.split('/')[-1]) 29 | wget.download(url, filepath) 30 | return filepath 31 | 32 | 33 | def unzip(filepath): 34 | """ 35 | Extract the data from a zipped file and delete the archive. 36 | 37 | Args: 38 | filepath: The path to the zipped file. 39 | """ 40 | print("\n* Extracting: {}...".format(filepath)) 41 | dirpath = os.path.dirname(filepath) 42 | with zipfile.ZipFile(filepath) as zf: 43 | for name in zf.namelist(): 44 | # Ignore useless files in archives. 45 | if "__MACOSX" in name or\ 46 | ".DS_Store" in name or\ 47 | "Icon" in name: 48 | continue 49 | zf.extract(name, dirpath) 50 | # Delete the archive once the data has been extracted. 51 | os.remove(filepath) 52 | 53 | 54 | def download_unzip(url, targetdir): 55 | """ 56 | Download and unzip data from some url and save it in a target directory. 57 | 58 | Args: 59 | url: The url to download the data from. 60 | targetdir: The target directory in which to download and unzip the 61 | data. 62 | """ 63 | filepath = os.path.join(targetdir, url.split('/')[-1]) 64 | target = os.path.join(targetdir, 65 | ".".join((url.split('/')[-1]).split('.')[:-1])) 66 | 67 | if not os.path.exists(targetdir): 68 | print("* Creating target directory {}...".format(targetdir)) 69 | os.makedirs(targetdir) 70 | 71 | # Skip download and unzipping if the unzipped data is already available. 72 | if os.path.exists(target) or os.path.exists(target + ".txt"): 73 | print("* Found unzipped data in {}, skipping download and unzip..." 74 | .format(targetdir)) 75 | # Skip downloading if the zipped data is already available. 76 | elif os.path.exists(filepath): 77 | print("* Found zipped data in {} - skipping download..." 78 | .format(targetdir)) 79 | unzip(filepath) 80 | # Download and unzip otherwise. 81 | else: 82 | unzip(download(url, targetdir)) 83 | 84 | 85 | if __name__ == "__main__": 86 | # Default data. 87 | snli_url = "https://nlp.stanford.edu/projects/snli/snli_1.0.zip" 88 | glove_url = "http://www-nlp.stanford.edu/data/glove.840B.300d.zip" 89 | 90 | parser = argparse.ArgumentParser(description='Download the SNLI dataset') 91 | parser.add_argument("--dataset_url", 92 | default=snli_url, 93 | help="URL of the dataset to download") 94 | parser.add_argument("--embeddings_url", 95 | default=glove_url, 96 | help="URL of the pretrained embeddings to download") 97 | parser.add_argument("--target_dir", 98 | default=os.path.join("..", "data"), 99 | help="Path to a directory where data must be saved") 100 | args = parser.parse_args() 101 | 102 | if not os.path.exists(args.target_dir): 103 | os.makedirs(args.target_dir) 104 | 105 | print(20*"=", "Fetching the dataset:", 20*'=') 106 | download_unzip(args.dataset_url, os.path.join(args.target_dir, "dataset")) 107 | 108 | print(20*"=", "Fetching the word embeddings:", 20*"=") 109 | download_unzip(args.embeddings_url, 110 | os.path.join(args.target_dir, "embeddings")) 111 | -------------------------------------------------------------------------------- /scripts/preprocessing/preprocess_bnli.py: -------------------------------------------------------------------------------- 1 | """ 2 | Preprocess the Breaking NLI data set. 3 | """ 4 | 5 | import os 6 | import json 7 | import pickle 8 | import argparse 9 | 10 | from nltk import word_tokenize 11 | from esim.data import Preprocessor 12 | 13 | 14 | def jsonl_to_txt(input_file, output_file): 15 | """ 16 | Transform the Breaking NLI data from a jsonl file to .txt for 17 | further processing. 18 | 19 | Args: 20 | input_file: The path to the Breaking NLI data set in jsonl format. 21 | output_file: The path to the .txt file where the tranformed data must 22 | be saved. 23 | """ 24 | with open(input_file, 'r') as input_f, open(output_file, 'w') as output_f: 25 | output_f.write("label\tsentence1\tsentence2\t\t\t\t\t\tpairID\n") 26 | 27 | for line in input_f: 28 | data = json.loads(line) 29 | 30 | # Sentences in the Breaking NLI data set aren't distributed in the 31 | # form of binary parses, so we must tokenise them with nltk. 32 | sentence1 = word_tokenize(data['sentence1']) 33 | sentence1 = " ".join(sentence1) 34 | sentence2 = word_tokenize(data['sentence2']) 35 | sentence2 = " ".join(sentence2) 36 | 37 | # The 5 tabs between sentence 2 and the pairID are added to 38 | # follow the same structure as the txt files in SNLI and MNLI. 39 | output_f.write(data['gold_label'] + "\t" + sentence1 + "\t" + 40 | sentence2 + "\t\t\t\t\t" + str(data['pairID']) + 41 | "\n") 42 | 43 | 44 | def preprocess_BNLI_data(input_file, 45 | targetdir, 46 | worddict, 47 | labeldict): 48 | """ 49 | Preprocess the BNLI data set so it can be used to test a model trained 50 | on SNLI. 51 | 52 | Args: 53 | inputdir: The path to the file containing the Breaking NLI (BNLI) data. 54 | target_dir: The path to the directory where the preprocessed Breaking 55 | NLI data must be saved. 56 | worddict: The path to the pickled worddict used for preprocessing the 57 | training data on which models were trained before being tested on 58 | BNLI. 59 | labeldict: The dict of labels used for the training data on which 60 | models were trained before being tested on BNLI. 61 | """ 62 | if not os.path.exists(targetdir): 63 | os.makedirs(targetdir) 64 | 65 | output_file = os.path.join(targetdir, "bnli.txt") 66 | 67 | print(20*"=", " Preprocessing Breaking NLI data set ", 20*"=") 68 | print("\t* Tranforming jsonl data to txt...") 69 | jsonl_to_txt(input_file, output_file) 70 | 71 | preprocessor = Preprocessor(labeldict=labeldict) 72 | 73 | with open(worddict, 'rb') as pkl: 74 | wdict = pickle.load(pkl) 75 | preprocessor.worddict = wdict 76 | 77 | print("\t* Reading txt data...") 78 | data = preprocessor.read_data(output_file) 79 | 80 | print("\t* Transforming words in premises and hypotheses to indices...") 81 | transformed_data = preprocessor.transform_to_indices(data) 82 | 83 | print("\t* Saving result...") 84 | with open(os.path.join(targetdir, "bnli_data.pkl"), 'wb') as pkl_file: 85 | pickle.dump(transformed_data, pkl_file) 86 | 87 | 88 | if __name__ == "__main__": 89 | default_config = "../../config/preprocessing/bnli_preprocessing.json" 90 | 91 | parser = argparse.ArgumentParser(description="Preprocess the Breaking\ 92 | NLI (BNLI) dataset") 93 | parser.add_argument("--config", 94 | default=default_config, 95 | help="Path to a configuration file for preprocessing BNLI") 96 | args = parser.parse_args() 97 | 98 | script_dir = os.path.dirname(os.path.realpath(__file__)) 99 | 100 | if args.config == default_config: 101 | config_path = os.path.join(script_dir, args.config) 102 | else: 103 | config_path = args.config 104 | 105 | with open(os.path.normpath(config_path), "r") as cfg_file: 106 | config = json.load(cfg_file) 107 | 108 | preprocess_BNLI_data( 109 | os.path.normpath(os.path.join(script_dir, config["data_file"])), 110 | os.path.normpath(os.path.join(script_dir, config["target_dir"])), 111 | os.path.normpath(os.path.join(script_dir, config["worddict"])), 112 | config["labeldict"] 113 | ) 114 | -------------------------------------------------------------------------------- /scripts/preprocessing/preprocess_mnli.py: -------------------------------------------------------------------------------- 1 | """ 2 | Preprocess the MultiNLI dataset and word embeddings to be used by the 3 | ESIM model. 4 | """ 5 | # Aurelien Coet, 2019. 6 | 7 | import os 8 | import pickle 9 | import argparse 10 | import fnmatch 11 | import json 12 | 13 | from esim.data import Preprocessor 14 | 15 | 16 | def preprocess_MNLI_data(inputdir, 17 | embeddings_file, 18 | targetdir, 19 | lowercase=False, 20 | ignore_punctuation=False, 21 | num_words=None, 22 | stopwords=[], 23 | labeldict={}, 24 | bos=None, 25 | eos=None): 26 | """ 27 | Preprocess the data from the MultiNLI corpus so it can be used by the 28 | ESIM model. 29 | Compute a worddict from the train set, and transform the words in 30 | the sentences of the corpus to their indices, as well as the labels. 31 | Build an embedding matrix from pretrained word vectors. 32 | The preprocessed data is saved in pickled form in some target directory. 33 | 34 | Args: 35 | inputdir: The path to the directory containing the NLI corpus. 36 | embeddings_file: The path to the file containing the pretrained 37 | word vectors that must be used to build the embedding matrix. 38 | targetdir: The path to the directory where the preprocessed data 39 | must be saved. 40 | lowercase: Boolean value indicating whether to lowercase the premises 41 | and hypotheseses in the input data. Defautls to False. 42 | ignore_punctuation: Boolean value indicating whether to remove 43 | punctuation from the input data. Defaults to False. 44 | num_words: Integer value indicating the size of the vocabulary to use 45 | for the word embeddings. If set to None, all words are kept. 46 | Defaults to None. 47 | stopwords: A list of words that must be ignored when preprocessing 48 | the data. Defaults to an empty list. 49 | bos: A string indicating the symbol to use for beginning of sentence 50 | tokens. If set to None, bos tokens aren't used. Defaults to None. 51 | eos: A string indicating the symbol to use for end of sentence tokens. 52 | If set to None, eos tokens aren't used. Defaults to None. 53 | """ 54 | if not os.path.exists(targetdir): 55 | os.makedirs(targetdir) 56 | 57 | # Retrieve the train, dev and test data files from the dataset directory. 58 | train_file = "" 59 | matched_dev_file = "" 60 | mismatched_dev_file = "" 61 | matched_test_file = "" 62 | mismatched_test_file = "" 63 | for file in os.listdir(inputdir): 64 | if fnmatch.fnmatch(file, "*_train.txt"): 65 | train_file = file 66 | elif fnmatch.fnmatch(file, "*_dev_matched.txt"): 67 | matched_dev_file = file 68 | elif fnmatch.fnmatch(file, "*_dev_mismatched.txt"): 69 | mismatched_dev_file = file 70 | elif fnmatch.fnmatch(file, "*_test_matched_unlabeled.txt"): 71 | matched_test_file = file 72 | elif fnmatch.fnmatch(file, "*_test_mismatched_unlabeled.txt"): 73 | mismatched_test_file = file 74 | 75 | # -------------------- Train data preprocessing -------------------- # 76 | preprocessor = Preprocessor(lowercase=lowercase, 77 | ignore_punctuation=ignore_punctuation, 78 | num_words=num_words, 79 | stopwords=stopwords, 80 | labeldict=labeldict, 81 | bos=bos, 82 | eos=eos) 83 | 84 | print(20*"=", " Preprocessing train set ", 20*"=") 85 | print("\t* Reading data...") 86 | data = preprocessor.read_data(os.path.join(inputdir, train_file)) 87 | 88 | print("\t* Computing worddict and saving it...") 89 | preprocessor.build_worddict(data) 90 | with open(os.path.join(targetdir, "worddict.pkl"), "wb") as pkl_file: 91 | pickle.dump(preprocessor.worddict, pkl_file) 92 | 93 | print("\t* Transforming words in premises and hypotheses to indices...") 94 | transformed_data = preprocessor.transform_to_indices(data) 95 | print("\t* Saving result...") 96 | with open(os.path.join(targetdir, "train_data.pkl"), "wb") as pkl_file: 97 | pickle.dump(transformed_data, pkl_file) 98 | 99 | # -------------------- Validation data preprocessing -------------------- # 100 | print(20*"=", " Preprocessing dev sets ", 20*"=") 101 | print("\t* Reading matched dev data...") 102 | data = preprocessor.read_data(os.path.join(inputdir, matched_dev_file)) 103 | 104 | print("\t* Transforming words in premises and hypotheses to indices...") 105 | transformed_data = preprocessor.transform_to_indices(data) 106 | print("\t* Saving result...") 107 | with open(os.path.join(targetdir, "matched_dev_data.pkl"), "wb") as pkl_file: 108 | pickle.dump(transformed_data, pkl_file) 109 | 110 | print("\t* Reading mismatched dev data...") 111 | data = preprocessor.read_data(os.path.join(inputdir, mismatched_dev_file)) 112 | 113 | print("\t* Transforming words in premises and hypotheses to indices...") 114 | transformed_data = preprocessor.transform_to_indices(data) 115 | print("\t* Saving result...") 116 | with open(os.path.join(targetdir, "mismatched_dev_data.pkl"), "wb") as pkl_file: 117 | pickle.dump(transformed_data, pkl_file) 118 | 119 | # -------------------- Test data preprocessing -------------------- # 120 | print(20*"=", " Preprocessing test sets ", 20*"=") 121 | print("\t* Reading matched test data...") 122 | data = preprocessor.read_data(os.path.join(inputdir, matched_test_file)) 123 | 124 | print("\t* Transforming words in premises and hypotheses to indices...") 125 | transformed_data = preprocessor.transform_to_indices(data) 126 | print("\t* Saving result...") 127 | with open(os.path.join(targetdir, "matched_test_data.pkl"), "wb") as pkl_file: 128 | pickle.dump(transformed_data, pkl_file) 129 | 130 | print("\t* Reading mismatched test data...") 131 | data = preprocessor.read_data(os.path.join(inputdir, mismatched_test_file)) 132 | 133 | print("\t* Transforming words in premises and hypotheses to indices...") 134 | transformed_data = preprocessor.transform_to_indices(data) 135 | print("\t* Saving result...") 136 | with open(os.path.join(targetdir, "mismatched_test_data.pkl"), "wb") as pkl_file: 137 | pickle.dump(transformed_data, pkl_file) 138 | 139 | # -------------------- Embeddings preprocessing -------------------- # 140 | print(20*"=", " Preprocessing embeddings ", 20*"=") 141 | print("\t* Building embedding matrix and saving it...") 142 | embed_matrix = preprocessor.build_embedding_matrix(embeddings_file) 143 | with open(os.path.join(targetdir, "embeddings.pkl"), "wb") as pkl_file: 144 | pickle.dump(embed_matrix, pkl_file) 145 | 146 | 147 | if __name__ == "__main__": 148 | default_config = "../../config/preprocessing/mnli_preprocessing.json" 149 | 150 | parser = argparse.ArgumentParser(description="Preprocess the MultiNLI dataset") 151 | parser.add_argument("--config", 152 | default=default_config, 153 | help="Path to a configuration file for preprocessing MultiNLI") 154 | args = parser.parse_args() 155 | 156 | script_dir = os.path.dirname(os.path.realpath(__file__)) 157 | 158 | if args.config == default_config: 159 | config_path = os.path.join(script_dir, args.config) 160 | else: 161 | config_path = args.config 162 | 163 | with open(os.path.normpath(config_path), 'r') as cfg_file: 164 | config = json.load(cfg_file) 165 | 166 | preprocess_MNLI_data( 167 | os.path.normpath(os.path.join(script_dir, config["data_dir"])), 168 | os.path.normpath(os.path.join(script_dir, config["embeddings_file"])), 169 | os.path.normpath(os.path.join(script_dir, config["target_dir"])), 170 | lowercase=config["lowercase"], 171 | ignore_punctuation=config["ignore_punctuation"], 172 | num_words=config["num_words"], 173 | stopwords=config["stopwords"], 174 | labeldict=config["labeldict"], 175 | bos=config["bos"], 176 | eos=config["eos"] 177 | ) 178 | -------------------------------------------------------------------------------- /scripts/preprocessing/preprocess_snli.py: -------------------------------------------------------------------------------- 1 | """ 2 | Preprocess the SNLI dataset and word embeddings to be used by the ESIM model. 3 | """ 4 | # Aurelien Coet, 2018. 5 | 6 | import os 7 | import pickle 8 | import argparse 9 | import fnmatch 10 | import json 11 | 12 | from esim.data import Preprocessor 13 | 14 | 15 | def preprocess_SNLI_data(inputdir, 16 | embeddings_file, 17 | targetdir, 18 | lowercase=False, 19 | ignore_punctuation=False, 20 | num_words=None, 21 | stopwords=[], 22 | labeldict={}, 23 | bos=None, 24 | eos=None): 25 | """ 26 | Preprocess the data from the SNLI corpus so it can be used by the 27 | ESIM model. 28 | Compute a worddict from the train set, and transform the words in 29 | the sentences of the corpus to their indices, as well as the labels. 30 | Build an embedding matrix from pretrained word vectors. 31 | The preprocessed data is saved in pickled form in some target directory. 32 | 33 | Args: 34 | inputdir: The path to the directory containing the NLI corpus. 35 | embeddings_file: The path to the file containing the pretrained 36 | word vectors that must be used to build the embedding matrix. 37 | targetdir: The path to the directory where the preprocessed data 38 | must be saved. 39 | lowercase: Boolean value indicating whether to lowercase the premises 40 | and hypotheseses in the input data. Defautls to False. 41 | ignore_punctuation: Boolean value indicating whether to remove 42 | punctuation from the input data. Defaults to False. 43 | num_words: Integer value indicating the size of the vocabulary to use 44 | for the word embeddings. If set to None, all words are kept. 45 | Defaults to None. 46 | stopwords: A list of words that must be ignored when preprocessing 47 | the data. Defaults to an empty list. 48 | bos: A string indicating the symbol to use for beginning of sentence 49 | tokens. If set to None, bos tokens aren't used. Defaults to None. 50 | eos: A string indicating the symbol to use for end of sentence tokens. 51 | If set to None, eos tokens aren't used. Defaults to None. 52 | """ 53 | if not os.path.exists(targetdir): 54 | os.makedirs(targetdir) 55 | 56 | # Retrieve the train, dev and test data files from the dataset directory. 57 | train_file = "" 58 | dev_file = "" 59 | test_file = "" 60 | for file in os.listdir(inputdir): 61 | if fnmatch.fnmatch(file, "*_train.txt"): 62 | train_file = file 63 | elif fnmatch.fnmatch(file, "*_dev.txt"): 64 | dev_file = file 65 | elif fnmatch.fnmatch(file, "*_test.txt"): 66 | test_file = file 67 | 68 | # -------------------- Train data preprocessing -------------------- # 69 | preprocessor = Preprocessor(lowercase=lowercase, 70 | ignore_punctuation=ignore_punctuation, 71 | num_words=num_words, 72 | stopwords=stopwords, 73 | labeldict=labeldict, 74 | bos=bos, 75 | eos=eos) 76 | 77 | print(20*"=", " Preprocessing train set ", 20*"=") 78 | print("\t* Reading data...") 79 | data = preprocessor.read_data(os.path.join(inputdir, train_file)) 80 | 81 | print("\t* Computing worddict and saving it...") 82 | preprocessor.build_worddict(data) 83 | with open(os.path.join(targetdir, "worddict.pkl"), "wb") as pkl_file: 84 | pickle.dump(preprocessor.worddict, pkl_file) 85 | 86 | print("\t* Transforming words in premises and hypotheses to indices...") 87 | transformed_data = preprocessor.transform_to_indices(data) 88 | print("\t* Saving result...") 89 | with open(os.path.join(targetdir, "train_data.pkl"), "wb") as pkl_file: 90 | pickle.dump(transformed_data, pkl_file) 91 | 92 | # -------------------- Validation data preprocessing -------------------- # 93 | print(20*"=", " Preprocessing dev set ", 20*"=") 94 | print("\t* Reading data...") 95 | data = preprocessor.read_data(os.path.join(inputdir, dev_file)) 96 | 97 | print("\t* Transforming words in premises and hypotheses to indices...") 98 | transformed_data = preprocessor.transform_to_indices(data) 99 | print("\t* Saving result...") 100 | with open(os.path.join(targetdir, "dev_data.pkl"), "wb") as pkl_file: 101 | pickle.dump(transformed_data, pkl_file) 102 | 103 | # -------------------- Test data preprocessing -------------------- # 104 | print(20*"=", " Preprocessing test set ", 20*"=") 105 | print("\t* Reading data...") 106 | data = preprocessor.read_data(os.path.join(inputdir, test_file)) 107 | 108 | print("\t* Transforming words in premises and hypotheses to indices...") 109 | transformed_data = preprocessor.transform_to_indices(data) 110 | print("\t* Saving result...") 111 | with open(os.path.join(targetdir, "test_data.pkl"), "wb") as pkl_file: 112 | pickle.dump(transformed_data, pkl_file) 113 | 114 | # -------------------- Embeddings preprocessing -------------------- # 115 | print(20*"=", " Preprocessing embeddings ", 20*"=") 116 | print("\t* Building embedding matrix and saving it...") 117 | embed_matrix = preprocessor.build_embedding_matrix(embeddings_file) 118 | with open(os.path.join(targetdir, "embeddings.pkl"), "wb") as pkl_file: 119 | pickle.dump(embed_matrix, pkl_file) 120 | 121 | 122 | if __name__ == "__main__": 123 | default_config = "../../config/preprocessing/snli_preprocessing.json" 124 | 125 | parser = argparse.ArgumentParser(description="Preprocess the SNLI dataset") 126 | parser.add_argument( 127 | "--config", 128 | default=default_config, 129 | help="Path to a configuration file for preprocessing SNLI" 130 | ) 131 | args = parser.parse_args() 132 | 133 | script_dir = os.path.dirname(os.path.realpath(__file__)) 134 | 135 | if args.config == default_config: 136 | config_path = os.path.join(script_dir, args.config) 137 | else: 138 | config_path = args.config 139 | 140 | with open(os.path.normpath(config_path), "r") as cfg_file: 141 | config = json.load(cfg_file) 142 | 143 | preprocess_SNLI_data( 144 | os.path.normpath(os.path.join(script_dir, config["data_dir"])), 145 | os.path.normpath(os.path.join(script_dir, config["embeddings_file"])), 146 | os.path.normpath(os.path.join(script_dir, config["target_dir"])), 147 | lowercase=config["lowercase"], 148 | ignore_punctuation=config["ignore_punctuation"], 149 | num_words=config["num_words"], 150 | stopwords=config["stopwords"], 151 | labeldict=config["labeldict"], 152 | bos=config["bos"], 153 | eos=config["eos"] 154 | ) 155 | -------------------------------------------------------------------------------- /scripts/testing/test_mnli.py: -------------------------------------------------------------------------------- 1 | """ 2 | Test the ESIM model on the preprocessed MultiNLI dataset. 3 | """ 4 | # Aurelien Coet, 2019. 5 | 6 | import os 7 | import pickle 8 | import argparse 9 | import torch 10 | import json 11 | 12 | from torch.utils.data import DataLoader 13 | from esim.data import NLIDataset 14 | from esim.model import ESIM 15 | 16 | 17 | def predict(model, dataloader, labeldict): 18 | """ 19 | Predict the labels of an unlabelled test set with a pretrained model. 20 | 21 | Args: 22 | model: The torch module which must be used to make predictions. 23 | dataloader: A DataLoader object to iterate over some dataset. 24 | labeldict: A dictionary associating labels to integer values. 25 | 26 | Returns: 27 | A dictionary associating pair ids to predicted labels. 28 | """ 29 | # Switch the model to eval mode. 30 | model.eval() 31 | device = model.device 32 | 33 | # Revert the labeldict to associate integers to labels. 34 | labels = {index: label for label, index in labeldict.items()} 35 | predictions = {} 36 | 37 | # Deactivate autograd for evaluation. 38 | with torch.no_grad(): 39 | for batch in dataloader: 40 | 41 | # Move input and output data to the GPU if one is used. 42 | ids = batch["id"] 43 | premises = batch['premise'].to(device) 44 | premises_lengths = batch['premise_length'].to(device) 45 | hypotheses = batch['hypothesis'].to(device) 46 | hypotheses_lengths = batch['hypothesis_length'].to(device) 47 | 48 | _, probs = model(premises, 49 | premises_lengths, 50 | hypotheses, 51 | hypotheses_lengths) 52 | 53 | _, preds = probs.max(dim=1) 54 | 55 | for i, pair_id in enumerate(ids): 56 | predictions[pair_id] = labels[int(preds[i])] 57 | 58 | return predictions 59 | 60 | 61 | def main(test_files, pretrained_file, labeldict, output_dir, batch_size=32): 62 | """ 63 | Test the ESIM model with pretrained weights on the MultiNLI dataset. 64 | 65 | Args: 66 | test_files: The paths to the preprocessed matched and mismatched MNLI 67 | test sets. 68 | pretrained_file: The path to a checkpoint produced by the 69 | 'train_mnli' script. 70 | labeldict: A dictionary associating labels (classes) to integer values. 71 | output_dir: The path to a directory where the predictions of the model 72 | must be saved. 73 | batch_size: The size of the batches used for testing. Defaults to 32. 74 | """ 75 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 76 | 77 | print(20 * "=", " Preparing for testing ", 20 * "=") 78 | 79 | output_dir = os.path.normpath(output_dir) 80 | if not os.path.exists(output_dir): 81 | os.makedirs(output_dir) 82 | 83 | checkpoint = torch.load(pretrained_file) 84 | 85 | # Retrieve model parameters from the checkpoint. 86 | vocab_size = checkpoint['model']['_word_embedding.weight'].size(0) 87 | embedding_dim = checkpoint['model']['_word_embedding.weight'].size(1) 88 | hidden_size = checkpoint['model']['_projection.0.weight'].size(0) 89 | num_classes = checkpoint['model']['_classification.4.weight'].size(0) 90 | 91 | print("\t* Loading test data...") 92 | with open(os.path.normpath(test_files["matched"]), 'rb') as pkl: 93 | matched_test_data = NLIDataset(pickle.load(pkl)) 94 | with open(os.path.normpath(test_files["mismatched"]), 'rb') as pkl: 95 | mismatched_test_data = NLIDataset(pickle.load(pkl)) 96 | 97 | matched_test_loader = DataLoader(matched_test_data, 98 | shuffle=False, 99 | batch_size=batch_size) 100 | mismatched_test_loader = DataLoader(mismatched_test_data, 101 | shuffle=False, 102 | batch_size=batch_size) 103 | 104 | print("\t* Building model...") 105 | model = ESIM(vocab_size, 106 | embedding_dim, 107 | hidden_size, 108 | num_classes=num_classes, 109 | device=device).to(device) 110 | 111 | model.load_state_dict(checkpoint['model']) 112 | 113 | print(20 * "=", 114 | " Prediction on MNLI with ESIM model on device: {} ".format(device), 115 | 20 * "=") 116 | 117 | print("\t* Prediction for matched test set...") 118 | predictions = predict(model, matched_test_loader, labeldict) 119 | 120 | with open(os.path.join(output_dir, "matched_predictions.csv"), 'w') as output_f: 121 | output_f.write("pairID,gold_label\n") 122 | for pair_id in predictions: 123 | output_f.write(pair_id+","+predictions[pair_id]+"\n") 124 | 125 | print("\t* Prediction for mismatched test set...") 126 | predictions = predict(model, mismatched_test_loader, labeldict) 127 | 128 | with open(os.path.join(output_dir, "mismatched_predictions.csv"), 'w') as output_f: 129 | output_f.write("pairID,gold_label\n") 130 | for pair_id in predictions: 131 | output_f.write(pair_id+","+predictions[pair_id]+"\n") 132 | 133 | 134 | if __name__ == "__main__": 135 | parser = argparse.ArgumentParser(description='Test the ESIM model on\ 136 | the MNLI matched and mismatched test sets') 137 | parser.add_argument('checkpoint', 138 | help="Path to a checkpoint with a pretrained model") 139 | parser.add_argument('--config', default='../config/testing/mnli_testing.json', 140 | help='Path to a configuration file') 141 | args = parser.parse_args() 142 | 143 | with open(os.path.normpath(args.config), 'r') as config_file: 144 | config = json.load(config_file) 145 | 146 | main(config['test_files'], 147 | args.checkpoint, 148 | config['labeldict'], 149 | config['output_dir'], 150 | config['batch_size']) 151 | -------------------------------------------------------------------------------- /scripts/testing/test_snli.py: -------------------------------------------------------------------------------- 1 | """ 2 | Test the ESIM model on some preprocessed dataset. 3 | """ 4 | # Aurelien Coet, 2018. 5 | 6 | import time 7 | import pickle 8 | import argparse 9 | import torch 10 | 11 | from torch.utils.data import DataLoader 12 | from esim.data import NLIDataset 13 | from esim.model import ESIM 14 | from esim.utils import correct_predictions 15 | 16 | 17 | def test(model, dataloader): 18 | """ 19 | Test the accuracy of a model on some labelled test dataset. 20 | 21 | Args: 22 | model: The torch module on which testing must be performed. 23 | dataloader: A DataLoader object to iterate over some dataset. 24 | 25 | Returns: 26 | batch_time: The average time to predict the classes of a batch. 27 | total_time: The total time to process the whole dataset. 28 | accuracy: The accuracy of the model on the input data. 29 | """ 30 | # Switch the model to eval mode. 31 | model.eval() 32 | device = model.device 33 | 34 | time_start = time.time() 35 | batch_time = 0.0 36 | accuracy = 0.0 37 | 38 | # Deactivate autograd for evaluation. 39 | with torch.no_grad(): 40 | for batch in dataloader: 41 | batch_start = time.time() 42 | 43 | # Move input and output data to the GPU if one is used. 44 | premises = batch["premise"].to(device) 45 | premises_lengths = batch["premise_length"].to(device) 46 | hypotheses = batch["hypothesis"].to(device) 47 | hypotheses_lengths = batch["hypothesis_length"].to(device) 48 | labels = batch["label"].to(device) 49 | 50 | _, probs = model(premises, 51 | premises_lengths, 52 | hypotheses, 53 | hypotheses_lengths) 54 | 55 | accuracy += correct_predictions(probs, labels) 56 | batch_time += time.time() - batch_start 57 | 58 | batch_time /= len(dataloader) 59 | total_time = time.time() - time_start 60 | accuracy /= (len(dataloader.dataset)) 61 | 62 | return batch_time, total_time, accuracy 63 | 64 | 65 | def main(test_file, pretrained_file, batch_size=32): 66 | """ 67 | Test the ESIM model with pretrained weights on some dataset. 68 | 69 | Args: 70 | test_file: The path to a file containing preprocessed NLI data. 71 | pretrained_file: The path to a checkpoint produced by the 72 | 'train_model' script. 73 | vocab_size: The number of words in the vocabulary of the model 74 | being tested. 75 | embedding_dim: The size of the embeddings in the model. 76 | hidden_size: The size of the hidden layers in the model. Must match 77 | the size used during training. Defaults to 300. 78 | num_classes: The number of classes in the output of the model. Must 79 | match the value used during training. Defaults to 3. 80 | batch_size: The size of the batches used for testing. Defaults to 32. 81 | """ 82 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 83 | 84 | print(20 * "=", " Preparing for testing ", 20 * "=") 85 | 86 | checkpoint = torch.load(pretrained_file) 87 | 88 | # Retrieving model parameters from checkpoint. 89 | vocab_size = checkpoint["model"]["_word_embedding.weight"].size(0) 90 | embedding_dim = checkpoint["model"]['_word_embedding.weight'].size(1) 91 | hidden_size = checkpoint["model"]["_projection.0.weight"].size(0) 92 | num_classes = checkpoint["model"]["_classification.4.weight"].size(0) 93 | 94 | print("\t* Loading test data...") 95 | with open(test_file, "rb") as pkl: 96 | test_data = NLIDataset(pickle.load(pkl)) 97 | 98 | test_loader = DataLoader(test_data, shuffle=False, batch_size=batch_size) 99 | 100 | print("\t* Building model...") 101 | model = ESIM(vocab_size, 102 | embedding_dim, 103 | hidden_size, 104 | num_classes=num_classes, 105 | device=device).to(device) 106 | 107 | model.load_state_dict(checkpoint["model"]) 108 | 109 | print(20 * "=", 110 | " Testing ESIM model on device: {} ".format(device), 111 | 20 * "=") 112 | batch_time, total_time, accuracy = test(model, test_loader) 113 | 114 | print("-> Average batch processing time: {:.4f}s, total test time:\ 115 | {:.4f}s, accuracy: {:.4f}%".format(batch_time, total_time, (accuracy*100))) 116 | 117 | 118 | if __name__ == "__main__": 119 | parser = argparse.ArgumentParser(description="Test the ESIM model on\ 120 | some dataset") 121 | parser.add_argument("test_data", 122 | help="Path to a file containing preprocessed test data") 123 | parser.add_argument("checkpoint", 124 | help="Path to a checkpoint with a pretrained model") 125 | parser.add_argument("--batch_size", type=int, default=32, 126 | help="Batch size to use during testing") 127 | args = parser.parse_args() 128 | 129 | main(args.test_data, 130 | args.checkpoint, 131 | args.batch_size) 132 | -------------------------------------------------------------------------------- /scripts/training/train_mnli.py: -------------------------------------------------------------------------------- 1 | """ 2 | Train the ESIM model on the preprocessed MultiNLI dataset. 3 | """ 4 | # Aurelien Coet, 2019. 5 | 6 | import os 7 | import argparse 8 | import pickle 9 | import torch 10 | import json 11 | 12 | import matplotlib.pyplot as plt 13 | import torch.nn as nn 14 | 15 | from torch.utils.data import DataLoader 16 | from esim.data import NLIDataset 17 | from esim.model import ESIM 18 | from utils import train, validate 19 | 20 | 21 | def main(train_file, 22 | valid_files, 23 | embeddings_file, 24 | target_dir, 25 | hidden_size=300, 26 | dropout=0.5, 27 | num_classes=3, 28 | epochs=64, 29 | batch_size=32, 30 | lr=0.0004, 31 | patience=5, 32 | max_grad_norm=10.0, 33 | checkpoint=None): 34 | """ 35 | Train the ESIM model on the SNLI dataset. 36 | 37 | Args: 38 | train_file: A path to some preprocessed data that must be used 39 | to train the model. 40 | valid_files: A dict containing the paths to the preprocessed matched 41 | and mismatched datasets that must be used to validate the model. 42 | embeddings_file: A path to some preprocessed word embeddings that 43 | must be used to initialise the model. 44 | target_dir: The path to a directory where the trained model must 45 | be saved. 46 | hidden_size: The size of the hidden layers in the model. Defaults 47 | to 300. 48 | dropout: The dropout rate to use in the model. Defaults to 0.5. 49 | num_classes: The number of classes in the output of the model. 50 | Defaults to 3. 51 | epochs: The maximum number of epochs for training. Defaults to 64. 52 | batch_size: The size of the batches for training. Defaults to 32. 53 | lr: The learning rate for the optimizer. Defaults to 0.0004. 54 | patience: The patience to use for early stopping. Defaults to 5. 55 | checkpoint: A checkpoint from which to continue training. If None, 56 | training starts from scratch. Defaults to None. 57 | """ 58 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 59 | 60 | print(20 * "=", " Preparing for training ", 20 * "=") 61 | 62 | if not os.path.exists(target_dir): 63 | os.makedirs(target_dir) 64 | 65 | # -------------------- Data loading ------------------- # 66 | print("\t* Loading training data...") 67 | with open(train_file, "rb") as pkl: 68 | train_data = NLIDataset(pickle.load(pkl)) 69 | 70 | train_loader = DataLoader(train_data, shuffle=True, batch_size=batch_size) 71 | 72 | print("\t* Loading validation data...") 73 | with open(os.path.normpath(valid_files["matched"]), "rb") as pkl: 74 | matched_valid_data = NLIDataset(pickle.load(pkl)) 75 | 76 | with open(os.path.normpath(valid_files["mismatched"]), "rb") as pkl: 77 | mismatched_valid_data = NLIDataset(pickle.load(pkl)) 78 | 79 | matched_valid_loader = DataLoader(matched_valid_data, 80 | shuffle=False, 81 | batch_size=batch_size) 82 | mismatched_valid_loader = DataLoader(mismatched_valid_data, 83 | shuffle=False, 84 | batch_size=batch_size) 85 | 86 | # -------------------- Model definition ------------------- # 87 | print('\t* Building model...') 88 | with open(embeddings_file, "rb") as pkl: 89 | embeddings = torch.tensor(pickle.load(pkl), dtype=torch.float)\ 90 | .to(device) 91 | 92 | model = ESIM(embeddings.shape[0], 93 | embeddings.shape[1], 94 | hidden_size, 95 | embeddings=embeddings, 96 | dropout=dropout, 97 | num_classes=num_classes, 98 | device=device).to(device) 99 | 100 | # -------------------- Preparation for training ------------------- # 101 | criterion = nn.CrossEntropyLoss() 102 | optimizer = torch.optim.Adam(model.parameters(), lr=lr) 103 | scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 104 | mode="max", 105 | factor=0.5, 106 | patience=0) 107 | 108 | best_score = 0.0 109 | start_epoch = 1 110 | 111 | # Data for loss curves plot. 112 | epochs_count = [] 113 | train_losses = [] 114 | matched_valid_losses = [] 115 | mismatched_valid_losses = [] 116 | 117 | # Continuing training from a checkpoint if one was given as argument. 118 | if checkpoint: 119 | checkpoint = torch.load(checkpoint) 120 | start_epoch = checkpoint["epoch"] + 1 121 | best_score = checkpoint["best_score"] 122 | 123 | print("\t* Training will continue on existing model from epoch {}..." 124 | .format(start_epoch)) 125 | 126 | model.load_state_dict(checkpoint["model"]) 127 | optimizer.load_state_dict(checkpoint["optimizer"]) 128 | epochs_count = checkpoint["epochs_count"] 129 | train_losses = checkpoint["train_losses"] 130 | matched_valid_losses = checkpoint["match_valid_losses"] 131 | mismatched_valid_losses = checkpoint["mismatch_valid_losses"] 132 | 133 | # Compute loss and accuracy before starting (or resuming) training. 134 | _, valid_loss, valid_accuracy = validate(model, 135 | matched_valid_loader, 136 | criterion) 137 | print("\t* Validation loss before training on matched data: {:.4f}, accuracy: {:.4f}%" 138 | .format(valid_loss, (valid_accuracy*100))) 139 | 140 | _, valid_loss, valid_accuracy = validate(model, 141 | mismatched_valid_loader, 142 | criterion) 143 | print("\t* Validation loss before training on mismatched data: {:.4f}, accuracy: {:.4f}%" 144 | .format(valid_loss, (valid_accuracy*100))) 145 | 146 | # -------------------- Training epochs ------------------- # 147 | print("\n", 148 | 20 * "=", 149 | "Training ESIM model on device: {}".format(device), 150 | 20 * "=") 151 | 152 | patience_counter = 0 153 | for epoch in range(start_epoch, epochs+1): 154 | epochs_count.append(epoch) 155 | 156 | print("* Training epoch {}:".format(epoch)) 157 | epoch_time, epoch_loss, epoch_accuracy = train(model, 158 | train_loader, 159 | optimizer, 160 | criterion, 161 | epoch, 162 | max_grad_norm) 163 | 164 | train_losses.append(epoch_loss) 165 | print("-> Training time: {:.4f}s, loss = {:.4f}, accuracy: {:.4f}%\n" 166 | .format(epoch_time, epoch_loss, (epoch_accuracy*100))) 167 | 168 | print("* Validation for epoch {} on matched data:".format(epoch)) 169 | epoch_time, epoch_loss, epoch_accuracy = validate(model, 170 | matched_valid_loader, 171 | criterion) 172 | matched_valid_losses.append(epoch_loss) 173 | print("-> Valid. time: {:.4f}s, loss: {:.4f}, accuracy: {:.4f}%" 174 | .format(epoch_time, epoch_loss, (epoch_accuracy*100))) 175 | 176 | print("* Validation for epoch {} on mismatched data:".format(epoch)) 177 | epoch_time, epoch_loss, mis_epoch_accuracy = validate(model, 178 | mismatched_valid_loader, 179 | criterion) 180 | mismatched_valid_losses.append(epoch_loss) 181 | print("-> Valid. time: {:.4f}s, loss: {:.4f}, accuracy: {:.4f}%\n" 182 | .format(epoch_time, epoch_loss, (mis_epoch_accuracy*100))) 183 | 184 | # Update the optimizer's learning rate with the scheduler. 185 | scheduler.step(epoch_accuracy) 186 | 187 | # Early stopping on validation accuracy. 188 | if epoch_accuracy < best_score: 189 | patience_counter += 1 190 | else: 191 | best_score = epoch_accuracy 192 | patience_counter = 0 193 | # Save the best model. The optimizer is not saved to avoid having 194 | # a checkpoint file that is too heavy to be shared. To resume 195 | # training from the best model, use the 'esim_*.pth.tar' 196 | # checkpoints instead. 197 | torch.save({"epoch": epoch, 198 | "model": model.state_dict(), 199 | "best_score": best_score, 200 | "epochs_count": epochs_count, 201 | "train_losses": train_losses, 202 | "match_valid_losses": matched_valid_losses, 203 | "mismatch_valid_losses": mismatched_valid_losses}, 204 | os.path.join(target_dir, "best.pth.tar")) 205 | 206 | # Save the model at each epoch. 207 | torch.save({"epoch": epoch, 208 | "model": model.state_dict(), 209 | "best_score": best_score, 210 | "optimizer": optimizer.state_dict(), 211 | "epochs_count": epochs_count, 212 | "train_losses": train_losses, 213 | "match_valid_losses": matched_valid_losses, 214 | "mismatch_valid_losses": mismatched_valid_losses}, 215 | os.path.join(target_dir, "esim_{}.pth.tar".format(epoch))) 216 | 217 | if patience_counter >= patience: 218 | print("-> Early stopping: patience limit reached, stopping...") 219 | break 220 | 221 | # Plotting of the loss curves for the train and validation sets. 222 | plt.figure() 223 | plt.plot(epochs_count, train_losses, "-r") 224 | plt.plot(epochs_count, matched_valid_losses, "-b") 225 | plt.plot(epochs_count, mismatched_valid_losses, "-g") 226 | plt.xlabel("epoch") 227 | plt.ylabel("loss") 228 | plt.legend(["Training loss", 229 | "Validation loss (matched set)", 230 | "Validation loss (mismatched set)"]) 231 | plt.title("Cross entropy loss") 232 | plt.show() 233 | 234 | 235 | if __name__ == "__main__": 236 | default_config = "../../config/training/mnli_training.json" 237 | 238 | parser = argparse.ArgumentParser(description="Train the ESIM model on MultiNLI") 239 | parser.add_argument("--config", 240 | default=default_config, 241 | help="Path to a json configuration file") 242 | parser.add_argument("--checkpoint", 243 | default=None, 244 | help="Path to a checkpoint file to resume training") 245 | args = parser.parse_args() 246 | 247 | script_dir = os.path.dirname(os.path.realpath(__file__)) 248 | 249 | if args.config == default_config: 250 | config_path = os.path.join(script_dir, args.config) 251 | else: 252 | config_path = args.config 253 | 254 | with open(os.path.normpath(config_path), "r") as config_file: 255 | config = json.load(config_file) 256 | 257 | main(os.path.normpath(os.path.join(script_dir, config["train_data"])), 258 | config["valid_data"], 259 | os.path.normpath(os.path.join(script_dir, config["embeddings"])), 260 | os.path.normpath(os.path.join(script_dir, config["target_dir"])), 261 | config["hidden_size"], 262 | config["dropout"], 263 | config["num_classes"], 264 | config["epochs"], 265 | config["batch_size"], 266 | config["lr"], 267 | config["patience"], 268 | config["max_gradient_norm"], 269 | args.checkpoint) 270 | -------------------------------------------------------------------------------- /scripts/training/train_snli.py: -------------------------------------------------------------------------------- 1 | """ 2 | Train the ESIM model on the preprocessed SNLI dataset. 3 | """ 4 | # Aurelien Coet, 2018. 5 | 6 | import os 7 | import argparse 8 | import pickle 9 | import torch 10 | import json 11 | 12 | import matplotlib.pyplot as plt 13 | import torch.nn as nn 14 | 15 | from torch.utils.data import DataLoader 16 | from esim.data import NLIDataset 17 | from esim.model import ESIM 18 | from utils import train, validate 19 | 20 | 21 | def main(train_file, 22 | valid_file, 23 | embeddings_file, 24 | target_dir, 25 | hidden_size=300, 26 | dropout=0.5, 27 | num_classes=3, 28 | epochs=64, 29 | batch_size=32, 30 | lr=0.0004, 31 | patience=5, 32 | max_grad_norm=10.0, 33 | checkpoint=None): 34 | """ 35 | Train the ESIM model on the SNLI dataset. 36 | 37 | Args: 38 | train_file: A path to some preprocessed data that must be used 39 | to train the model. 40 | valid_file: A path to some preprocessed data that must be used 41 | to validate the model. 42 | embeddings_file: A path to some preprocessed word embeddings that 43 | must be used to initialise the model. 44 | target_dir: The path to a directory where the trained model must 45 | be saved. 46 | hidden_size: The size of the hidden layers in the model. Defaults 47 | to 300. 48 | dropout: The dropout rate to use in the model. Defaults to 0.5. 49 | num_classes: The number of classes in the output of the model. 50 | Defaults to 3. 51 | epochs: The maximum number of epochs for training. Defaults to 64. 52 | batch_size: The size of the batches for training. Defaults to 32. 53 | lr: The learning rate for the optimizer. Defaults to 0.0004. 54 | patience: The patience to use for early stopping. Defaults to 5. 55 | checkpoint: A checkpoint from which to continue training. If None, 56 | training starts from scratch. Defaults to None. 57 | """ 58 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 59 | 60 | print(20 * "=", " Preparing for training ", 20 * "=") 61 | 62 | if not os.path.exists(target_dir): 63 | os.makedirs(target_dir) 64 | 65 | # -------------------- Data loading ------------------- # 66 | print("\t* Loading training data...") 67 | with open(train_file, "rb") as pkl: 68 | train_data = NLIDataset(pickle.load(pkl)) 69 | 70 | train_loader = DataLoader(train_data, shuffle=True, batch_size=batch_size) 71 | 72 | print("\t* Loading validation data...") 73 | with open(valid_file, "rb") as pkl: 74 | valid_data = NLIDataset(pickle.load(pkl)) 75 | 76 | valid_loader = DataLoader(valid_data, shuffle=False, batch_size=batch_size) 77 | 78 | # -------------------- Model definition ------------------- # 79 | print("\t* Building model...") 80 | with open(embeddings_file, "rb") as pkl: 81 | embeddings = torch.tensor(pickle.load(pkl), dtype=torch.float)\ 82 | .to(device) 83 | 84 | model = ESIM(embeddings.shape[0], 85 | embeddings.shape[1], 86 | hidden_size, 87 | embeddings=embeddings, 88 | dropout=dropout, 89 | num_classes=num_classes, 90 | device=device).to(device) 91 | 92 | # -------------------- Preparation for training ------------------- # 93 | criterion = nn.CrossEntropyLoss() 94 | optimizer = torch.optim.Adam(model.parameters(), lr=lr) 95 | scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 96 | mode="max", 97 | factor=0.5, 98 | patience=0) 99 | 100 | best_score = 0.0 101 | start_epoch = 1 102 | 103 | # Data for loss curves plot. 104 | epochs_count = [] 105 | train_losses = [] 106 | valid_losses = [] 107 | 108 | # Continuing training from a checkpoint if one was given as argument. 109 | if checkpoint: 110 | checkpoint = torch.load(checkpoint) 111 | start_epoch = checkpoint["epoch"] + 1 112 | best_score = checkpoint["best_score"] 113 | 114 | print("\t* Training will continue on existing model from epoch {}..." 115 | .format(start_epoch)) 116 | 117 | model.load_state_dict(checkpoint["model"]) 118 | optimizer.load_state_dict(checkpoint["optimizer"]) 119 | epochs_count = checkpoint["epochs_count"] 120 | train_losses = checkpoint["train_losses"] 121 | valid_losses = checkpoint["valid_losses"] 122 | 123 | # Compute loss and accuracy before starting (or resuming) training. 124 | _, valid_loss, valid_accuracy = validate(model, 125 | valid_loader, 126 | criterion) 127 | print("\t* Validation loss before training: {:.4f}, accuracy: {:.4f}%" 128 | .format(valid_loss, (valid_accuracy*100))) 129 | 130 | # -------------------- Training epochs ------------------- # 131 | print("\n", 132 | 20 * "=", 133 | "Training ESIM model on device: {}".format(device), 134 | 20 * "=") 135 | 136 | patience_counter = 0 137 | for epoch in range(start_epoch, epochs+1): 138 | epochs_count.append(epoch) 139 | 140 | print("* Training epoch {}:".format(epoch)) 141 | epoch_time, epoch_loss, epoch_accuracy = train(model, 142 | train_loader, 143 | optimizer, 144 | criterion, 145 | epoch, 146 | max_grad_norm) 147 | 148 | train_losses.append(epoch_loss) 149 | print("-> Training time: {:.4f}s, loss = {:.4f}, accuracy: {:.4f}%" 150 | .format(epoch_time, epoch_loss, (epoch_accuracy*100))) 151 | 152 | print("* Validation for epoch {}:".format(epoch)) 153 | epoch_time, epoch_loss, epoch_accuracy = validate(model, 154 | valid_loader, 155 | criterion) 156 | 157 | valid_losses.append(epoch_loss) 158 | print("-> Valid. time: {:.4f}s, loss: {:.4f}, accuracy: {:.4f}%\n" 159 | .format(epoch_time, epoch_loss, (epoch_accuracy*100))) 160 | 161 | # Update the optimizer's learning rate with the scheduler. 162 | scheduler.step(epoch_accuracy) 163 | 164 | # Early stopping on validation accuracy. 165 | if epoch_accuracy < best_score: 166 | patience_counter += 1 167 | else: 168 | best_score = epoch_accuracy 169 | patience_counter = 0 170 | # Save the best model. The optimizer is not saved to avoid having 171 | # a checkpoint file that is too heavy to be shared. To resume 172 | # training from the best model, use the 'esim_*.pth.tar' 173 | # checkpoints instead. 174 | torch.save({"epoch": epoch, 175 | "model": model.state_dict(), 176 | "best_score": best_score, 177 | "epochs_count": epochs_count, 178 | "train_losses": train_losses, 179 | "valid_losses": valid_losses}, 180 | os.path.join(target_dir, "best.pth.tar")) 181 | 182 | # Save the model at each epoch. 183 | torch.save({"epoch": epoch, 184 | "model": model.state_dict(), 185 | "best_score": best_score, 186 | "optimizer": optimizer.state_dict(), 187 | "epochs_count": epochs_count, 188 | "train_losses": train_losses, 189 | "valid_losses": valid_losses}, 190 | os.path.join(target_dir, "esim_{}.pth.tar".format(epoch))) 191 | 192 | if patience_counter >= patience: 193 | print("-> Early stopping: patience limit reached, stopping...") 194 | break 195 | 196 | # Plotting of the loss curves for the train and validation sets. 197 | plt.figure() 198 | plt.plot(epochs_count, train_losses, "-r") 199 | plt.plot(epochs_count, valid_losses, "-b") 200 | plt.xlabel("epoch") 201 | plt.ylabel("loss") 202 | plt.legend(["Training loss", "Validation loss"]) 203 | plt.title("Cross entropy loss") 204 | plt.show() 205 | 206 | 207 | if __name__ == "__main__": 208 | default_config = "../../config/training/snli_training.json" 209 | 210 | parser = argparse.ArgumentParser(description="Train the ESIM model on SNLI") 211 | parser.add_argument("--config", 212 | default=default_config, 213 | help="Path to a json configuration file") 214 | parser.add_argument("--checkpoint", 215 | default=None, 216 | help="Path to a checkpoint file to resume training") 217 | args = parser.parse_args() 218 | 219 | script_dir = os.path.dirname(os.path.realpath(__file__)) 220 | 221 | if args.config == default_config: 222 | config_path = os.path.join(script_dir, args.config) 223 | else: 224 | config_path = args.config 225 | 226 | with open(os.path.normpath(config_path), 'r') as config_file: 227 | config = json.load(config_file) 228 | 229 | main(os.path.normpath(os.path.join(script_dir, config["train_data"])), 230 | os.path.normpath(os.path.join(script_dir, config["valid_data"])), 231 | os.path.normpath(os.path.join(script_dir, config["embeddings"])), 232 | os.path.normpath(os.path.join(script_dir, config["target_dir"])), 233 | config["hidden_size"], 234 | config["dropout"], 235 | config["num_classes"], 236 | config["epochs"], 237 | config["batch_size"], 238 | config["lr"], 239 | config["patience"], 240 | config["max_gradient_norm"], 241 | args.checkpoint) 242 | -------------------------------------------------------------------------------- /scripts/training/utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Utility functions for training and validating models. 3 | """ 4 | 5 | import time 6 | import torch 7 | 8 | import torch.nn as nn 9 | 10 | from tqdm import tqdm 11 | from esim.utils import correct_predictions 12 | 13 | 14 | def train(model, 15 | dataloader, 16 | optimizer, 17 | criterion, 18 | epoch_number, 19 | max_gradient_norm): 20 | """ 21 | Train a model for one epoch on some input data with a given optimizer and 22 | criterion. 23 | 24 | Args: 25 | model: A torch module that must be trained on some input data. 26 | dataloader: A DataLoader object to iterate over the training data. 27 | optimizer: A torch optimizer to use for training on the input model. 28 | criterion: A loss criterion to use for training. 29 | epoch_number: The number of the epoch for which training is performed. 30 | max_gradient_norm: Max. norm for gradient norm clipping. 31 | 32 | Returns: 33 | epoch_time: The total time necessary to train the epoch. 34 | epoch_loss: The training loss computed for the epoch. 35 | epoch_accuracy: The accuracy computed for the epoch. 36 | """ 37 | # Switch the model to train mode. 38 | model.train() 39 | device = model.device 40 | 41 | epoch_start = time.time() 42 | batch_time_avg = 0.0 43 | running_loss = 0.0 44 | correct_preds = 0 45 | 46 | tqdm_batch_iterator = tqdm(dataloader) 47 | for batch_index, batch in enumerate(tqdm_batch_iterator): 48 | batch_start = time.time() 49 | 50 | # Move input and output data to the GPU if it is used. 51 | premises = batch["premise"].to(device) 52 | premises_lengths = batch["premise_length"].to(device) 53 | hypotheses = batch["hypothesis"].to(device) 54 | hypotheses_lengths = batch["hypothesis_length"].to(device) 55 | labels = batch["label"].to(device) 56 | 57 | optimizer.zero_grad() 58 | 59 | logits, probs = model(premises, 60 | premises_lengths, 61 | hypotheses, 62 | hypotheses_lengths) 63 | loss = criterion(logits, labels) 64 | loss.backward() 65 | 66 | nn.utils.clip_grad_norm_(model.parameters(), max_gradient_norm) 67 | optimizer.step() 68 | 69 | batch_time_avg += time.time() - batch_start 70 | running_loss += loss.item() 71 | correct_preds += correct_predictions(probs, labels) 72 | 73 | description = "Avg. batch proc. time: {:.4f}s, loss: {:.4f}"\ 74 | .format(batch_time_avg/(batch_index+1), 75 | running_loss/(batch_index+1)) 76 | tqdm_batch_iterator.set_description(description) 77 | 78 | epoch_time = time.time() - epoch_start 79 | epoch_loss = running_loss / len(dataloader) 80 | epoch_accuracy = correct_preds / len(dataloader.dataset) 81 | 82 | return epoch_time, epoch_loss, epoch_accuracy 83 | 84 | 85 | def validate(model, dataloader, criterion): 86 | """ 87 | Compute the loss and accuracy of a model on some validation dataset. 88 | 89 | Args: 90 | model: A torch module for which the loss and accuracy must be 91 | computed. 92 | dataloader: A DataLoader object to iterate over the validation data. 93 | criterion: A loss criterion to use for computing the loss. 94 | epoch: The number of the epoch for which validation is performed. 95 | device: The device on which the model is located. 96 | 97 | Returns: 98 | epoch_time: The total time to compute the loss and accuracy on the 99 | entire validation set. 100 | epoch_loss: The loss computed on the entire validation set. 101 | epoch_accuracy: The accuracy computed on the entire validation set. 102 | """ 103 | # Switch to evaluate mode. 104 | model.eval() 105 | device = model.device 106 | 107 | epoch_start = time.time() 108 | running_loss = 0.0 109 | running_accuracy = 0.0 110 | 111 | # Deactivate autograd for evaluation. 112 | with torch.no_grad(): 113 | for batch in dataloader: 114 | # Move input and output data to the GPU if one is used. 115 | premises = batch["premise"].to(device) 116 | premises_lengths = batch["premise_length"].to(device) 117 | hypotheses = batch["hypothesis"].to(device) 118 | hypotheses_lengths = batch["hypothesis_length"].to(device) 119 | labels = batch["label"].to(device) 120 | 121 | logits, probs = model(premises, 122 | premises_lengths, 123 | hypotheses, 124 | hypotheses_lengths) 125 | loss = criterion(logits, labels) 126 | 127 | running_loss += loss.item() 128 | running_accuracy += correct_predictions(probs, labels) 129 | 130 | epoch_time = time.time() - epoch_start 131 | epoch_loss = running_loss / len(dataloader) 132 | epoch_accuracy = running_accuracy / (len(dataloader.dataset)) 133 | 134 | return epoch_time, epoch_loss, epoch_accuracy 135 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | 3 | 4 | setup(name='ESIM', 5 | version='1.0.1', 6 | url='https://github.com/coetaur0/ESIM', 7 | license='Apache 2', 8 | author='Aurelien Coet', 9 | author_email='aurelien.coet19@gmail.com', 10 | description='Implementation of the ESIM model for NLI with PyTorch', 11 | packages=[ 12 | 'esim' 13 | ], 14 | install_requires=[ 15 | 'wget', 16 | 'numpy', 17 | 'nltk', 18 | 'matplotlib', 19 | 'tqdm', 20 | 'torch' 21 | ]) 22 | --------------------------------------------------------------------------------