├── .gitignore ├── LICENSE ├── README.md ├── example_to_feature.py ├── fetch_data.sh ├── finetune_lm.py ├── generated_explanations ├── snli_dev.csv.gz └── snli_test.csv.gz ├── images └── architecture.jpg ├── lm_utils.py ├── merge_esnli_train.py ├── nli_utils.py ├── prepare_train_test.py ├── requirements.txt ├── run_clf.sh ├── run_finetune_gpt2m.sh ├── run_generate_gpt2m.sh └── run_nli.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | 179 | APPENDIX: How to apply the Apache License to your work. 180 | 181 | To apply the Apache License to your work, attach the following 182 | boilerplate notice, with the fields enclosed by brackets "[]" 183 | replaced with your own identifying information. (Don't include 184 | the brackets!) The text should be enclosed in the appropriate 185 | comment syntax for the file format. We also recommend that a 186 | file or class name and description of purpose be included on the 187 | same "printed page" as the copyright notice for easier 188 | identification within third-party archives. 189 | 190 | Copyright [yyyy] [name of copyright owner] 191 | 192 | Licensed under the Apache License, Version 2.0 (the "License"); 193 | you may not use this file except in compliance with the License. 194 | You may obtain a copy of the License at 195 | 196 | http://www.apache.org/licenses/LICENSE-2.0 197 | 198 | Unless required by applicable law or agreed to in writing, software 199 | distributed under the License is distributed on an "AS IS" BASIS, 200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 201 | See the License for the specific language governing permissions and 202 | limitations under the License. 203 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # NILE 2 | Reference code for [ACL20](http://acl2019.org/) paper - [NILE : Natural Language Inference with Faithful Natural Language Explanations](https://www.aclweb.org/anthology/2020.acl-main.771/). 3 | 4 |

5 | ... 6 |

7 | 8 | ## Dependencies 9 | The code was written with, or depends on: 10 | * Python 3.6 11 | * Pytorch 1.4.0 12 | 13 | ## Running the code 14 | 1. Create a virtualenv and install dependecies 15 | ```bash 16 | virtualenv -p python3.6 env 17 | source env/bin/activate 18 | pip install -r requirements.txt 19 | ``` 20 | 1. Fetch data and pre-process. This will create add files in ```dataset_snli``` and ```dataset_mnli``` folders. 21 | ```bash 22 | bash fetch_data.sh 23 | python prepare_train_test.py --dataset snli --create_data --filter_repetitions 24 | python prepare_train_test.py --dataset mnli --create_data --filter_repetitions 25 | ``` 26 | 1. Fine-tuning langauge models using e-SNLI, for entailment, contradiction and neutral explanations. 'all' is trained to produce a comparable single-explanation ETPA baseline, and can be skippd in this and subsequent steps if only reproducing NILE. 27 | ```bash 28 | bash run_finetune_gpt2m.sh 0 entailment 2 29 | bash run_finetune_gpt2m.sh 0 contradiction 2 30 | bash run_finetune_gpt2m.sh 0 neutral 2 31 | bash run_finetune_gpt2m.sh 0 all 2 32 | ``` 33 | 1. Generate explanations using the fine-tuned langauge models, where can be snli or mnli, and is train/dev/test for SNLI and dev/dev_mm for MNLI. 34 | ```bash 35 | bash run_generate_gpt2m.sh 0 entailment all 36 | bash run_generate_gpt2m.sh 0 contradiction all 37 | bash run_generate_gpt2m.sh 0 neutral all 38 | bash run_generate_gpt2m.sh 0 all all 39 | ``` 40 | 1. Merge generated explanation 41 | ```bash 42 | python prepare_train_test.py --dataset --merge_data --split --input_prefix gpt2_m_ 43 | ``` 44 | 45 | To merge for the single-explanatoin baseline, run 46 | ```bash 47 | python prepare_train_test.py --dataset --merge_data --split --input_prefix gpt2_m_ --merge_single 48 | ``` 49 | 1. Train classifiers on the generated explantions, models are saved at ```saved_clf```. 50 | 51 | NILE-PH 52 | ```bash 53 | bash run_clf.sh 0 snli independent independent gpt2_m_ train dev _ _ _ 54 | bash run_clf.sh 0 snli aggregate aggregate gpt2_m_ train dev _ _ _ 55 | bash run_clf.sh 0 snli append append gpt2_m_ train dev _ _ _ 56 | ``` 57 | NILE-NS 58 | ```bash 59 | bash run_clf.sh 0 snli instance_independent instance_independent gpt2_m_ train dev _ _ _ 60 | bash run_clf.sh 0 snli instance_aggregate instance_aggregate gpt2_m_ train dev _ _ _ 61 | bash run_clf.sh 0 snli instance_append instance_append gpt2_m_ train dev _ _ _ 62 | ``` 63 | NILE 64 | ```bash 65 | bash run_clf.sh 0 snli instance_independent instance_independent gpt2_m_ train dev sample _ _ 66 | bash run_clf.sh 0 snli instance_aggregate instance_aggregate gpt2_m_ train dev sample _ _ 67 | ``` 68 | NILE post-hoc 69 | ```bash 70 | bash run_clf.sh 0 snli instance instance gpt2_m_ train dev sample _ _ 71 | ``` 72 | Single-explanation baseline 73 | ```bash 74 | bash run_clf.sh 0 snli Explanation_1 Explanation_1 gpt2_m_ train dev sample _ _ 75 | ``` 76 | 1. Evaluate a trained classifier for label accuracy, is the path of a model saved in the previous step, and can be independent, aggregate, append, instance_independent, instance_aggregate, and instance for NILE variants, and all_explanation for the single-explanation baseline 77 | ```bash 78 | bash run_clf.sh 0 snli gpt2_m_ _ test _ _ 79 | ``` 80 | 81 | ## Example explanations 82 | Generated explanations on the e-SNLI dev and test sets are present at ```./generated_explanations/*.gz```. Please unzip before using (```gunzip ```). 83 | The generated explanations are present in the ```entailment_explanation,contradiction_explanation,neutral_explanation``` columns in a csv format. 84 | 85 | ## Pre-trained models 86 | We are sharing pre-trained [label-specific generators](https://drive.google.com/file/d/1lZZYbAwZ8kphY8lp0bVSOQ841c683uUc/view?usp=sharing). 87 | We are also sharing pre-trained classifiers for [NILE-PH Append](https://drive.google.com/file/d/1DacGNNiPvUC6lYk9jzq44QlR5uNAOPss/view?usp=sharing) and [NILE Independent](https://drive.google.com/file/d/10xcnzWTyg1dgX8hldAsnGn52Oqocvqon/view?usp=sharing) architectures. 88 | 89 | ## Citation 90 | If you use this code, please consider citing: 91 | 92 | [1] Sawan Kumar and Partha Talukdar. 2020. NILE : Natural language inference with faithful natural language explanations. In Proceedings of the 58th Annual Meeting of the Association for Computational Linguistics, pages 8730–8742, Online. Association for Computational Linguistics. 93 | [[bibtex](https://www.aclweb.org/anthology/2020.acl-main.771.bib)] 94 | 95 | ## Contact 96 | For any clarification, comments, or suggestions please create an issue or contact sawankumar@iisc.ac.in 97 | -------------------------------------------------------------------------------- /example_to_feature.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. 3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 4 | # Modifications copyright (c) 2020 Sawan Kumar 5 | # 6 | # Licensed under the Apache License, Version 2.0 (the "License"); 7 | # you may not use this file except in compliance with the License. 8 | # You may obtain a copy of the License at 9 | # 10 | # http://www.apache.org/licenses/LICENSE-2.0 11 | # 12 | # Unless required by applicable law or agreed to in writing, software 13 | # distributed under the License is distributed on an "AS IS" BASIS, 14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | # See the License for the specific language governing permissions and 16 | # limitations under the License. 17 | #""" GLUE processors and helpers """ 18 | # 19 | # Modification history 20 | # 2020 Sawan Kumar: Modified from glue.py in HuggingFace's transformers, 21 | # to handle examples for NILE 22 | 23 | import logging 24 | import os 25 | import numpy as np 26 | 27 | from transformers.data.processors.utils import DataProcessor, InputExample, InputFeatures 28 | from transformers.file_utils import is_tf_available 29 | 30 | if is_tf_available(): 31 | import tensorflow as tf 32 | 33 | logger = logging.getLogger(__name__) 34 | 35 | def convert_examples_to_features(examples, tokenizer, 36 | max_length=512, 37 | label_list=None, 38 | pad_on_left=False, 39 | pad_token=0, 40 | pad_token_segment_id=0, 41 | mask_padding_with_zero=True, 42 | sample_negatives=False): 43 | """ 44 | Loads a data file into a list of ``InputFeatures`` 45 | 46 | Args: 47 | examples: List of ``InputExamples`` or ``tf.data.Dataset`` containing the examples. 48 | tokenizer: Instance of a tokenizer that will tokenize the examples 49 | max_length: Maximum example length 50 | task: GLUE task 51 | label_list: List of labels. Can be obtained from the processor using the ``processor.get_labels()`` method 52 | pad_on_left: If set to ``True``, the examples will be padded on the left rather than on the right (default) 53 | pad_token: Padding token 54 | pad_token_segment_id: The segment ID for the padding token (It is usually 0, but can vary such as for XLNet where it is 4) 55 | mask_padding_with_zero: If set to ``True``, the attention mask will be filled by ``1`` for actual values 56 | and by ``0`` for padded values. If set to ``False``, inverts it (``1`` for padded values, ``0`` for 57 | actual values) 58 | 59 | Returns: 60 | If the ``examples`` input is a ``tf.data.Dataset``, will return a ``tf.data.Dataset`` 61 | containing the task-specific features. If the input is a list of ``InputExamples``, will return 62 | a list of task-specific ``InputFeatures`` which can be fed to the model. 63 | 64 | """ 65 | is_tf_dataset = False 66 | if is_tf_available() and isinstance(examples, tf.data.Dataset): 67 | is_tf_dataset = True 68 | 69 | label_map = {label: i for i, label in enumerate(label_list)} 70 | 71 | features = [] 72 | if examples[0].text_b is not None: 73 | k = len(examples[0].text_b) 74 | if sample_negatives: 75 | neg_indices = [np.random.choice(len(examples), size=len(examples), replace=False) for i in range(k)] 76 | for (ex_index, example) in enumerate(examples): 77 | if ex_index % 10000 == 0: 78 | logger.info("Writing example %d" % (ex_index)) 79 | if is_tf_dataset: 80 | example = processor.get_example_from_tensor_dict(example) 81 | 82 | if type(example.text_a) is list: 83 | text_a = example.text_a 84 | text_b = [example.text_b]*len(text_a) 85 | elif type(example.text_b) is list: 86 | text_b = example.text_b 87 | if sample_negatives: 88 | label_idx = label_map[example.label] 89 | text_b_neg = [(examples[neg_indices[i][ex_index]]).text_b[label_idx] for i in range(k)] 90 | text_b_neg[label_idx] = text_b[label_idx] 91 | 92 | text_a = [example.text_a]*len(text_b) 93 | else: 94 | text_a = [example.text_a] 95 | text_b = [example.text_b] 96 | 97 | if 0: #sample_negatives: 98 | print ('Created negative samples') 99 | print ('Original example: label:{} text_a: {} text_b1: {}, 2: {}, 3:{}'.format(example.label, text_a[0], text_b[0], text_b[1], text_b[2])) 100 | print ('Converted example: text_a: {} text_b1: {}, 2: {}, 3:{}'.format(text_a[0], text_b_neg[0], text_b_neg[1], text_b_neg[2])) 101 | 102 | def get_indices(t1, t2): 103 | out = [] 104 | for a,b in zip(t1, t2): 105 | inputs = tokenizer.encode_plus( 106 | a, 107 | b, 108 | add_special_tokens=True, 109 | max_length=max_length, 110 | ) 111 | input_ids, token_type_ids = inputs["input_ids"], inputs["token_type_ids"] 112 | 113 | # The mask has 1 for real tokens and 0 for padding tokens. Only real 114 | # tokens are attended to. 115 | attention_mask = [1 if mask_padding_with_zero else 0] * len(input_ids) 116 | 117 | # Zero-pad up to the sequence length. 118 | padding_length = max_length - len(input_ids) 119 | if pad_on_left: 120 | input_ids = ([pad_token] * padding_length) + input_ids 121 | attention_mask = ([0 if mask_padding_with_zero else 1] * padding_length) + attention_mask 122 | token_type_ids = ([pad_token_segment_id] * padding_length) + token_type_ids 123 | else: 124 | input_ids = input_ids + ([pad_token] * padding_length) 125 | attention_mask = attention_mask + ([0 if mask_padding_with_zero else 1] * padding_length) 126 | token_type_ids = token_type_ids + ([pad_token_segment_id] * padding_length) 127 | 128 | assert len(input_ids) == max_length, "Error with input length {} vs {}".format(len(input_ids), max_length) 129 | assert len(attention_mask) == max_length, "Error with input length {} vs {}".format(len(attention_mask), max_length) 130 | assert len(token_type_ids) == max_length, "Error with input length {} vs {}".format(len(token_type_ids), max_length) 131 | out.append((input_ids, attention_mask, token_type_ids)) 132 | 133 | if len(t1) == 1: 134 | input_ids, attention_mask, token_type_ids = out[0] 135 | else: 136 | input_ids, attention_mask, token_type_ids = zip(*out) 137 | return input_ids, attention_mask, token_type_ids 138 | 139 | input_ids, attention_mask, token_type_ids = get_indices(text_a, text_b) 140 | if sample_negatives: 141 | input_ids_n, attention_mask_n, token_type_ids_n = get_indices(text_a, text_b_neg) 142 | 143 | label = label_map[example.label] 144 | 145 | if ex_index < 5: 146 | logger.info("*** Example ***") 147 | logger.info("guid: %s" % (example.guid)) 148 | logger.info("input_ids: %s" % " ".join([str(x) for x in input_ids])) 149 | logger.info("attention_mask: %s" % " ".join([str(x) for x in attention_mask])) 150 | logger.info("token_type_ids: %s" % " ".join([str(x) for x in token_type_ids])) 151 | logger.info("label: %s (id = %d)" % (example.label, label)) 152 | 153 | features.append( 154 | InputFeatures(input_ids=input_ids, 155 | attention_mask=attention_mask, 156 | token_type_ids=token_type_ids, 157 | label=label)) 158 | 159 | if sample_negatives: 160 | features.append( 161 | InputFeatures(input_ids=input_ids_n, 162 | attention_mask=attention_mask_n, 163 | token_type_ids=token_type_ids_n, 164 | label=label)) 165 | 166 | if is_tf_available() and is_tf_dataset: 167 | def gen(): 168 | for ex in features: 169 | yield ({'input_ids': ex.input_ids, 170 | 'attention_mask': ex.attention_mask, 171 | 'token_type_ids': ex.token_type_ids}, 172 | ex.label) 173 | 174 | return tf.data.Dataset.from_generator(gen, 175 | ({'input_ids': tf.int32, 176 | 'attention_mask': tf.int32, 177 | 'token_type_ids': tf.int32}, 178 | tf.int64), 179 | ({'input_ids': tf.TensorShape([None]), 180 | 'attention_mask': tf.TensorShape([None]), 181 | 'token_type_ids': tf.TensorShape([None])}, 182 | tf.TensorShape([]))) 183 | 184 | return features 185 | -------------------------------------------------------------------------------- /fetch_data.sh: -------------------------------------------------------------------------------- 1 | mkdir ./external 2 | 3 | #e-snli 4 | git clone https://github.com/OanaMariaCamburu/e-SNLI.git ./external/esnli 5 | python merge_esnli_train.py ./external/esnli 6 | 7 | #mnli, from glue 8 | wget -O ./external/MNLI.zip https://dl.fbaipublicfiles.com/glue/data/MNLI.zip 9 | unzip -d ./external/ ./external/MNLI.zip 10 | 11 | mkdir cache 12 | mkdir saved_lm 13 | mkdir saved_gen 14 | mkdir saved_clf 15 | -------------------------------------------------------------------------------- /finetune_lm.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. 3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 4 | # Modifications copyright (c) 2020 Sawan Kumar 5 | # 6 | # Licensed under the Apache License, Version 2.0 (the "License"); 7 | # you may not use this file except in compliance with the License. 8 | # You may obtain a copy of the License at 9 | # 10 | # http://www.apache.org/licenses/LICENSE-2.0 11 | # 12 | # Unless required by applicable law or agreed to in writing, software 13 | # distributed under the License is distributed on an "AS IS" BASIS, 14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | # See the License for the specific language governing permissions and 16 | # limitations under the License. 17 | # 18 | # Modification history 19 | # 2020 Sawan Kumar: Modified to finetune and generate explanations for NLI 20 | 21 | from __future__ import absolute_import, division, print_function 22 | 23 | import argparse 24 | import glob 25 | import logging 26 | import os 27 | import pickle 28 | import random 29 | import pandas as pd 30 | 31 | import numpy as np 32 | import torch 33 | import torch.nn.functional as F 34 | from torch.utils.data import DataLoader, Dataset, SequentialSampler, RandomSampler 35 | from torch.utils.data.distributed import DistributedSampler 36 | from tensorboardX import SummaryWriter 37 | from tqdm import tqdm, trange 38 | 39 | from transformers import (WEIGHTS_NAME, AdamW, get_linear_schedule_with_warmup, 40 | GPT2Config, GPT2LMHeadModel, GPT2Tokenizer) 41 | 42 | from lm_utils import TSVDataset, EXP_TOKEN, EOS_TOKEN 43 | 44 | logger = logging.getLogger(__name__) 45 | 46 | MODEL_CLASSES = { 47 | 'gpt2': (GPT2Config, GPT2LMHeadModel, GPT2Tokenizer), 48 | } 49 | 50 | cross_entropy_ignore_index = -1 51 | 52 | MAX_LENGTH = int(10000) # Hardcoded max length to avoid infinite loop 53 | 54 | def set_seed(args): 55 | random.seed(args.seed) 56 | np.random.seed(args.seed) 57 | torch.manual_seed(args.seed) 58 | if args.n_gpu > 0: 59 | torch.cuda.manual_seed_all(args.seed) 60 | 61 | def train(args, train_dataset, model, tokenizer): 62 | """ Train the model """ 63 | if args.local_rank in [-1, 0]: 64 | tb_writer = SummaryWriter() 65 | 66 | args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu) 67 | train_sampler = RandomSampler(train_dataset) if args.local_rank == -1 else DistributedSampler(train_dataset) 68 | train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size) 69 | 70 | if args.max_steps > 0: 71 | t_total = args.max_steps 72 | args.num_train_epochs = args.max_steps // (len(train_dataloader) // args.gradient_accumulation_steps) + 1 73 | else: 74 | t_total = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs 75 | 76 | # Prepare optimizer and schedule (linear warmup and decay) 77 | no_decay = ['bias', 'LayerNorm.weight'] 78 | optimizer_grouped_parameters = [ 79 | {'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': args.weight_decay}, 80 | {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0} 81 | ] 82 | optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon) 83 | scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total) 84 | if args.fp16: 85 | try: 86 | from apex import amp 87 | except ImportError: 88 | raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.") 89 | model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16_opt_level) 90 | 91 | # multi-gpu training (should be after apex fp16 initialization) 92 | if args.n_gpu > 1: 93 | model = torch.nn.DataParallel(model) 94 | 95 | # Distributed training (should be after apex fp16 initialization) 96 | if args.local_rank != -1: 97 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank], 98 | output_device=args.local_rank, 99 | find_unused_parameters=True) 100 | 101 | # Train! 102 | logger.info("***** Running training *****") 103 | logger.info(" Num examples = %d", len(train_dataset)) 104 | logger.info(" Num Epochs = %d", args.num_train_epochs) 105 | logger.info(" Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size) 106 | logger.info(" Total train batch size (w. parallel, distributed & accumulation) = %d", 107 | args.train_batch_size * args.gradient_accumulation_steps * (torch.distributed.get_world_size() if args.local_rank != -1 else 1)) 108 | logger.info(" Gradient Accumulation steps = %d", args.gradient_accumulation_steps) 109 | logger.info(" Total optimization steps = %d", t_total) 110 | 111 | global_step = 0 112 | tr_loss, logging_loss = 0.0, 0.0 113 | model.zero_grad() 114 | train_iterator = trange(int(args.num_train_epochs), desc="Epoch", disable=args.local_rank not in [-1, 0]) 115 | set_seed(args) # Added here for reproducibility (even between python 2 and 3) 116 | for _ in train_iterator: 117 | epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0]) 118 | for step, batch in enumerate(epoch_iterator): 119 | batch, prompt_lengths, total_lengths = batch 120 | max_length = torch.max(total_lengths).item() 121 | batch = batch[:, :max_length] 122 | inputs, labels = (batch, batch.clone().detach()) 123 | inputs = inputs.to(args.device) 124 | labels = labels.to(args.device) 125 | for idx in range(len(prompt_lengths)): 126 | labels[idx, :prompt_lengths[idx]] = cross_entropy_ignore_index 127 | model.train() 128 | outputs = model(inputs, labels=labels) 129 | loss = outputs[0] # model outputs are always tuple in transformers (see doc) 130 | 131 | if args.n_gpu > 1: 132 | loss = loss.mean() # mean() to average on multi-gpu parallel training 133 | if args.gradient_accumulation_steps > 1: 134 | loss = loss / args.gradient_accumulation_steps 135 | 136 | if args.fp16: 137 | with amp.scale_loss(loss, optimizer) as scaled_loss: 138 | scaled_loss.backward() 139 | else: 140 | loss.backward() 141 | 142 | tr_loss += loss.item() 143 | if (step + 1) % args.gradient_accumulation_steps == 0: 144 | if args.fp16: 145 | torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm) 146 | else: 147 | torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm) 148 | optimizer.step() 149 | scheduler.step() # Update learning rate schedule 150 | model.zero_grad() 151 | global_step += 1 152 | 153 | if args.local_rank in [-1, 0] and args.logging_steps > 0 and global_step % args.logging_steps == 0: 154 | # Log metrics 155 | if args.local_rank == -1 and args.evaluate_during_training: # Only evaluate when single GPU otherwise metrics may not average well 156 | results = evaluate(args, model, tokenizer) 157 | for key, value in results.items(): 158 | tb_writer.add_scalar('eval_{}'.format(key), value, global_step) 159 | tb_writer.add_scalar('lr', scheduler.get_lr()[0], global_step) 160 | tb_writer.add_scalar('loss', (tr_loss - logging_loss)/args.logging_steps, global_step) 161 | logging_loss = tr_loss 162 | 163 | if args.local_rank in [-1, 0] and args.save_steps > 0 and global_step % args.save_steps == 0: 164 | # Save model checkpoint 165 | output_dir = os.path.join(args.output_dir, 'checkpoint-{}'.format(global_step)) 166 | if not os.path.exists(output_dir): 167 | os.makedirs(output_dir) 168 | model_to_save = model.module if hasattr(model, 'module') else model # Take care of distributed/parallel training 169 | model_to_save.save_pretrained(output_dir) 170 | torch.save(args, os.path.join(output_dir, 'training_args.bin')) 171 | logger.info("Saving model checkpoint to %s", output_dir) 172 | 173 | if args.max_steps > 0 and global_step > args.max_steps: 174 | epoch_iterator.close() 175 | break 176 | if args.max_steps > 0 and global_step > args.max_steps: 177 | train_iterator.close() 178 | break 179 | 180 | if args.local_rank in [-1, 0]: 181 | tb_writer.close() 182 | 183 | return global_step, tr_loss / global_step 184 | 185 | def sample_sequence(model, length, context, device='cpu', eos_token_id=None): 186 | context = torch.tensor(context, dtype=torch.long, device=device) 187 | context = context.unsqueeze(0) 188 | generated = context 189 | past = None 190 | with torch.no_grad(): 191 | for _ in range(length): 192 | #inputs = {'input_ids': context} 193 | #output, past = model(**inputs, past=past) 194 | inputs = {'input_ids': generated} 195 | output, past = model(**inputs) 196 | next_token_logits = output[0, -1, :] 197 | next_token = torch.argmax(next_token_logits) 198 | generated = torch.cat((generated, next_token.view(1,1)), dim=1) 199 | if next_token.item() == eos_token_id: 200 | break 201 | context = next_token.view(1,1) 202 | return generated 203 | 204 | def generate(args, model, tokenizer, prefix=""): 205 | if args.length < 0 and model.config.max_position_embeddings > 0: 206 | args.length = model.config.max_position_embeddings 207 | elif 0 < model.config.max_position_embeddings < args.length: 208 | args.length = model.config.max_position_embeddings # No generation bigger than model size 209 | elif args.length < 0: 210 | args.length = MAX_LENGTH # avoid infinite loop 211 | 212 | eval_output_dir = args.output_dir 213 | eval_dataset = TSVDataset(tokenizer, args, file_path=args.eval_data_file, 214 | block_size=args.block_size, get_annotations=False) 215 | 216 | eval_sampler = SequentialSampler(eval_dataset) if args.local_rank == -1 else DistributedSampler(eval_dataset) 217 | eval_dataloader = DataLoader(eval_dataset, sampler=eval_sampler, batch_size=1) 218 | 219 | # Eval! 220 | logger.info("***** Running generation {} *****".format(prefix)) 221 | logger.info(" Num examples = %d", len(eval_dataset)) 222 | 223 | model.eval() 224 | 225 | for index, batch in enumerate(tqdm(eval_dataloader, desc="Evaluating")): 226 | batch, prompt_lengths, total_lengths = batch 227 | batch = batch.squeeze() 228 | out = sample_sequence( 229 | model=model, 230 | context=batch, 231 | length=args.length, 232 | device=args.device, 233 | eos_token_id=tokenizer.convert_tokens_to_ids(EOS_TOKEN), 234 | ) 235 | out = out[0, len(batch):].tolist() 236 | text = tokenizer.decode(out, clean_up_tokenization_spaces=True) 237 | text = text.split(EOS_TOKEN)[0].strip() 238 | eval_dataset.add_explanation(index, text) 239 | print (text) 240 | 241 | #save 242 | directory, filename = os.path.split(args.eval_data_file) 243 | model_directory, model_name = os.path.split(os.path.normpath(args.output_dir)) 244 | output_name = os.path.join(directory, '{}_{}'.format(model_name, filename)) 245 | eval_dataset.save(output_name) 246 | 247 | def evaluate(args, model, tokenizer, prefix=""): 248 | eval_output_dir = args.output_dir 249 | eval_dataset = TSVDataset(tokenizer, args, file_path=args.eval_data_file, 250 | block_size=args.block_size, get_annotations=True) 251 | 252 | if not os.path.exists(eval_output_dir) and args.local_rank in [-1, 0]: 253 | os.makedirs(eval_output_dir) 254 | 255 | args.eval_batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu) 256 | # Note that DistributedSampler samples randomly 257 | eval_sampler = SequentialSampler(eval_dataset) if args.local_rank == -1 else DistributedSampler(eval_dataset) 258 | eval_dataloader = DataLoader(eval_dataset, sampler=eval_sampler, batch_size=args.eval_batch_size) 259 | 260 | # Eval! 261 | logger.info("***** Running evaluation {} *****".format(prefix)) 262 | logger.info(" Num examples = %d", len(eval_dataset)) 263 | logger.info(" Batch size = %d", args.eval_batch_size) 264 | eval_loss = 0.0 265 | nb_eval_steps = 0 266 | model.eval() 267 | 268 | for batch in tqdm(eval_dataloader, desc="Evaluating"): 269 | batch, prompt_lengths, total_lengths = batch 270 | batch = batch.to(args.device) 271 | 272 | with torch.no_grad(): 273 | outputs = model(batch, labels=batch) 274 | lm_loss = outputs[0] 275 | eval_loss += lm_loss.mean().item() 276 | nb_eval_steps += 1 277 | 278 | eval_loss = eval_loss / nb_eval_steps 279 | perplexity = torch.exp(torch.tensor(eval_loss)) 280 | 281 | result = { 282 | "perplexity": perplexity 283 | } 284 | 285 | output_eval_file = os.path.join(eval_output_dir, "eval_results.txt") 286 | with open(output_eval_file, "w") as writer: 287 | logger.info("***** Eval results {} *****".format(prefix)) 288 | for key in sorted(result.keys()): 289 | logger.info(" %s = %s", key, str(result[key])) 290 | writer.write("%s = %s\n" % (key, str(result[key]))) 291 | 292 | return result 293 | 294 | 295 | def main(): 296 | parser = argparse.ArgumentParser() 297 | 298 | ## Required parameters 299 | parser.add_argument("--train_data_file", default=None, type=str, required=True, 300 | help="The input training data file (a text file).") 301 | parser.add_argument("--output_dir", default=None, type=str, required=True, 302 | help="The output directory where the model predictions and checkpoints will be written.") 303 | 304 | ## Other parameters 305 | parser.add_argument("--eval_data_file", default=None, type=str, 306 | help="An optional input evaluation data file to evaluate the perplexity on (a text file).") 307 | 308 | parser.add_argument("--model_type", default="bert", type=str, 309 | help="The model architecture to be fine-tuned.") 310 | parser.add_argument("--model_name_or_path", default="bert-base-cased", type=str, 311 | help="The model checkpoint for weights initialization.") 312 | 313 | parser.add_argument("--config_name", default="", type=str, 314 | help="Optional pretrained config name or path if not the same as model_name_or_path") 315 | parser.add_argument("--tokenizer_name", default="", type=str, 316 | help="Optional pretrained tokenizer name or path if not the same as model_name_or_path") 317 | parser.add_argument("--cache_dir", default="", type=str, 318 | help="Optional directory to store the pre-trained models downloaded from s3 (instread of the default one)") 319 | parser.add_argument("--block_size", default=-1, type=int, 320 | help="Optional input sequence length after tokenization." 321 | "The training dataset will be truncated in block of this size for training." 322 | "Default to the model max input length for single sentence inputs (take into account special tokens).") 323 | parser.add_argument("--do_train", action='store_true', 324 | help="Whether to run training.") 325 | parser.add_argument("--do_eval", action='store_true', 326 | help="Whether to run eval on the eval data file") 327 | parser.add_argument("--do_generate", action='store_true', 328 | help="Whether to generate text on the eval data file") 329 | parser.add_argument("--length", type=int, default=100, 330 | help="Length for generation") 331 | parser.add_argument("--evaluate_during_training", action='store_true', 332 | help="Run evaluation during training at each logging step.") 333 | parser.add_argument("--do_lower_case", action='store_true', 334 | help="Set this flag if you are using an uncased model.") 335 | 336 | parser.add_argument("--per_gpu_train_batch_size", default=4, type=int, 337 | help="Batch size per GPU/CPU for training.") 338 | parser.add_argument("--per_gpu_eval_batch_size", default=4, type=int, 339 | help="Batch size per GPU/CPU for evaluation.") 340 | parser.add_argument('--gradient_accumulation_steps', type=int, default=1, 341 | help="Number of updates steps to accumulate before performing a backward/update pass.") 342 | parser.add_argument("--learning_rate", default=5e-5, type=float, 343 | help="The initial learning rate for Adam.") 344 | parser.add_argument("--weight_decay", default=0.0, type=float, 345 | help="Weight deay if we apply some.") 346 | parser.add_argument("--adam_epsilon", default=1e-8, type=float, 347 | help="Epsilon for Adam optimizer.") 348 | parser.add_argument("--max_grad_norm", default=1.0, type=float, 349 | help="Max gradient norm.") 350 | parser.add_argument("--num_train_epochs", default=1.0, type=float, 351 | help="Total number of training epochs to perform.") 352 | parser.add_argument("--max_steps", default=-1, type=int, 353 | help="If > 0: set total number of training steps to perform. Override num_train_epochs.") 354 | parser.add_argument("--warmup_steps", default=0, type=int, 355 | help="Linear warmup over warmup_steps.") 356 | 357 | parser.add_argument("--data_type", default="tsv", type=str, 358 | help="Dataset type") 359 | 360 | parser.add_argument('--logging_steps', type=int, default=50, 361 | help="Log every X updates steps.") 362 | parser.add_argument('--save_steps', type=int, default=50, 363 | help="Save checkpoint every X updates steps.") 364 | parser.add_argument("--eval_all_checkpoints", action='store_true', 365 | help="Evaluate all checkpoints starting with the same prefix as model_name_or_path ending and ending with step number") 366 | parser.add_argument("--no_cuda", action='store_true', 367 | help="Avoid using CUDA when available") 368 | parser.add_argument('--overwrite_output_dir', action='store_true', 369 | help="Overwrite the content of the output directory") 370 | parser.add_argument('--overwrite_cache', action='store_true', 371 | help="Overwrite the cached training and evaluation sets") 372 | parser.add_argument('--seed', type=int, default=42, 373 | help="random seed for initialization") 374 | 375 | parser.add_argument('--fp16', action='store_true', 376 | help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit") 377 | parser.add_argument('--fp16_opt_level', type=str, default='O1', 378 | help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']." 379 | "See details at https://nvidia.github.io/apex/amp.html") 380 | parser.add_argument("--local_rank", type=int, default=-1, 381 | help="For distributed training: local_rank") 382 | parser.add_argument('--server_ip', type=str, default='', help="For distant debugging.") 383 | parser.add_argument('--server_port', type=str, default='', help="For distant debugging.") 384 | args = parser.parse_args() 385 | 386 | if args.eval_data_file is None and args.do_eval: 387 | raise ValueError("Cannot do evaluation without an evaluation data file. Either supply a file to --eval_data_file " 388 | "or remove the --do_eval argument.") 389 | 390 | if os.path.exists(args.output_dir) and os.listdir(args.output_dir) and args.do_train and not args.overwrite_output_dir: 391 | raise ValueError("Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome.".format(args.output_dir)) 392 | 393 | # Setup distant debugging if needed 394 | if args.server_ip and args.server_port: 395 | # Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script 396 | import ptvsd 397 | print("Waiting for debugger attach") 398 | ptvsd.enable_attach(address=(args.server_ip, args.server_port), redirect_output=True) 399 | ptvsd.wait_for_attach() 400 | 401 | # Setup CUDA, GPU & distributed training 402 | if args.local_rank == -1 or args.no_cuda: 403 | device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu") 404 | args.n_gpu = torch.cuda.device_count() 405 | else: # Initializes the distributed backend which will take care of sychronizing nodes/GPUs 406 | torch.cuda.set_device(args.local_rank) 407 | device = torch.device("cuda", args.local_rank) 408 | torch.distributed.init_process_group(backend='nccl') 409 | args.n_gpu = 1 410 | args.device = device 411 | 412 | # Setup logging 413 | logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s', 414 | datefmt = '%m/%d/%Y %H:%M:%S', 415 | level = logging.INFO if args.local_rank in [-1, 0] else logging.WARN) 416 | logger.warning("Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s", 417 | args.local_rank, device, args.n_gpu, bool(args.local_rank != -1), args.fp16) 418 | 419 | # Set seed 420 | set_seed(args) 421 | 422 | # Load pretrained model and tokenizer 423 | if args.local_rank not in [-1, 0]: 424 | torch.distributed.barrier() # Barrier to make sure only the first process in distributed training download model & vocab 425 | 426 | config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type] 427 | config = config_class.from_pretrained(args.config_name if args.config_name else args.model_name_or_path, 428 | cache_dir=args.cache_dir if args.cache_dir else None) 429 | tokenizer = tokenizer_class.from_pretrained(args.tokenizer_name if args.tokenizer_name else args.model_name_or_path, do_lower_case=args.do_lower_case, 430 | cache_dir=args.cache_dir if args.cache_dir else None) 431 | if args.block_size <= 0: 432 | args.block_size = tokenizer.max_len_single_sentence # Our input block size will be the max possible for the model 433 | args.block_size = min(args.block_size, tokenizer.max_len_single_sentence) 434 | 435 | if args.do_train: 436 | #Additional tokens 437 | print ('#tokens', len(tokenizer)) 438 | new_tokens = [EXP_TOKEN, EOS_TOKEN] 439 | tokenizer.add_tokens(new_tokens) 440 | print ('#extended tokens', len(tokenizer)) 441 | model = model_class.from_pretrained(args.model_name_or_path, from_tf=bool('.ckpt' in args.model_name_or_path), config=config, 442 | cache_dir=args.cache_dir if args.cache_dir else None) 443 | model.resize_token_embeddings(len(tokenizer)) 444 | model.to(args.device) 445 | 446 | if args.local_rank == 0: 447 | torch.distributed.barrier() # End of barrier to make sure only the first process in distributed training download model & vocab 448 | 449 | logger.info("Training/evaluation parameters %s", args) 450 | 451 | # Training 452 | if args.do_train: 453 | if args.local_rank not in [-1, 0]: 454 | torch.distributed.barrier() # Barrier to make sure only the first process in distributed training process the dataset, and the others will use the cache 455 | 456 | train_dataset = TSVDataset(tokenizer, args, file_path=args.train_data_file, 457 | block_size=args.block_size, get_annotations=True) 458 | 459 | if args.local_rank == 0: 460 | torch.distributed.barrier() 461 | 462 | global_step, tr_loss = train(args, train_dataset, model, tokenizer) 463 | logger.info(" global_step = %s, average loss = %s", global_step, tr_loss) 464 | 465 | 466 | # Saving best-practices: if you use save_pretrained for the model and tokenizer, you can reload them using from_pretrained() 467 | if args.do_train and (args.local_rank == -1 or torch.distributed.get_rank() == 0): 468 | # Create output directory if needed 469 | if not os.path.exists(args.output_dir) and args.local_rank in [-1, 0]: 470 | os.makedirs(args.output_dir) 471 | 472 | logger.info("Saving model checkpoint to %s", args.output_dir) 473 | # Save a trained model, configuration and tokenizer using `save_pretrained()`. 474 | # They can then be reloaded using `from_pretrained()` 475 | model_to_save = model.module if hasattr(model, 'module') else model # Take care of distributed/parallel training 476 | model_to_save.save_pretrained(args.output_dir) 477 | tokenizer.save_pretrained(args.output_dir) 478 | 479 | # Good practice: save your training arguments together with the trained model 480 | torch.save(args, os.path.join(args.output_dir, 'training_args.bin')) 481 | 482 | # Load a trained model and vocabulary that you have fine-tuned 483 | model = model_class.from_pretrained(args.output_dir) 484 | tokenizer = tokenizer_class.from_pretrained(args.output_dir, do_lower_case=args.do_lower_case) 485 | model.to(args.device) 486 | 487 | # Evaluation 488 | if args.do_eval and args.local_rank in [-1, 0]: 489 | model = model_class.from_pretrained(args.output_dir) 490 | model.to(args.device) 491 | tokenizer = tokenizer_class.from_pretrained(args.output_dir, do_lower_case=args.do_lower_case) 492 | result = evaluate(args, model, tokenizer) 493 | print (result) 494 | 495 | #Generation 496 | if args.do_generate: 497 | model = model_class.from_pretrained(args.output_dir) 498 | model.to(args.device) 499 | tokenizer = tokenizer_class.from_pretrained(args.output_dir, do_lower_case=args.do_lower_case) 500 | generate(args, model, tokenizer) 501 | 502 | if __name__ == "__main__": 503 | main() 504 | -------------------------------------------------------------------------------- /generated_explanations/snli_dev.csv.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SawanKumar28/nile/ef7eb47dd49afff6855358901afca24de27f0eae/generated_explanations/snli_dev.csv.gz -------------------------------------------------------------------------------- /generated_explanations/snli_test.csv.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SawanKumar28/nile/ef7eb47dd49afff6855358901afca24de27f0eae/generated_explanations/snli_test.csv.gz -------------------------------------------------------------------------------- /images/architecture.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SawanKumar28/nile/ef7eb47dd49afff6855358901afca24de27f0eae/images/architecture.jpg -------------------------------------------------------------------------------- /lm_utils.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import os 3 | import pickle 4 | import torch 5 | 6 | from torch.utils.data import DataLoader, Dataset 7 | 8 | EXP_TOKEN = '[EXP]' 9 | EOS_TOKEN = '[EOS]' 10 | class TSVDataset(Dataset): 11 | def __init__(self, tokenizer, args, file_path='train', block_size=512, get_annotations=False): 12 | self.print_count = 5 13 | self.eos_token_id = tokenizer.convert_tokens_to_ids(EOS_TOKEN) 14 | 15 | cached_features_file, data = self.load_data(file_path, block_size) 16 | self.data = data 17 | 18 | if get_annotations: cached_features_file = cached_features_file + '_annotated' 19 | 20 | if os.path.exists(cached_features_file): 21 | print ('Loading features from', cached_features_file) 22 | with open(cached_features_file, 'rb') as handle: 23 | self.examples = pickle.load(handle) 24 | return 25 | 26 | print ('Saving features from ', file_path, ' into ', cached_features_file) 27 | 28 | def create_example(r): 29 | text1 = '{} {} '.format(r['input'], EXP_TOKEN) 30 | tokenized_text1 = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(text1)) 31 | prompt_length = len(tokenized_text1) 32 | tokenized_text, total_length = tokenized_text1, len(tokenized_text1) 33 | if get_annotations: 34 | text2 = r['target'] 35 | tokenized_text2 = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(text2)) 36 | tokenized_text = tokenized_text1 + tokenized_text2 37 | tokenized_text = tokenized_text + [self.eos_token_id] 38 | total_length = len(tokenized_text) 39 | if len(tokenized_text) > block_size: 40 | tokenized_text = tokenized_text[:block_size] 41 | if len(tokenized_text) < block_size: 42 | tokenized_text = tokenized_text + [self.eos_token_id] * (block_size-len(tokenized_text)) 43 | if self.print_count > 0: 44 | print ('example: ', text1 + text2 if get_annotations else text1) 45 | self.print_count = self.print_count - 1 46 | return (tokenized_text, prompt_length, total_length) 47 | 48 | self.examples = data.apply(create_example, axis=1).to_list() 49 | print ('Saving ', len(self.examples), ' examples') 50 | with open(cached_features_file, 'wb') as handle: 51 | pickle.dump(self.examples, handle, protocol=pickle.HIGHEST_PROTOCOL) 52 | 53 | def __len__(self): 54 | return len(self.examples) 55 | 56 | def __getitem__(self, item): 57 | return torch.tensor(self.examples[item][0]), self.examples[item][1], self.examples[item][2] 58 | 59 | def get_example_text(self, index): 60 | return self.data['prompt'][index] 61 | 62 | def add_explanation(self, index, explanation): 63 | explanation_name = 'Generated_Explanation' 64 | self.data.at[self.data.index[index], explanation_name] = explanation 65 | 66 | def load_data(self, file_path, block_size): 67 | assert os.path.isfile(file_path) 68 | data = pd.read_csv(file_path, sep='\t', index_col='pairID') 69 | print (data) 70 | directory, filename = os.path.split(file_path) 71 | cached_features_file = os.path.join(directory, 'cached_lm_{}_{}'.format(block_size, filename)) 72 | return cached_features_file, data 73 | 74 | def save(self, filename): 75 | self.data.to_csv(filename, sep='\t') 76 | -------------------------------------------------------------------------------- /merge_esnli_train.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import os 3 | import sys 4 | 5 | esnli_path = sys.argv[1] 6 | 7 | #Merge train files 8 | merged_train_file_path = os.path.join(esnli_path, "dataset", "esnli_train.csv") 9 | train1_path = os.path.join(esnli_path, "dataset", "esnli_train_1.csv") 10 | train2_path = os.path.join(esnli_path, "dataset", "esnli_train_2.csv") 11 | d1 = pd.read_csv(train1_path, index_col="pairID") 12 | d2 = pd.read_csv(train2_path, index_col="pairID") 13 | 14 | d = pd.concat([d1, d2], 0, sort=False) 15 | d_nna = d.dropna() 16 | 17 | print ("Merging train files with lengths ", len(d1), len(d2), ", output length", len(d_nna)) 18 | d_nna.to_csv(merged_train_file_path) 19 | -------------------------------------------------------------------------------- /nli_utils.py: -------------------------------------------------------------------------------- 1 | from transformers.data.processors.utils import DataProcessor, InputExample, InputFeatures 2 | import pandas as pd 3 | import numpy as np 4 | 5 | class ExpProcessor(DataProcessor): 6 | s1 = 'sentence1' 7 | s2 = 'sentence2' 8 | index_col = "pairID" 9 | labels = ["entailment", "contradiction", "neutral"] 10 | gold_label = "gold_label" 11 | def get_train_examples(self, filepath, data_format="instance", to_drop=[]): 12 | data = pd.read_csv(filepath, index_col=self.index_col) 13 | examples = self._create_examples(data, 'train', data_format=data_format, to_drop=to_drop) 14 | return examples 15 | 16 | def get_dev_examples(self, filepath, data_format="instance", to_drop=[]): 17 | data = pd.read_csv(filepath, index_col=self.index_col) 18 | examples = self._create_examples(data, 'dev', data_format=data_format, to_drop=to_drop) 19 | return examples 20 | 21 | def get_labels(self): 22 | return self.labels 23 | 24 | data_formats = ["instance", "independent", "append", "instance_independent", "instance_append", 25 | "all_explanation", "Explanation_1"] 26 | #aggregate uses the same format as independent 27 | def _create_examples(self, labeled_examples, set_type, data_format="instance", to_drop=[]): 28 | """Creates examples for the training and dev sets.""" 29 | if data_format not in self.data_formats: 30 | raise ValueError("Data format {} not supported".format(data_format)) 31 | 32 | if 'explanation' in to_drop: to_drop = self.labels 33 | 34 | keep_labels = [True if l not in to_drop else False for l in self.labels] 35 | exp_names = ["{}_explanation".format(l) for l in self.labels] 36 | 37 | examples = [] 38 | for (idx, le) in labeled_examples.iterrows(): 39 | guid = idx 40 | label = le[self.gold_label] 41 | 42 | if data_format in ["independent", "instance_independent"]: 43 | exp_text = [le[exp_name] if keep else "" 44 | for l, keep, exp_name in zip(self.labels, keep_labels, exp_names)] 45 | elif data_format in ["append", "instance_append"]: 46 | exp_text = " ".join(["{}: {}".format(l, le[exp_name]) if keep else "" 47 | for l, keep, exp_name in zip(self.labels, keep_labels, exp_names)]) 48 | 49 | if data_format == "instance": 50 | text_a, text_b = le[self.s1], le[self.s2] 51 | elif data_format in ["Explanation_1", "all_explanation"]: 52 | text_a, text_b = le[data_format], None 53 | elif data_format in ["independent", "append"]: 54 | text_a, text_b = exp_text, None 55 | elif data_format in ["instance_independent", "instance_append"]: 56 | instance_text = "Premise: {} Hypothesis: {}".format( 57 | le[self.s1], le[self.s2]) if "instance" not in to_drop else "Premise: Hypothesis:" 58 | text_a, text_b = instance_text, exp_text 59 | 60 | examples.append( 61 | InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) 62 | return examples 63 | 64 | def simple_accuracy(preds, labels): 65 | return (preds == labels).mean() 66 | 67 | def exp_compute_metrics(preds, labels): 68 | assert len(preds) == len(labels) 69 | return {"acc": simple_accuracy(preds, labels)} 70 | -------------------------------------------------------------------------------- /prepare_train_test.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import argparse 4 | import numpy as np 5 | from tqdm import tqdm 6 | import pandas as pd 7 | 8 | if __name__ == "__main__": 9 | 10 | parser = argparse.ArgumentParser() 11 | 12 | parser.add_argument("--dataset", type=str, default="snli") #snli, mnli 13 | parser.add_argument("--create_data", action="store_true") 14 | parser.add_argument("--filter_repetitions", action="store_true") 15 | 16 | #For merging 17 | parser.add_argument("--merge_data", action="store_true") 18 | parser.add_argument("--merge_single", action="store_true") 19 | parser.add_argument("--split", type=str, default="train") 20 | parser.add_argument("--input_prefix", type=str, default="dummy") 21 | 22 | #for shuffled evaluation 23 | parser.add_argument("--shuffle", action="store_true") 24 | 25 | args = parser.parse_args() 26 | tqdm.pandas() 27 | 28 | args.dataset = args.dataset.lower() 29 | 30 | s1,s2 = 'sentence1', 'sentence2' 31 | index_col = 'pairID' 32 | gold_label = 'gold_label' 33 | data_labels = ['entailment', 'neutral', 'contradiction'] 34 | if args.dataset == 'snli': 35 | input_path = './external/esnli/dataset' 36 | filenames = { 37 | 'dev': 'esnli_dev.csv', 38 | 'train': 'esnli_train.csv', 39 | 'test': 'esnli_test.csv' 40 | } 41 | sep = ',' 42 | data_index_col = 'pairID' 43 | data_gold_label = 'gold_label' 44 | quotechar = '"' 45 | quoting = 0 46 | data_s1,data_s2 = 'Sentence1', 'Sentence2' 47 | label_map = None 48 | skip_segregation = False 49 | explanation_available = True 50 | e1 = 'Explanation_1' 51 | output_root = './dataset_snli' 52 | elif args.dataset == 'mnli': 53 | input_path = './external/MNLI' 54 | filenames = { 55 | 'train': 'train.tsv', 56 | 'dev': 'dev_matched.tsv', 57 | 'dev_mm': 'dev_mismatched.tsv' 58 | } 59 | sep = '\t' 60 | data_index_col = 'pairID' 61 | data_gold_label = 'gold_label' 62 | quotechar = None 63 | quoting = 3 64 | data_s1,data_s2 = 'sentence1', 'sentence2' 65 | label_map = None 66 | skip_segregation = True 67 | explanation_available = False 68 | output_root = './dataset_mnli' 69 | else: 70 | raise ValueError("dataset not supported") 71 | 72 | if args.create_data: 73 | data = {} 74 | for split in filenames: 75 | data[split] = pd.read_csv(os.path.join(input_path, filenames[split]), 76 | index_col=data_index_col, sep=sep, quotechar=quotechar, quoting=quoting) 77 | data[split] = data[split].rename(columns={data_s1:s1, data_s2:s2, data_gold_label:gold_label}) 78 | data[split].index.name = index_col 79 | if label_map: data[split][gold_label] = data[split][gold_label].apply(label_map) 80 | 81 | print ('\n Split {} Len {}'.format(split, len(data[split]))) 82 | print (data[split][gold_label].value_counts()) 83 | 84 | if args.filter_repetitions and split == "train" and args.dataset == 'snli': 85 | print ("Filtering repetitions") 86 | def has_repetition(r): 87 | exp = r[e1].lower() 88 | p = r[s1].lower() 89 | h = r[s2].lower() 90 | return True if p in exp or h in exp else False 91 | cond = data[split].apply(has_repetition, axis=1) 92 | print ("#cases with repetitions:", cond.sum()) 93 | data[split] = data[split][cond==False] 94 | print ('Updated Split {} Len {}'.format(split, len(data[split]))) 95 | print (data[split][gold_label].value_counts()) 96 | cond = data[split].apply(has_repetition, axis=1) 97 | print ("#cases with repetitions:", cond.sum()) 98 | 99 | label = "all" 100 | examples = data[split] 101 | dpath = os.path.join(output_root, label) 102 | os.makedirs(dpath) if not os.path.exists(dpath) else None 103 | fname = os.path.join(dpath, '{}_data.csv'.format(split)) 104 | examples.to_csv(fname) 105 | 106 | if not skip_segregation: 107 | for label in data_labels: 108 | for split in filenames: 109 | dpath = os.path.join(output_root, label) 110 | os.makedirs(dpath) if not os.path.exists(dpath) else None 111 | fname = os.path.join(dpath, '{}_data.csv'.format(split)) 112 | examples = data[split][data[split][gold_label] == label] 113 | print ('Saving {} | {} | {} examples'.format(label, split, len(examples))) 114 | examples.to_csv(fname) 115 | 116 | def generate_prompt(r): 117 | inp = 'Premise: {} Hypothesis: {}'.format(r[s1], r[s2]) 118 | return inp 119 | 120 | if args.create_data: 121 | labels = ["all"] 122 | if not skip_segregation: labels.extend(data_labels) 123 | 124 | for label in labels: 125 | dpath = os.path.join(output_root, label) 126 | for split in filenames: 127 | fname = os.path.join(dpath, '{}_data.csv'.format(split)) 128 | examples = pd.read_csv(fname, index_col=index_col) 129 | print ('Processing {} | {} | {} examples'.format(label, split, len(examples))) 130 | print (examples[gold_label].value_counts()) 131 | examples['input'] = examples[[s1, s2]].progress_apply(generate_prompt, axis=1) 132 | columns_to_write = ['input'] 133 | if explanation_available: 134 | examples['target'] = examples[e1] 135 | columns_to_write.append('target') 136 | print ('Writing') 137 | fname = os.path.join(dpath, '{}.tsv'.format(split)) 138 | examples[columns_to_write].to_csv(fname, sep='\t') 139 | 140 | if args.merge_data: 141 | if args.merge_single: 142 | suffixes = ["all"] 143 | output_suffix = "_all" 144 | else: 145 | suffixes = ["entailment", "contradiction", "neutral"] 146 | output_suffix = "" 147 | split = args.split 148 | fname_csv = os.path.join(output_root, 'all', '{}_data.csv'.format(split)) 149 | d_csv = pd.read_csv(fname_csv, index_col=index_col) 150 | 151 | for s in suffixes: 152 | fname_tsv = os.path.join(output_root, 'all', '{}{}_{}.tsv'.format(args.input_prefix, s, split)) 153 | d_tsv = pd.read_csv(fname_tsv, index_col=index_col, sep='\t') 154 | 155 | d_csv['{}_explanation'.format(s)] = d_tsv['Generated_Explanation'] 156 | 157 | 158 | print (d_csv.head(5)) 159 | fname = os.path.join(output_root, 'all', '{}{}_{}{}.csv'.format(args.input_prefix, split, "merged", output_suffix)) 160 | d_csv.to_csv(fname) 161 | 162 | if args.shuffle: 163 | split = args.split 164 | fname = os.path.join(output_root, 'all', '{}{}_{}.csv'.format(args.input_prefix, split, "merged")) 165 | d_csv = pd.read_csv(fname, index_col=index_col) 166 | 167 | d_csv_shuffled = d_csv.copy() 168 | for l in ["entailment", "contradiction", "neutral"]: 169 | d_csv_shuffled['{}_explanation'.format(l)] = np.random.choice( 170 | d_csv_shuffled['{}_explanation'.format(l)].values, 171 | len(d_csv_shuffled), replace=False) 172 | fname_csv_out = os.path.join(output_root, 'all', '{}shuffled{}_{}.csv'.format(args.input_prefix, split, "merged")) 173 | d_csv_shuffled.to_csv(fname_csv_out) 174 | 175 | 176 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | pandas 2 | tqdm 3 | torch 4 | transformers==2.3.0 5 | tensorboardX 6 | -------------------------------------------------------------------------------- /run_clf.sh: -------------------------------------------------------------------------------- 1 | GPUDEV=$1 2 | SEED=$2 3 | DATASET=$3 4 | EXPMODEL=$4 5 | DATAFORMAT=$5 6 | INPREFIX=$6 7 | TRAIN=$7 8 | EVAL=$8 9 | SAMPLENEGS=$9 10 | TODROP="${10}" 11 | MODELPATH="${11}" 12 | 13 | MODELTYPE=roberta 14 | MODELNAME=roberta-base 15 | if [ "$DATAFORMAT" == "instance" ] || [ "$DATAFORMAT" == "Explanation_1" ] 16 | then 17 | SEQLEN=100 18 | BSZ=32 19 | INPREFIX="" 20 | INSUFFIX="_data" 21 | elif [ "$DATAFORMAT" == "all_explanation" ] 22 | then 23 | SEQLEN=100 24 | BSZ=32 25 | INSUFFIX="_merged_all" 26 | elif [ "$DATAFORMAT" == "independent" ] || [ "$DATAFORMAT" == "aggregate" ] 27 | then 28 | SEQLEN=50 29 | BSZ=32 30 | INSUFFIX="_merged" 31 | elif [ "$DATAFORMAT" == "append" ] 32 | then 33 | SEQLEN=100 34 | BSZ=32 35 | INSUFFIX="_merged" 36 | elif [ "$DATAFORMAT" == "instance_independent" ] || [ "$DATAFORMAT" == "instance_aggregate" ] 37 | then 38 | SEQLEN=100 39 | BSZ=16 40 | INSUFFIX="_merged" 41 | elif [ "$DATAFORMAT" == "instance_append" ] 42 | then 43 | SEQLEN=200 44 | BSZ=16 45 | INSUFFIX="_merged" 46 | fi 47 | NEPOCHS=3 48 | 49 | TRAINFILE=./dataset_"$DATASET"/all/"$INPREFIX""$TRAIN""$INSUFFIX".csv 50 | if [ "$TRAIN" == "_" ] 51 | then 52 | TRAINCMD="" 53 | else 54 | TRAINCMD="--do_train" 55 | fi 56 | EVALFILE=./dataset_"$DATASET"/all/"$INPREFIX""$EVAL""$INSUFFIX".csv 57 | 58 | if [ "$SAMPLENEGS" == "sample" ] 59 | then 60 | SAMPLECMD="--sample_negs" 61 | SAMPLESTR="_negs" 62 | else 63 | SAMPLECMD="" 64 | SAMPLESTR="" 65 | fi 66 | 67 | if [ "$TODROP" == "_" ] 68 | then 69 | TODROPCMD="" 70 | else 71 | TODROPCMD="--to_drop "$TODROP"" 72 | fi 73 | 74 | if [ "$MODELPATH" == "_" ] 75 | then 76 | OUTPUTDIR="./saved_clf/seed"$SEED"_"$DATASET"_"$EXPMODEL"_"$BSZ"_"$SEQLEN"_"$NEPOCHS"_"$INPREFIX""$SAMPLESTR"" 77 | else 78 | OUTPUTDIR="$MODELPATH" 79 | fi 80 | 81 | 82 | cmd="CUDA_VISIBLE_DEVICES=$GPUDEV python run_nli.py "$SAMPLECMD" "$TODROPCMD" \ 83 | --cache_dir ../nile_release/cache \ 84 | --seed "$SEED" \ 85 | --model_type "$MODELTYPE" \ 86 | --model_name_or_path "$MODELNAME" \ 87 | --exp_model "$EXPMODEL" \ 88 | --data_format "$DATAFORMAT" \ 89 | "$TRAINCMD" --save_steps 1523000000000000 \ 90 | --do_eval --eval_all_checkpoints \ 91 | --do_lower_case \ 92 | --train_file "$TRAINFILE" --eval_file "$EVALFILE" \ 93 | --max_seq_length "$SEQLEN" \ 94 | --per_gpu_eval_batch_size="$BSZ" \ 95 | --per_gpu_train_batch_size="$BSZ" \ 96 | --learning_rate 2e-5 \ 97 | --num_train_epochs "$NEPOCHS" \ 98 | --logging_steps 5000 --evaluate_during_training \ 99 | --output_dir "$OUTPUTDIR"" 100 | echo $cmd 101 | eval $cmd 102 | -------------------------------------------------------------------------------- /run_finetune_gpt2m.sh: -------------------------------------------------------------------------------- 1 | GPUDEV=$1 2 | DATAROOT=$2 3 | BSZ=$3 4 | cmd="CUDA_VISIBLE_DEVICES="$GPUDEV" python finetune_lm.py \ 5 | --cache_dir ./cache \ 6 | --output_dir=./saved_lm/gpt2_m_"$DATAROOT" \ 7 | --per_gpu_train_batch_size $BSZ 8 | --per_gpu_eval_batch_size $BSZ \ 9 | --model_type=gpt2 \ 10 | --model_name_or_path=gpt2-medium \ 11 | --do_train \ 12 | --block_size 128 \ 13 | --save_steps 6866800 \ 14 | --num_train_epochs 3 \ 15 | --train_data_file=./dataset_snli/"$DATAROOT"/train.tsv \ 16 | --do_eval \ 17 | --eval_data_file=./dataset_snli/"$DATAROOT"/dev.tsv" 18 | echo $cmd 19 | eval $cmd 20 | -------------------------------------------------------------------------------- /run_generate_gpt2m.sh: -------------------------------------------------------------------------------- 1 | GPUDEV=$1 2 | MODELSPLIT=$2 3 | DATASET=$3 4 | DATAROOT=$4 5 | DATASPLIT=$5 6 | cmd="CUDA_VISIBLE_DEVICES="$GPUDEV" python finetune_lm.py \ 7 | --do_generate \ 8 | --cache_dir ./cache \ 9 | --output_dir=./saved_lm/gpt2_m_"$MODELSPLIT" \ 10 | --model_type=gpt2 \ 11 | --model_name_or_path=gpt2-medium \ 12 | --block_size 128 \ 13 | --save_steps 6866800 \ 14 | --num_train_epochs 3 \ 15 | --train_data_file=./dataset_"$DATASET"/"$DATAROOT"/train.tsv \ 16 | --eval_data_file=./dataset_"$DATASET"/"$DATAROOT"/"$DATASPLIT".tsv" 17 | echo $cmd 18 | eval $cmd 19 | -------------------------------------------------------------------------------- /run_nli.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. 3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 4 | # Modifications copyright (c) 2020 Sawan Kumar 5 | # 6 | # Licensed under the Apache License, Version 2.0 (the "License"); 7 | # you may not use this file except in compliance with the License. 8 | # You may obtain a copy of the License at 9 | # 10 | # http://www.apache.org/licenses/LICENSE-2.0 11 | # 12 | # Unless required by applicable law or agreed to in writing, software 13 | # distributed under the License is distributed on an "AS IS" BASIS, 14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | # See the License for the specific language governing permissions and 16 | # limitations under the License. 17 | #""" Finetuning the library models for sequence classification on GLUE (Bert, XLM, XLNet, RoBERTa).""" 18 | # 19 | # Modification history 20 | # 2020 Sawan Kumar: Adapted run_glue.py to learn classifiers for NILE 21 | 22 | 23 | from __future__ import absolute_import, division, print_function 24 | 25 | import argparse 26 | import glob 27 | import logging 28 | import os 29 | import random 30 | 31 | import numpy as np 32 | import torch 33 | import torch.nn as nn 34 | import torch.nn.functional as F 35 | from torch.utils.data import (DataLoader, RandomSampler, SequentialSampler, 36 | TensorDataset) 37 | from torch.utils.data.distributed import DistributedSampler 38 | from tensorboardX import SummaryWriter 39 | from tqdm import tqdm, trange 40 | 41 | from transformers import (WEIGHTS_NAME, 42 | RobertaConfig, 43 | RobertaForSequenceClassification, 44 | RobertaTokenizer) 45 | 46 | from transformers import AdamW, get_linear_schedule_with_warmup 47 | 48 | from example_to_feature import convert_examples_to_features as convert_examples_to_features 49 | 50 | from nli_utils import ExpProcessor 51 | from nli_utils import exp_compute_metrics as compute_metrics 52 | 53 | #from lm_utils import NLIDataset, CoQADataset 54 | 55 | logger = logging.getLogger(__name__) 56 | 57 | ALL_MODELS = sum((tuple(conf.pretrained_config_archive_map.keys()) for conf in (RobertaConfig,)), ()) 58 | 59 | MODEL_CLASSES = { 60 | 'roberta': (RobertaConfig, RobertaForSequenceClassification, RobertaTokenizer), 61 | } 62 | 63 | 64 | def set_seed(args): 65 | random.seed(args.seed) 66 | np.random.seed(args.seed) 67 | torch.manual_seed(args.seed) 68 | if args.n_gpu > 0: 69 | torch.cuda.manual_seed_all(args.seed) 70 | 71 | def get_logits(batch, model_output, exp_model): 72 | if exp_model in ["instance", "append", "instance_append", "all_explanation", "Explanation_1"]: 73 | return model_output 74 | model_output = model_output.view(batch[0].size(0), batch[0].size(1), model_output.size(1)) 75 | e,c,n = 0,1,2 76 | v1,v2 = 0,1 77 | if exp_model in ["independent", "instance_independent"]: 78 | evidence_e = [model_output[:, e, v1]] 79 | evidence_c = [model_output[:, c, v1]] 80 | evidence_n = [model_output[:, n, v1]] 81 | elif exp_model in ["aggregate", "instance_aggregate"]: 82 | evidence_e = [model_output[:, e, v1], model_output[:, c, v2]] 83 | evidence_c = [model_output[:, e, v2], model_output[:, c, v1]] 84 | evidence_n = [model_output[:, n, v1]] 85 | logits_e, logits_c, logits_n = [torch.cat([evd.unsqueeze(-1) for evd in item], 1) 86 | for item in [evidence_e, evidence_c, evidence_n]] 87 | logits_e, logits_c, logits_n = [torch.logsumexp(item, 1, keepdim=True) for item in [logits_e, logits_c, logits_n]] 88 | logits = torch.cat([logits_e, logits_c, logits_n], 1) 89 | return logits 90 | 91 | def train(args, train_dataset, model, tokenizer): 92 | """ Train the model """ 93 | if args.local_rank in [-1, 0]: 94 | tb_writer = SummaryWriter() 95 | 96 | args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu) 97 | train_sampler = RandomSampler(train_dataset) if args.local_rank == -1 else DistributedSampler(train_dataset) 98 | train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size) 99 | 100 | if args.max_steps > 0: 101 | t_total = args.max_steps 102 | args.num_train_epochs = args.max_steps // (len(train_dataloader) // args.gradient_accumulation_steps) + 1 103 | else: 104 | t_total = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs 105 | 106 | # Prepare optimizer and schedule (linear warmup and decay) 107 | no_decay = ['bias', 'LayerNorm.weight'] 108 | optimizer_grouped_parameters = [ 109 | {'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': args.weight_decay}, 110 | {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0} 111 | ] 112 | optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon) 113 | scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total) 114 | if args.fp16: 115 | try: 116 | from apex import amp 117 | except ImportError: 118 | raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.") 119 | model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16_opt_level) 120 | 121 | # multi-gpu training (should be after apex fp16 initialization) 122 | if args.n_gpu > 1: 123 | model = torch.nn.DataParallel(model) 124 | 125 | # Distributed training (should be after apex fp16 initialization) 126 | if args.local_rank != -1: 127 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank], 128 | output_device=args.local_rank, 129 | find_unused_parameters=True) 130 | 131 | # Train! 132 | logger.info("***** Running training *****") 133 | logger.info(" Num examples = %d", len(train_dataset)) 134 | logger.info(" Num Epochs = %d", args.num_train_epochs) 135 | logger.info(" Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size) 136 | logger.info(" Total train batch size (w. parallel, distributed & accumulation) = %d", 137 | args.train_batch_size * args.gradient_accumulation_steps * (torch.distributed.get_world_size() if args.local_rank != -1 else 1)) 138 | logger.info(" Gradient Accumulation steps = %d", args.gradient_accumulation_steps) 139 | logger.info(" Total optimization steps = %d", t_total) 140 | 141 | global_step = 0 142 | tr_loss, logging_loss = 0.0, 0.0 143 | model.zero_grad() 144 | train_iterator = trange(int(args.num_train_epochs), desc="Epoch", disable=args.local_rank not in [-1, 0]) 145 | set_seed(args) # Added here for reproductibility (even between python 2 and 3) 146 | for _ in train_iterator: 147 | epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0]) 148 | for step, batch in enumerate(epoch_iterator): 149 | model.train() 150 | batch = tuple(t.to(args.device) for t in batch) 151 | 152 | inputs = {'input_ids': batch[0].view(-1, batch[0].size(-1)), 153 | 'attention_mask': batch[1].view(-1, batch[0].size(-1))} 154 | inputs['token_type_ids'] = None 155 | outputs = model(**inputs) 156 | model_output = outputs[0] 157 | 158 | loss_fct = nn.CrossEntropyLoss() 159 | logits = get_logits(batch, model_output, args.exp_model) 160 | loss = loss_fct(logits, batch[3]) 161 | 162 | if args.n_gpu > 1: 163 | loss = loss.mean() # mean() to average on multi-gpu parallel training 164 | if args.gradient_accumulation_steps > 1: 165 | loss = loss / args.gradient_accumulation_steps 166 | 167 | if args.fp16: 168 | with amp.scale_loss(loss, optimizer) as scaled_loss: 169 | scaled_loss.backward() 170 | torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm) 171 | else: 172 | loss.backward() 173 | torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm) 174 | 175 | tr_loss += loss.item() 176 | if (step + 1) % args.gradient_accumulation_steps == 0: 177 | optimizer.step() 178 | scheduler.step() # Update learning rate schedule 179 | model.zero_grad() 180 | global_step += 1 181 | 182 | if args.local_rank in [-1, 0] and args.logging_steps > 0 and global_step % args.logging_steps == 0: 183 | # Log metrics 184 | if args.local_rank == -1 and args.evaluate_during_training: # Only evaluate when single GPU otherwise metrics may not average well 185 | results = evaluate(args, model, tokenizer) 186 | for key, value in results.items(): 187 | tb_writer.add_scalar('eval_{}'.format(key), value, global_step) 188 | tb_writer.add_scalar('lr', scheduler.get_lr()[0], global_step) 189 | tb_writer.add_scalar('loss', (tr_loss - logging_loss)/args.logging_steps, global_step) 190 | logging_loss = tr_loss 191 | 192 | if args.local_rank in [-1, 0] and args.save_steps > 0 and global_step % args.save_steps == 0: 193 | # Save model checkpoint 194 | output_dir = os.path.join(args.output_dir, 'checkpoint-{}'.format(global_step)) 195 | if not os.path.exists(output_dir): 196 | os.makedirs(output_dir) 197 | model_to_save = model.module if hasattr(model, 'module') else model # Take care of distributed/parallel training 198 | model_to_save.save_pretrained(output_dir) 199 | torch.save(args, os.path.join(output_dir, 'training_args.bin')) 200 | logger.info("Saving model checkpoint to %s", output_dir) 201 | 202 | if args.max_steps > 0 and global_step > args.max_steps: 203 | epoch_iterator.close() 204 | break 205 | if args.max_steps > 0 and global_step > args.max_steps: 206 | train_iterator.close() 207 | break 208 | 209 | if args.local_rank in [-1, 0]: 210 | tb_writer.close() 211 | 212 | return global_step, tr_loss / global_step 213 | 214 | 215 | def evaluate(args, model, tokenizer, prefix="", analyze_attentions=False, eval_on_train=False): 216 | processor = ExpProcessor() 217 | eval_output_dir = args.output_dir 218 | 219 | results = {} 220 | eval_dataset, indices = load_and_cache_examples(args, tokenizer, evaluate=True) 221 | if not os.path.exists(eval_output_dir) and args.local_rank in [-1, 0]: 222 | os.makedirs(eval_output_dir) 223 | 224 | args.eval_batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu) 225 | # Note that DistributedSampler samples randomly 226 | eval_sampler = SequentialSampler(eval_dataset) if args.local_rank == -1 else DistributedSampler(eval_dataset) 227 | eval_dataloader = DataLoader(eval_dataset, sampler=eval_sampler, batch_size=args.eval_batch_size) 228 | 229 | # Eval! 230 | logger.info("***** Running evaluation {} *****".format(prefix)) 231 | logger.info(" Num examples = %d", len(eval_dataset)) 232 | logger.info(" Batch size = %d", args.eval_batch_size) 233 | eval_loss = 0.0 234 | nb_eval_steps = 0 235 | preds = None 236 | out_label_ids = None 237 | attentions = [] 238 | for batch in tqdm(eval_dataloader, desc="Evaluating"): 239 | model.eval() 240 | batch = tuple(t.to(args.device) for t in batch) 241 | 242 | with torch.no_grad(): 243 | inputs = {'input_ids': batch[0].view(-1, batch[0].size(-1)), 244 | 'attention_mask': batch[1].view(-1, batch[0].size(-1))} 245 | inputs['token_type_ids'] = None 246 | outputs = model(**inputs) 247 | model_output = outputs[0] 248 | loss_fct = nn.CrossEntropyLoss() 249 | logits = get_logits(batch, model_output, args.exp_model) 250 | tmp_eval_loss = loss_fct(logits, batch[3]) 251 | eval_loss += tmp_eval_loss.mean().item() 252 | 253 | nb_eval_steps += 1 254 | inputs['labels'] = batch[3] 255 | if preds is None: 256 | preds = logits.detach().cpu().numpy() 257 | out_label_ids = inputs['labels'].detach().cpu().numpy() 258 | else: 259 | preds = np.append(preds, logits.detach().cpu().numpy(), axis=0) 260 | out_label_ids = np.append(out_label_ids, inputs['labels'].detach().cpu().numpy(), axis=0) 261 | 262 | eval_loss = eval_loss / nb_eval_steps 263 | preds_ = preds 264 | preds = np.argmax(preds, axis=1) 265 | result = compute_metrics(preds, out_label_ids) 266 | results.update(result) 267 | 268 | if not eval_on_train: 269 | eval_dataset = os.path.basename(os.path.dirname(os.path.dirname(args.eval_file))).split('_')[1] 270 | eval_file_base = os.path.splitext(os.path.basename(args.eval_file))[0] 271 | to_drop_list = args.to_drop.split(',') if evaluate else [] 272 | to_drop_str = '_drop'+''.join(to_drop_list) if args.to_drop else '' 273 | prediction_file = os.path.join(args.output_dir, 'predictions_{}_{}_{}_{}{}.npz'.format( 274 | eval_dataset, 275 | eval_file_base, 276 | str(args.max_seq_length), 277 | str(args.data_format), 278 | to_drop_str)) 279 | print ("Writing predictions to ", prediction_file) 280 | np.savez_compressed(prediction_file, preds=preds, indices=indices, labels=out_label_ids) 281 | 282 | output_eval_file = os.path.join(eval_output_dir, "eval_results.txt") 283 | with open(output_eval_file, "w") as writer: 284 | logger.info("***** Eval results {} *****".format(prefix)) 285 | for key in sorted(result.keys()): 286 | logger.info(" %s = %s", key, str(result[key])) 287 | writer.write("%s = %s\n" % (key, str(result[key]))) 288 | 289 | return results 290 | 291 | 292 | def load_and_cache_examples(args, tokenizer, evaluate=False): 293 | if args.local_rank not in [-1, 0] and not evaluate: 294 | torch.distributed.barrier() # Make sure only the first process in distributed training process the dataset, and the others will use the cache 295 | 296 | processor = ExpProcessor() 297 | # Load data features from cache or dataset file 298 | 299 | filename = args.train_file if not evaluate else args.eval_file 300 | data_dir, filename_base = os.path.split(filename) 301 | 302 | if args.data_format == "aggregate": data_storage_format = "independent" 303 | elif args.data_format == "instance_aggregate": data_storage_format = "instance_independent" 304 | else: data_storage_format = args.data_format 305 | 306 | to_drop_list = args.to_drop.split(',') if evaluate else [] 307 | 308 | cached_features_file = os.path.join(data_dir, 'cached_seq{}_{}_{}_{}_{}'.format( 309 | str(args.max_seq_length), 310 | filename_base, 311 | data_storage_format, 312 | 'drop'+args.to_drop if args.to_drop and evaluate else '', 313 | '_negs' if args.sample_negs and not evaluate else '' 314 | )) 315 | 316 | if os.path.exists(cached_features_file): 317 | logger.info("Loading features from cached file %s", cached_features_file) 318 | examples, features = torch.load(cached_features_file) 319 | else: 320 | logger.info("Creating features from dataset file at %s", data_dir) 321 | label_list = processor.get_labels() 322 | fn = processor.get_dev_examples if evaluate else processor.get_train_examples 323 | examples = fn(filename, data_format=data_storage_format, to_drop=to_drop_list) 324 | 325 | features = convert_examples_to_features(examples, 326 | tokenizer, 327 | label_list=label_list, 328 | max_length=args.max_seq_length, 329 | pad_on_left=False, 330 | pad_token=tokenizer.convert_tokens_to_ids([tokenizer.pad_token])[0], 331 | pad_token_segment_id=0, 332 | sample_negatives=args.sample_negs if not evaluate else False, 333 | ) 334 | if args.local_rank in [-1, 0]: 335 | logger.info("Saving features into cached file %s", cached_features_file) 336 | torch.save((examples, features), cached_features_file) 337 | 338 | if args.local_rank == 0 and not evaluate: 339 | torch.distributed.barrier() # Make sure only the first process in distributed training process the dataset, and the others will use the cache 340 | 341 | # Convert to Tensors and build dataset 342 | features = features 343 | all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long) 344 | all_attention_mask = torch.tensor([f.attention_mask for f in features], dtype=torch.long) 345 | all_token_type_ids = torch.tensor([f.token_type_ids for f in features], dtype=torch.long) 346 | all_labels = torch.tensor([f.label for f in features], dtype=torch.long) 347 | 348 | dataset = TensorDataset(all_input_ids, all_attention_mask, all_token_type_ids, all_labels) 349 | 350 | indices = [example.guid for example in examples] 351 | return dataset, indices 352 | 353 | 354 | def main(): 355 | parser = argparse.ArgumentParser() 356 | 357 | ## Required parameters 358 | parser.add_argument("--train_file", default=None, type=str, required=True, 359 | help="The input train file. Should contain the .tsv files (or other data files) for the task.") 360 | parser.add_argument("--eval_file", default=None, type=str, required=True, 361 | help="The input eval file. Should contain the .tsv files (or other data files) for the task.") 362 | parser.add_argument("--model_type", default=None, type=str, required=True, 363 | help="Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys())) 364 | parser.add_argument("--model_name_or_path", default=None, type=str, required=True, 365 | help="Path to pre-trained model or shortcut name selected in the list: " + ", ".join(ALL_MODELS)) 366 | parser.add_argument("--output_dir", default=None, type=str, required=True, 367 | help="The output directory where the model predictions and checkpoints will be written.") 368 | 369 | ## Other parameters 370 | parser.add_argument("--exp_model", default="instance", type=str) 371 | parser.add_argument("--data_format", default="instance", type=str) 372 | parser.add_argument("--to_drop", default="", type=str) #comma-sep list 373 | 374 | parser.add_argument("--config_name", default="", type=str, 375 | help="Pretrained config name or path if not the same as model_name") 376 | parser.add_argument("--tokenizer_name", default="", type=str, 377 | help="Pretrained tokenizer name or path if not the same as model_name") 378 | parser.add_argument("--cache_dir", default="", type=str, 379 | help="Where do you want to store the pre-trained models downloaded from s3") 380 | parser.add_argument("--max_seq_length", default=128, type=int, 381 | help="The maximum total input sequence length after tokenization. Sequences longer " 382 | "than this will be truncated, sequences shorter will be padded.") 383 | parser.add_argument("--do_train", action='store_true', 384 | help="Whether to run training.") 385 | parser.add_argument("--do_eval", action='store_true', 386 | help="Whether to run eval on the dev set.") 387 | parser.add_argument("--evaluate_during_training", action='store_true', 388 | help="Rul evaluation during training at each logging step.") 389 | parser.add_argument("--do_lower_case", action='store_true', 390 | help="Set this flag if you are using an uncased model.") 391 | 392 | parser.add_argument("--per_gpu_train_batch_size", default=8, type=int, 393 | help="Batch size per GPU/CPU for training.") 394 | parser.add_argument("--per_gpu_eval_batch_size", default=8, type=int, 395 | help="Batch size per GPU/CPU for evaluation.") 396 | parser.add_argument('--gradient_accumulation_steps', type=int, default=1, 397 | help="Number of updates steps to accumulate before performing a backward/update pass.") 398 | parser.add_argument("--learning_rate", default=5e-5, type=float, 399 | help="The initial learning rate for Adam.") 400 | parser.add_argument("--weight_decay", default=0.0, type=float, 401 | help="Weight deay if we apply some.") 402 | parser.add_argument("--adam_epsilon", default=1e-8, type=float, 403 | help="Epsilon for Adam optimizer.") 404 | parser.add_argument("--max_grad_norm", default=1.0, type=float, 405 | help="Max gradient norm.") 406 | parser.add_argument("--num_train_epochs", default=3.0, type=float, 407 | help="Total number of training epochs to perform.") 408 | parser.add_argument("--max_steps", default=-1, type=int, 409 | help="If > 0: set total number of training steps to perform. Override num_train_epochs.") 410 | parser.add_argument("--warmup_steps", default=0, type=int, 411 | help="Linear warmup over warmup_steps.") 412 | 413 | parser.add_argument("--prompt_type", default="none", type=str, 414 | help="Prompt given before explanation") 415 | parser.add_argument("--use_annotations", action='store_true', 416 | help="Whether to use annotations instead of generated explanations") 417 | 418 | parser.add_argument('--logging_steps', type=int, default=50, 419 | help="Log every X updates steps.") 420 | parser.add_argument('--save_steps', type=int, default=50, 421 | help="Save checkpoint every X updates steps.") 422 | parser.add_argument("--eval_all_checkpoints", action='store_true', 423 | help="Evaluate all checkpoints starting with the same prefix as model_name ending and ending with step number") 424 | parser.add_argument("--no_cuda", action='store_true', 425 | help="Avoid using CUDA when available") 426 | parser.add_argument('--overwrite_output_dir', action='store_true', 427 | help="Overwrite the content of the output directory") 428 | parser.add_argument('--overwrite_cache', action='store_true', 429 | help="Overwrite the cached training and evaluation sets") 430 | parser.add_argument('--seed', type=int, default=42, 431 | help="random seed for initialization") 432 | 433 | parser.add_argument('--sample_negs', action='store_true', 434 | help='sample negative conjectures') 435 | 436 | parser.add_argument('--fp16', action='store_true', 437 | help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit") 438 | parser.add_argument('--fp16_opt_level', type=str, default='O1', 439 | help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']." 440 | "See details at https://nvidia.github.io/apex/amp.html") 441 | parser.add_argument("--local_rank", type=int, default=-1, 442 | help="For distributed training: local_rank") 443 | parser.add_argument('--server_ip', type=str, default='', help="For distant debugging.") 444 | parser.add_argument('--server_port', type=str, default='', help="For distant debugging.") 445 | args = parser.parse_args() 446 | 447 | if os.path.exists(args.output_dir) and os.listdir(args.output_dir) and args.do_train and not args.overwrite_output_dir: 448 | raise ValueError("Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome.".format(args.output_dir)) 449 | 450 | # Setup distant debugging if needed 451 | if args.server_ip and args.server_port: 452 | # Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script 453 | import ptvsd 454 | print("Waiting for debugger attach") 455 | ptvsd.enable_attach(address=(args.server_ip, args.server_port), redirect_output=True) 456 | ptvsd.wait_for_attach() 457 | 458 | # Setup CUDA, GPU & distributed training 459 | if args.local_rank == -1 or args.no_cuda: 460 | device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu") 461 | args.n_gpu = torch.cuda.device_count() 462 | else: # Initializes the distributed backend which will take care of sychronizing nodes/GPUs 463 | torch.cuda.set_device(args.local_rank) 464 | device = torch.device("cuda", args.local_rank) 465 | torch.distributed.init_process_group(backend='nccl') 466 | args.n_gpu = 1 467 | args.device = device 468 | 469 | # Setup logging 470 | logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s', 471 | datefmt = '%m/%d/%Y %H:%M:%S', 472 | level = logging.INFO if args.local_rank in [-1, 0] else logging.WARN) 473 | logger.warning("Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s", 474 | args.local_rank, device, args.n_gpu, bool(args.local_rank != -1), args.fp16) 475 | 476 | # Set seed 477 | set_seed(args) 478 | 479 | processor = ExpProcessor() 480 | label_list = processor.get_labels() 481 | num_labels = len(label_list) 482 | 483 | if args.exp_model in ["independent", "instance_independent"]: args.model_num_outputs = 1 484 | elif args.exp_model in ["aggregate", "instance_aggregate"]: args.model_num_outputs = 2 485 | else: args.model_num_outputs = 3 486 | 487 | # Load pretrained model and tokenizer 488 | if args.local_rank not in [-1, 0]: 489 | torch.distributed.barrier() # Make sure only the first process in distributed training will download model & vocab 490 | 491 | args.model_type = args.model_type.lower() 492 | 493 | config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type] 494 | config = config_class.from_pretrained(args.config_name if args.config_name else args.model_name_or_path, num_labels=args.model_num_outputs, 495 | cache_dir=args.cache_dir if args.cache_dir else None) 496 | tokenizer = tokenizer_class.from_pretrained(args.tokenizer_name if args.tokenizer_name else args.model_name_or_path, do_lower_case=args.do_lower_case, 497 | cache_dir=args.cache_dir if args.cache_dir else None) 498 | 499 | if args.do_train: 500 | model = model_class.from_pretrained(args.model_name_or_path, from_tf=bool('.ckpt' in args.model_name_or_path), config=config, 501 | cache_dir=args.cache_dir if args.cache_dir else None) 502 | 503 | if args.local_rank == 0: 504 | torch.distributed.barrier() # Make sure only the first process in distributed training will download model & vocab 505 | 506 | 507 | logger.info("Training/evaluation parameters %s", args) 508 | 509 | 510 | # Training 511 | if args.do_train: 512 | model.to(args.device) 513 | train_dataset, _ = load_and_cache_examples(args, tokenizer, evaluate=False) 514 | global_step, tr_loss = train(args, train_dataset, model, tokenizer) 515 | logger.info(" global_step = %s, average loss = %s", global_step, tr_loss) 516 | 517 | 518 | # Saving best-practices: if you use defaults names for the model, you can reload it using from_pretrained() 519 | if args.do_train and (args.local_rank == -1 or torch.distributed.get_rank() == 0): 520 | # Create output directory if needed 521 | if not os.path.exists(args.output_dir) and args.local_rank in [-1, 0]: 522 | os.makedirs(args.output_dir) 523 | 524 | logger.info("Saving model checkpoint to %s", args.output_dir) 525 | # Save a trained model, configuration and tokenizer using `save_pretrained()`. 526 | # They can then be reloaded using `from_pretrained()` 527 | model_to_save = model.module if hasattr(model, 'module') else model # Take care of distributed/parallel training 528 | model_to_save.save_pretrained(args.output_dir) 529 | tokenizer.save_pretrained(args.output_dir) 530 | 531 | # Good practice: save your training arguments together with the trained model 532 | torch.save(args, os.path.join(args.output_dir, 'training_args.bin')) 533 | 534 | # Load a trained model and vocabulary that you have fine-tuned 535 | model = model_class.from_pretrained(args.output_dir) 536 | tokenizer = tokenizer_class.from_pretrained(args.output_dir, do_lower_case=args.do_lower_case) 537 | model.to(args.device) 538 | 539 | 540 | # Evaluation 541 | results = {} 542 | if args.do_eval and args.local_rank in [-1, 0]: 543 | tokenizer = tokenizer_class.from_pretrained(args.output_dir, do_lower_case=args.do_lower_case) 544 | checkpoints = [args.output_dir] 545 | if args.eval_all_checkpoints: 546 | checkpoints = list(os.path.dirname(c) for c in sorted(glob.glob(args.output_dir + '/**/' + WEIGHTS_NAME, recursive=True))) 547 | logging.getLogger("transformers.modeling_utils").setLevel(logging.WARN) # Reduce logging 548 | logger.info("Evaluate the following checkpoints: %s", checkpoints) 549 | for checkpoint in checkpoints: 550 | global_step = checkpoint.split('-')[-1] if len(checkpoints) > 1 else "" 551 | model = model_class.from_pretrained(checkpoint) 552 | model.output_hidden_states = True 553 | model.output_attentions = True 554 | model.to(args.device) 555 | result = evaluate(args, model, tokenizer, prefix=global_step, analyze_attentions=True) 556 | result = dict((k + '_{}'.format(global_step), v) for k, v in result.items()) 557 | results.update(result) 558 | 559 | print (results) 560 | return results 561 | 562 | 563 | if __name__ == "__main__": 564 | main() 565 | --------------------------------------------------------------------------------