├── .gitignore
├── LICENSE
├── README.md
├── example_to_feature.py
├── fetch_data.sh
├── finetune_lm.py
├── generated_explanations
├── snli_dev.csv.gz
└── snli_test.csv.gz
├── images
└── architecture.jpg
├── lm_utils.py
├── merge_esnli_train.py
├── nli_utils.py
├── prepare_train_test.py
├── requirements.txt
├── run_clf.sh
├── run_finetune_gpt2m.sh
├── run_generate_gpt2m.sh
└── run_nli.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | *.egg-info/
24 | .installed.cfg
25 | *.egg
26 | MANIFEST
27 |
28 | # PyInstaller
29 | # Usually these files are written by a python script from a template
30 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
31 | *.manifest
32 | *.spec
33 |
34 | # Installer logs
35 | pip-log.txt
36 | pip-delete-this-directory.txt
37 |
38 | # Unit test / coverage reports
39 | htmlcov/
40 | .tox/
41 | .coverage
42 | .coverage.*
43 | .cache
44 | nosetests.xml
45 | coverage.xml
46 | *.cover
47 | .hypothesis/
48 | .pytest_cache/
49 |
50 | # Translations
51 | *.mo
52 | *.pot
53 |
54 | # Django stuff:
55 | *.log
56 | local_settings.py
57 | db.sqlite3
58 |
59 | # Flask stuff:
60 | instance/
61 | .webassets-cache
62 |
63 | # Scrapy stuff:
64 | .scrapy
65 |
66 | # Sphinx documentation
67 | docs/_build/
68 |
69 | # PyBuilder
70 | target/
71 |
72 | # Jupyter Notebook
73 | .ipynb_checkpoints
74 |
75 | # pyenv
76 | .python-version
77 |
78 | # celery beat schedule file
79 | celerybeat-schedule
80 |
81 | # SageMath parsed files
82 | *.sage.py
83 |
84 | # Environments
85 | .env
86 | .venv
87 | env/
88 | venv/
89 | ENV/
90 | env.bak/
91 | venv.bak/
92 |
93 | # Spyder project settings
94 | .spyderproject
95 | .spyproject
96 |
97 | # Rope project settings
98 | .ropeproject
99 |
100 | # mkdocs documentation
101 | /site
102 |
103 | # mypy
104 | .mypy_cache/
105 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 |
2 | Apache License
3 | Version 2.0, January 2004
4 | http://www.apache.org/licenses/
5 |
6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
7 |
8 | 1. Definitions.
9 |
10 | "License" shall mean the terms and conditions for use, reproduction,
11 | and distribution as defined by Sections 1 through 9 of this document.
12 |
13 | "Licensor" shall mean the copyright owner or entity authorized by
14 | the copyright owner that is granting the License.
15 |
16 | "Legal Entity" shall mean the union of the acting entity and all
17 | other entities that control, are controlled by, or are under common
18 | control with that entity. For the purposes of this definition,
19 | "control" means (i) the power, direct or indirect, to cause the
20 | direction or management of such entity, whether by contract or
21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
22 | outstanding shares, or (iii) beneficial ownership of such entity.
23 |
24 | "You" (or "Your") shall mean an individual or Legal Entity
25 | exercising permissions granted by this License.
26 |
27 | "Source" form shall mean the preferred form for making modifications,
28 | including but not limited to software source code, documentation
29 | source, and configuration files.
30 |
31 | "Object" form shall mean any form resulting from mechanical
32 | transformation or translation of a Source form, including but
33 | not limited to compiled object code, generated documentation,
34 | and conversions to other media types.
35 |
36 | "Work" shall mean the work of authorship, whether in Source or
37 | Object form, made available under the License, as indicated by a
38 | copyright notice that is included in or attached to the work
39 | (an example is provided in the Appendix below).
40 |
41 | "Derivative Works" shall mean any work, whether in Source or Object
42 | form, that is based on (or derived from) the Work and for which the
43 | editorial revisions, annotations, elaborations, or other modifications
44 | represent, as a whole, an original work of authorship. For the purposes
45 | of this License, Derivative Works shall not include works that remain
46 | separable from, or merely link (or bind by name) to the interfaces of,
47 | the Work and Derivative Works thereof.
48 |
49 | "Contribution" shall mean any work of authorship, including
50 | the original version of the Work and any modifications or additions
51 | to that Work or Derivative Works thereof, that is intentionally
52 | submitted to Licensor for inclusion in the Work by the copyright owner
53 | or by an individual or Legal Entity authorized to submit on behalf of
54 | the copyright owner. For the purposes of this definition, "submitted"
55 | means any form of electronic, verbal, or written communication sent
56 | to the Licensor or its representatives, including but not limited to
57 | communication on electronic mailing lists, source code control systems,
58 | and issue tracking systems that are managed by, or on behalf of, the
59 | Licensor for the purpose of discussing and improving the Work, but
60 | excluding communication that is conspicuously marked or otherwise
61 | designated in writing by the copyright owner as "Not a Contribution."
62 |
63 | "Contributor" shall mean Licensor and any individual or Legal Entity
64 | on behalf of whom a Contribution has been received by Licensor and
65 | subsequently incorporated within the Work.
66 |
67 | 2. Grant of Copyright License. Subject to the terms and conditions of
68 | this License, each Contributor hereby grants to You a perpetual,
69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
70 | copyright license to reproduce, prepare Derivative Works of,
71 | publicly display, publicly perform, sublicense, and distribute the
72 | Work and such Derivative Works in Source or Object form.
73 |
74 | 3. Grant of Patent License. Subject to the terms and conditions of
75 | this License, each Contributor hereby grants to You a perpetual,
76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
77 | (except as stated in this section) patent license to make, have made,
78 | use, offer to sell, sell, import, and otherwise transfer the Work,
79 | where such license applies only to those patent claims licensable
80 | by such Contributor that are necessarily infringed by their
81 | Contribution(s) alone or by combination of their Contribution(s)
82 | with the Work to which such Contribution(s) was submitted. If You
83 | institute patent litigation against any entity (including a
84 | cross-claim or counterclaim in a lawsuit) alleging that the Work
85 | or a Contribution incorporated within the Work constitutes direct
86 | or contributory patent infringement, then any patent licenses
87 | granted to You under this License for that Work shall terminate
88 | as of the date such litigation is filed.
89 |
90 | 4. Redistribution. You may reproduce and distribute copies of the
91 | Work or Derivative Works thereof in any medium, with or without
92 | modifications, and in Source or Object form, provided that You
93 | meet the following conditions:
94 |
95 | (a) You must give any other recipients of the Work or
96 | Derivative Works a copy of this License; and
97 |
98 | (b) You must cause any modified files to carry prominent notices
99 | stating that You changed the files; and
100 |
101 | (c) You must retain, in the Source form of any Derivative Works
102 | that You distribute, all copyright, patent, trademark, and
103 | attribution notices from the Source form of the Work,
104 | excluding those notices that do not pertain to any part of
105 | the Derivative Works; and
106 |
107 | (d) If the Work includes a "NOTICE" text file as part of its
108 | distribution, then any Derivative Works that You distribute must
109 | include a readable copy of the attribution notices contained
110 | within such NOTICE file, excluding those notices that do not
111 | pertain to any part of the Derivative Works, in at least one
112 | of the following places: within a NOTICE text file distributed
113 | as part of the Derivative Works; within the Source form or
114 | documentation, if provided along with the Derivative Works; or,
115 | within a display generated by the Derivative Works, if and
116 | wherever such third-party notices normally appear. The contents
117 | of the NOTICE file are for informational purposes only and
118 | do not modify the License. You may add Your own attribution
119 | notices within Derivative Works that You distribute, alongside
120 | or as an addendum to the NOTICE text from the Work, provided
121 | that such additional attribution notices cannot be construed
122 | as modifying the License.
123 |
124 | You may add Your own copyright statement to Your modifications and
125 | may provide additional or different license terms and conditions
126 | for use, reproduction, or distribution of Your modifications, or
127 | for any such Derivative Works as a whole, provided Your use,
128 | reproduction, and distribution of the Work otherwise complies with
129 | the conditions stated in this License.
130 |
131 | 5. Submission of Contributions. Unless You explicitly state otherwise,
132 | any Contribution intentionally submitted for inclusion in the Work
133 | by You to the Licensor shall be under the terms and conditions of
134 | this License, without any additional terms or conditions.
135 | Notwithstanding the above, nothing herein shall supersede or modify
136 | the terms of any separate license agreement you may have executed
137 | with Licensor regarding such Contributions.
138 |
139 | 6. Trademarks. This License does not grant permission to use the trade
140 | names, trademarks, service marks, or product names of the Licensor,
141 | except as required for reasonable and customary use in describing the
142 | origin of the Work and reproducing the content of the NOTICE file.
143 |
144 | 7. Disclaimer of Warranty. Unless required by applicable law or
145 | agreed to in writing, Licensor provides the Work (and each
146 | Contributor provides its Contributions) on an "AS IS" BASIS,
147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
148 | implied, including, without limitation, any warranties or conditions
149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
150 | PARTICULAR PURPOSE. You are solely responsible for determining the
151 | appropriateness of using or redistributing the Work and assume any
152 | risks associated with Your exercise of permissions under this License.
153 |
154 | 8. Limitation of Liability. In no event and under no legal theory,
155 | whether in tort (including negligence), contract, or otherwise,
156 | unless required by applicable law (such as deliberate and grossly
157 | negligent acts) or agreed to in writing, shall any Contributor be
158 | liable to You for damages, including any direct, indirect, special,
159 | incidental, or consequential damages of any character arising as a
160 | result of this License or out of the use or inability to use the
161 | Work (including but not limited to damages for loss of goodwill,
162 | work stoppage, computer failure or malfunction, or any and all
163 | other commercial damages or losses), even if such Contributor
164 | has been advised of the possibility of such damages.
165 |
166 | 9. Accepting Warranty or Additional Liability. While redistributing
167 | the Work or Derivative Works thereof, You may choose to offer,
168 | and charge a fee for, acceptance of support, warranty, indemnity,
169 | or other liability obligations and/or rights consistent with this
170 | License. However, in accepting such obligations, You may act only
171 | on Your own behalf and on Your sole responsibility, not on behalf
172 | of any other Contributor, and only if You agree to indemnify,
173 | defend, and hold each Contributor harmless for any liability
174 | incurred by, or claims asserted against, such Contributor by reason
175 | of your accepting any such warranty or additional liability.
176 |
177 | END OF TERMS AND CONDITIONS
178 |
179 | APPENDIX: How to apply the Apache License to your work.
180 |
181 | To apply the Apache License to your work, attach the following
182 | boilerplate notice, with the fields enclosed by brackets "[]"
183 | replaced with your own identifying information. (Don't include
184 | the brackets!) The text should be enclosed in the appropriate
185 | comment syntax for the file format. We also recommend that a
186 | file or class name and description of purpose be included on the
187 | same "printed page" as the copyright notice for easier
188 | identification within third-party archives.
189 |
190 | Copyright [yyyy] [name of copyright owner]
191 |
192 | Licensed under the Apache License, Version 2.0 (the "License");
193 | you may not use this file except in compliance with the License.
194 | You may obtain a copy of the License at
195 |
196 | http://www.apache.org/licenses/LICENSE-2.0
197 |
198 | Unless required by applicable law or agreed to in writing, software
199 | distributed under the License is distributed on an "AS IS" BASIS,
200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
201 | See the License for the specific language governing permissions and
202 | limitations under the License.
203 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # NILE
2 | Reference code for [ACL20](http://acl2019.org/) paper - [NILE : Natural Language Inference with Faithful Natural Language Explanations](https://www.aclweb.org/anthology/2020.acl-main.771/).
3 |
4 |
5 |
6 |
7 |
8 | ## Dependencies
9 | The code was written with, or depends on:
10 | * Python 3.6
11 | * Pytorch 1.4.0
12 |
13 | ## Running the code
14 | 1. Create a virtualenv and install dependecies
15 | ```bash
16 | virtualenv -p python3.6 env
17 | source env/bin/activate
18 | pip install -r requirements.txt
19 | ```
20 | 1. Fetch data and pre-process. This will create add files in ```dataset_snli``` and ```dataset_mnli``` folders.
21 | ```bash
22 | bash fetch_data.sh
23 | python prepare_train_test.py --dataset snli --create_data --filter_repetitions
24 | python prepare_train_test.py --dataset mnli --create_data --filter_repetitions
25 | ```
26 | 1. Fine-tuning langauge models using e-SNLI, for entailment, contradiction and neutral explanations. 'all' is trained to produce a comparable single-explanation ETPA baseline, and can be skippd in this and subsequent steps if only reproducing NILE.
27 | ```bash
28 | bash run_finetune_gpt2m.sh 0 entailment 2
29 | bash run_finetune_gpt2m.sh 0 contradiction 2
30 | bash run_finetune_gpt2m.sh 0 neutral 2
31 | bash run_finetune_gpt2m.sh 0 all 2
32 | ```
33 | 1. Generate explanations using the fine-tuned langauge models, where can be snli or mnli, and is train/dev/test for SNLI and dev/dev_mm for MNLI.
34 | ```bash
35 | bash run_generate_gpt2m.sh 0 entailment all
36 | bash run_generate_gpt2m.sh 0 contradiction all
37 | bash run_generate_gpt2m.sh 0 neutral all
38 | bash run_generate_gpt2m.sh 0 all all
39 | ```
40 | 1. Merge generated explanation
41 | ```bash
42 | python prepare_train_test.py --dataset --merge_data --split --input_prefix gpt2_m_
43 | ```
44 |
45 | To merge for the single-explanatoin baseline, run
46 | ```bash
47 | python prepare_train_test.py --dataset --merge_data --split --input_prefix gpt2_m_ --merge_single
48 | ```
49 | 1. Train classifiers on the generated explantions, models are saved at ```saved_clf```.
50 |
51 | NILE-PH
52 | ```bash
53 | bash run_clf.sh 0 snli independent independent gpt2_m_ train dev _ _ _
54 | bash run_clf.sh 0 snli aggregate aggregate gpt2_m_ train dev _ _ _
55 | bash run_clf.sh 0 snli append append gpt2_m_ train dev _ _ _
56 | ```
57 | NILE-NS
58 | ```bash
59 | bash run_clf.sh 0 snli instance_independent instance_independent gpt2_m_ train dev _ _ _
60 | bash run_clf.sh 0 snli instance_aggregate instance_aggregate gpt2_m_ train dev _ _ _
61 | bash run_clf.sh 0 snli instance_append instance_append gpt2_m_ train dev _ _ _
62 | ```
63 | NILE
64 | ```bash
65 | bash run_clf.sh 0 snli instance_independent instance_independent gpt2_m_ train dev sample _ _
66 | bash run_clf.sh 0 snli instance_aggregate instance_aggregate gpt2_m_ train dev sample _ _
67 | ```
68 | NILE post-hoc
69 | ```bash
70 | bash run_clf.sh 0 snli instance instance gpt2_m_ train dev sample _ _
71 | ```
72 | Single-explanation baseline
73 | ```bash
74 | bash run_clf.sh 0 snli Explanation_1 Explanation_1 gpt2_m_ train dev sample _ _
75 | ```
76 | 1. Evaluate a trained classifier for label accuracy, is the path of a model saved in the previous step, and can be independent, aggregate, append, instance_independent, instance_aggregate, and instance for NILE variants, and all_explanation for the single-explanation baseline
77 | ```bash
78 | bash run_clf.sh 0 snli gpt2_m_ _ test _ _
79 | ```
80 |
81 | ## Example explanations
82 | Generated explanations on the e-SNLI dev and test sets are present at ```./generated_explanations/*.gz```. Please unzip before using (```gunzip ```).
83 | The generated explanations are present in the ```entailment_explanation,contradiction_explanation,neutral_explanation``` columns in a csv format.
84 |
85 | ## Pre-trained models
86 | We are sharing pre-trained [label-specific generators](https://drive.google.com/file/d/1lZZYbAwZ8kphY8lp0bVSOQ841c683uUc/view?usp=sharing).
87 | We are also sharing pre-trained classifiers for [NILE-PH Append](https://drive.google.com/file/d/1DacGNNiPvUC6lYk9jzq44QlR5uNAOPss/view?usp=sharing) and [NILE Independent](https://drive.google.com/file/d/10xcnzWTyg1dgX8hldAsnGn52Oqocvqon/view?usp=sharing) architectures.
88 |
89 | ## Citation
90 | If you use this code, please consider citing:
91 |
92 | [1] Sawan Kumar and Partha Talukdar. 2020. NILE : Natural language inference with faithful natural language explanations. In Proceedings of the 58th Annual Meeting of the Association for Computational Linguistics, pages 8730–8742, Online. Association for Computational Linguistics.
93 | [[bibtex](https://www.aclweb.org/anthology/2020.acl-main.771.bib)]
94 |
95 | ## Contact
96 | For any clarification, comments, or suggestions please create an issue or contact sawankumar@iisc.ac.in
97 |
--------------------------------------------------------------------------------
/example_to_feature.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
4 | # Modifications copyright (c) 2020 Sawan Kumar
5 | #
6 | # Licensed under the Apache License, Version 2.0 (the "License");
7 | # you may not use this file except in compliance with the License.
8 | # You may obtain a copy of the License at
9 | #
10 | # http://www.apache.org/licenses/LICENSE-2.0
11 | #
12 | # Unless required by applicable law or agreed to in writing, software
13 | # distributed under the License is distributed on an "AS IS" BASIS,
14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 | # See the License for the specific language governing permissions and
16 | # limitations under the License.
17 | #""" GLUE processors and helpers """
18 | #
19 | # Modification history
20 | # 2020 Sawan Kumar: Modified from glue.py in HuggingFace's transformers,
21 | # to handle examples for NILE
22 |
23 | import logging
24 | import os
25 | import numpy as np
26 |
27 | from transformers.data.processors.utils import DataProcessor, InputExample, InputFeatures
28 | from transformers.file_utils import is_tf_available
29 |
30 | if is_tf_available():
31 | import tensorflow as tf
32 |
33 | logger = logging.getLogger(__name__)
34 |
35 | def convert_examples_to_features(examples, tokenizer,
36 | max_length=512,
37 | label_list=None,
38 | pad_on_left=False,
39 | pad_token=0,
40 | pad_token_segment_id=0,
41 | mask_padding_with_zero=True,
42 | sample_negatives=False):
43 | """
44 | Loads a data file into a list of ``InputFeatures``
45 |
46 | Args:
47 | examples: List of ``InputExamples`` or ``tf.data.Dataset`` containing the examples.
48 | tokenizer: Instance of a tokenizer that will tokenize the examples
49 | max_length: Maximum example length
50 | task: GLUE task
51 | label_list: List of labels. Can be obtained from the processor using the ``processor.get_labels()`` method
52 | pad_on_left: If set to ``True``, the examples will be padded on the left rather than on the right (default)
53 | pad_token: Padding token
54 | pad_token_segment_id: The segment ID for the padding token (It is usually 0, but can vary such as for XLNet where it is 4)
55 | mask_padding_with_zero: If set to ``True``, the attention mask will be filled by ``1`` for actual values
56 | and by ``0`` for padded values. If set to ``False``, inverts it (``1`` for padded values, ``0`` for
57 | actual values)
58 |
59 | Returns:
60 | If the ``examples`` input is a ``tf.data.Dataset``, will return a ``tf.data.Dataset``
61 | containing the task-specific features. If the input is a list of ``InputExamples``, will return
62 | a list of task-specific ``InputFeatures`` which can be fed to the model.
63 |
64 | """
65 | is_tf_dataset = False
66 | if is_tf_available() and isinstance(examples, tf.data.Dataset):
67 | is_tf_dataset = True
68 |
69 | label_map = {label: i for i, label in enumerate(label_list)}
70 |
71 | features = []
72 | if examples[0].text_b is not None:
73 | k = len(examples[0].text_b)
74 | if sample_negatives:
75 | neg_indices = [np.random.choice(len(examples), size=len(examples), replace=False) for i in range(k)]
76 | for (ex_index, example) in enumerate(examples):
77 | if ex_index % 10000 == 0:
78 | logger.info("Writing example %d" % (ex_index))
79 | if is_tf_dataset:
80 | example = processor.get_example_from_tensor_dict(example)
81 |
82 | if type(example.text_a) is list:
83 | text_a = example.text_a
84 | text_b = [example.text_b]*len(text_a)
85 | elif type(example.text_b) is list:
86 | text_b = example.text_b
87 | if sample_negatives:
88 | label_idx = label_map[example.label]
89 | text_b_neg = [(examples[neg_indices[i][ex_index]]).text_b[label_idx] for i in range(k)]
90 | text_b_neg[label_idx] = text_b[label_idx]
91 |
92 | text_a = [example.text_a]*len(text_b)
93 | else:
94 | text_a = [example.text_a]
95 | text_b = [example.text_b]
96 |
97 | if 0: #sample_negatives:
98 | print ('Created negative samples')
99 | print ('Original example: label:{} text_a: {} text_b1: {}, 2: {}, 3:{}'.format(example.label, text_a[0], text_b[0], text_b[1], text_b[2]))
100 | print ('Converted example: text_a: {} text_b1: {}, 2: {}, 3:{}'.format(text_a[0], text_b_neg[0], text_b_neg[1], text_b_neg[2]))
101 |
102 | def get_indices(t1, t2):
103 | out = []
104 | for a,b in zip(t1, t2):
105 | inputs = tokenizer.encode_plus(
106 | a,
107 | b,
108 | add_special_tokens=True,
109 | max_length=max_length,
110 | )
111 | input_ids, token_type_ids = inputs["input_ids"], inputs["token_type_ids"]
112 |
113 | # The mask has 1 for real tokens and 0 for padding tokens. Only real
114 | # tokens are attended to.
115 | attention_mask = [1 if mask_padding_with_zero else 0] * len(input_ids)
116 |
117 | # Zero-pad up to the sequence length.
118 | padding_length = max_length - len(input_ids)
119 | if pad_on_left:
120 | input_ids = ([pad_token] * padding_length) + input_ids
121 | attention_mask = ([0 if mask_padding_with_zero else 1] * padding_length) + attention_mask
122 | token_type_ids = ([pad_token_segment_id] * padding_length) + token_type_ids
123 | else:
124 | input_ids = input_ids + ([pad_token] * padding_length)
125 | attention_mask = attention_mask + ([0 if mask_padding_with_zero else 1] * padding_length)
126 | token_type_ids = token_type_ids + ([pad_token_segment_id] * padding_length)
127 |
128 | assert len(input_ids) == max_length, "Error with input length {} vs {}".format(len(input_ids), max_length)
129 | assert len(attention_mask) == max_length, "Error with input length {} vs {}".format(len(attention_mask), max_length)
130 | assert len(token_type_ids) == max_length, "Error with input length {} vs {}".format(len(token_type_ids), max_length)
131 | out.append((input_ids, attention_mask, token_type_ids))
132 |
133 | if len(t1) == 1:
134 | input_ids, attention_mask, token_type_ids = out[0]
135 | else:
136 | input_ids, attention_mask, token_type_ids = zip(*out)
137 | return input_ids, attention_mask, token_type_ids
138 |
139 | input_ids, attention_mask, token_type_ids = get_indices(text_a, text_b)
140 | if sample_negatives:
141 | input_ids_n, attention_mask_n, token_type_ids_n = get_indices(text_a, text_b_neg)
142 |
143 | label = label_map[example.label]
144 |
145 | if ex_index < 5:
146 | logger.info("*** Example ***")
147 | logger.info("guid: %s" % (example.guid))
148 | logger.info("input_ids: %s" % " ".join([str(x) for x in input_ids]))
149 | logger.info("attention_mask: %s" % " ".join([str(x) for x in attention_mask]))
150 | logger.info("token_type_ids: %s" % " ".join([str(x) for x in token_type_ids]))
151 | logger.info("label: %s (id = %d)" % (example.label, label))
152 |
153 | features.append(
154 | InputFeatures(input_ids=input_ids,
155 | attention_mask=attention_mask,
156 | token_type_ids=token_type_ids,
157 | label=label))
158 |
159 | if sample_negatives:
160 | features.append(
161 | InputFeatures(input_ids=input_ids_n,
162 | attention_mask=attention_mask_n,
163 | token_type_ids=token_type_ids_n,
164 | label=label))
165 |
166 | if is_tf_available() and is_tf_dataset:
167 | def gen():
168 | for ex in features:
169 | yield ({'input_ids': ex.input_ids,
170 | 'attention_mask': ex.attention_mask,
171 | 'token_type_ids': ex.token_type_ids},
172 | ex.label)
173 |
174 | return tf.data.Dataset.from_generator(gen,
175 | ({'input_ids': tf.int32,
176 | 'attention_mask': tf.int32,
177 | 'token_type_ids': tf.int32},
178 | tf.int64),
179 | ({'input_ids': tf.TensorShape([None]),
180 | 'attention_mask': tf.TensorShape([None]),
181 | 'token_type_ids': tf.TensorShape([None])},
182 | tf.TensorShape([])))
183 |
184 | return features
185 |
--------------------------------------------------------------------------------
/fetch_data.sh:
--------------------------------------------------------------------------------
1 | mkdir ./external
2 |
3 | #e-snli
4 | git clone https://github.com/OanaMariaCamburu/e-SNLI.git ./external/esnli
5 | python merge_esnli_train.py ./external/esnli
6 |
7 | #mnli, from glue
8 | wget -O ./external/MNLI.zip https://dl.fbaipublicfiles.com/glue/data/MNLI.zip
9 | unzip -d ./external/ ./external/MNLI.zip
10 |
11 | mkdir cache
12 | mkdir saved_lm
13 | mkdir saved_gen
14 | mkdir saved_clf
15 |
--------------------------------------------------------------------------------
/finetune_lm.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
4 | # Modifications copyright (c) 2020 Sawan Kumar
5 | #
6 | # Licensed under the Apache License, Version 2.0 (the "License");
7 | # you may not use this file except in compliance with the License.
8 | # You may obtain a copy of the License at
9 | #
10 | # http://www.apache.org/licenses/LICENSE-2.0
11 | #
12 | # Unless required by applicable law or agreed to in writing, software
13 | # distributed under the License is distributed on an "AS IS" BASIS,
14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 | # See the License for the specific language governing permissions and
16 | # limitations under the License.
17 | #
18 | # Modification history
19 | # 2020 Sawan Kumar: Modified to finetune and generate explanations for NLI
20 |
21 | from __future__ import absolute_import, division, print_function
22 |
23 | import argparse
24 | import glob
25 | import logging
26 | import os
27 | import pickle
28 | import random
29 | import pandas as pd
30 |
31 | import numpy as np
32 | import torch
33 | import torch.nn.functional as F
34 | from torch.utils.data import DataLoader, Dataset, SequentialSampler, RandomSampler
35 | from torch.utils.data.distributed import DistributedSampler
36 | from tensorboardX import SummaryWriter
37 | from tqdm import tqdm, trange
38 |
39 | from transformers import (WEIGHTS_NAME, AdamW, get_linear_schedule_with_warmup,
40 | GPT2Config, GPT2LMHeadModel, GPT2Tokenizer)
41 |
42 | from lm_utils import TSVDataset, EXP_TOKEN, EOS_TOKEN
43 |
44 | logger = logging.getLogger(__name__)
45 |
46 | MODEL_CLASSES = {
47 | 'gpt2': (GPT2Config, GPT2LMHeadModel, GPT2Tokenizer),
48 | }
49 |
50 | cross_entropy_ignore_index = -1
51 |
52 | MAX_LENGTH = int(10000) # Hardcoded max length to avoid infinite loop
53 |
54 | def set_seed(args):
55 | random.seed(args.seed)
56 | np.random.seed(args.seed)
57 | torch.manual_seed(args.seed)
58 | if args.n_gpu > 0:
59 | torch.cuda.manual_seed_all(args.seed)
60 |
61 | def train(args, train_dataset, model, tokenizer):
62 | """ Train the model """
63 | if args.local_rank in [-1, 0]:
64 | tb_writer = SummaryWriter()
65 |
66 | args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)
67 | train_sampler = RandomSampler(train_dataset) if args.local_rank == -1 else DistributedSampler(train_dataset)
68 | train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size)
69 |
70 | if args.max_steps > 0:
71 | t_total = args.max_steps
72 | args.num_train_epochs = args.max_steps // (len(train_dataloader) // args.gradient_accumulation_steps) + 1
73 | else:
74 | t_total = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs
75 |
76 | # Prepare optimizer and schedule (linear warmup and decay)
77 | no_decay = ['bias', 'LayerNorm.weight']
78 | optimizer_grouped_parameters = [
79 | {'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': args.weight_decay},
80 | {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
81 | ]
82 | optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
83 | scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total)
84 | if args.fp16:
85 | try:
86 | from apex import amp
87 | except ImportError:
88 | raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
89 | model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16_opt_level)
90 |
91 | # multi-gpu training (should be after apex fp16 initialization)
92 | if args.n_gpu > 1:
93 | model = torch.nn.DataParallel(model)
94 |
95 | # Distributed training (should be after apex fp16 initialization)
96 | if args.local_rank != -1:
97 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank],
98 | output_device=args.local_rank,
99 | find_unused_parameters=True)
100 |
101 | # Train!
102 | logger.info("***** Running training *****")
103 | logger.info(" Num examples = %d", len(train_dataset))
104 | logger.info(" Num Epochs = %d", args.num_train_epochs)
105 | logger.info(" Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size)
106 | logger.info(" Total train batch size (w. parallel, distributed & accumulation) = %d",
107 | args.train_batch_size * args.gradient_accumulation_steps * (torch.distributed.get_world_size() if args.local_rank != -1 else 1))
108 | logger.info(" Gradient Accumulation steps = %d", args.gradient_accumulation_steps)
109 | logger.info(" Total optimization steps = %d", t_total)
110 |
111 | global_step = 0
112 | tr_loss, logging_loss = 0.0, 0.0
113 | model.zero_grad()
114 | train_iterator = trange(int(args.num_train_epochs), desc="Epoch", disable=args.local_rank not in [-1, 0])
115 | set_seed(args) # Added here for reproducibility (even between python 2 and 3)
116 | for _ in train_iterator:
117 | epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0])
118 | for step, batch in enumerate(epoch_iterator):
119 | batch, prompt_lengths, total_lengths = batch
120 | max_length = torch.max(total_lengths).item()
121 | batch = batch[:, :max_length]
122 | inputs, labels = (batch, batch.clone().detach())
123 | inputs = inputs.to(args.device)
124 | labels = labels.to(args.device)
125 | for idx in range(len(prompt_lengths)):
126 | labels[idx, :prompt_lengths[idx]] = cross_entropy_ignore_index
127 | model.train()
128 | outputs = model(inputs, labels=labels)
129 | loss = outputs[0] # model outputs are always tuple in transformers (see doc)
130 |
131 | if args.n_gpu > 1:
132 | loss = loss.mean() # mean() to average on multi-gpu parallel training
133 | if args.gradient_accumulation_steps > 1:
134 | loss = loss / args.gradient_accumulation_steps
135 |
136 | if args.fp16:
137 | with amp.scale_loss(loss, optimizer) as scaled_loss:
138 | scaled_loss.backward()
139 | else:
140 | loss.backward()
141 |
142 | tr_loss += loss.item()
143 | if (step + 1) % args.gradient_accumulation_steps == 0:
144 | if args.fp16:
145 | torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm)
146 | else:
147 | torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
148 | optimizer.step()
149 | scheduler.step() # Update learning rate schedule
150 | model.zero_grad()
151 | global_step += 1
152 |
153 | if args.local_rank in [-1, 0] and args.logging_steps > 0 and global_step % args.logging_steps == 0:
154 | # Log metrics
155 | if args.local_rank == -1 and args.evaluate_during_training: # Only evaluate when single GPU otherwise metrics may not average well
156 | results = evaluate(args, model, tokenizer)
157 | for key, value in results.items():
158 | tb_writer.add_scalar('eval_{}'.format(key), value, global_step)
159 | tb_writer.add_scalar('lr', scheduler.get_lr()[0], global_step)
160 | tb_writer.add_scalar('loss', (tr_loss - logging_loss)/args.logging_steps, global_step)
161 | logging_loss = tr_loss
162 |
163 | if args.local_rank in [-1, 0] and args.save_steps > 0 and global_step % args.save_steps == 0:
164 | # Save model checkpoint
165 | output_dir = os.path.join(args.output_dir, 'checkpoint-{}'.format(global_step))
166 | if not os.path.exists(output_dir):
167 | os.makedirs(output_dir)
168 | model_to_save = model.module if hasattr(model, 'module') else model # Take care of distributed/parallel training
169 | model_to_save.save_pretrained(output_dir)
170 | torch.save(args, os.path.join(output_dir, 'training_args.bin'))
171 | logger.info("Saving model checkpoint to %s", output_dir)
172 |
173 | if args.max_steps > 0 and global_step > args.max_steps:
174 | epoch_iterator.close()
175 | break
176 | if args.max_steps > 0 and global_step > args.max_steps:
177 | train_iterator.close()
178 | break
179 |
180 | if args.local_rank in [-1, 0]:
181 | tb_writer.close()
182 |
183 | return global_step, tr_loss / global_step
184 |
185 | def sample_sequence(model, length, context, device='cpu', eos_token_id=None):
186 | context = torch.tensor(context, dtype=torch.long, device=device)
187 | context = context.unsqueeze(0)
188 | generated = context
189 | past = None
190 | with torch.no_grad():
191 | for _ in range(length):
192 | #inputs = {'input_ids': context}
193 | #output, past = model(**inputs, past=past)
194 | inputs = {'input_ids': generated}
195 | output, past = model(**inputs)
196 | next_token_logits = output[0, -1, :]
197 | next_token = torch.argmax(next_token_logits)
198 | generated = torch.cat((generated, next_token.view(1,1)), dim=1)
199 | if next_token.item() == eos_token_id:
200 | break
201 | context = next_token.view(1,1)
202 | return generated
203 |
204 | def generate(args, model, tokenizer, prefix=""):
205 | if args.length < 0 and model.config.max_position_embeddings > 0:
206 | args.length = model.config.max_position_embeddings
207 | elif 0 < model.config.max_position_embeddings < args.length:
208 | args.length = model.config.max_position_embeddings # No generation bigger than model size
209 | elif args.length < 0:
210 | args.length = MAX_LENGTH # avoid infinite loop
211 |
212 | eval_output_dir = args.output_dir
213 | eval_dataset = TSVDataset(tokenizer, args, file_path=args.eval_data_file,
214 | block_size=args.block_size, get_annotations=False)
215 |
216 | eval_sampler = SequentialSampler(eval_dataset) if args.local_rank == -1 else DistributedSampler(eval_dataset)
217 | eval_dataloader = DataLoader(eval_dataset, sampler=eval_sampler, batch_size=1)
218 |
219 | # Eval!
220 | logger.info("***** Running generation {} *****".format(prefix))
221 | logger.info(" Num examples = %d", len(eval_dataset))
222 |
223 | model.eval()
224 |
225 | for index, batch in enumerate(tqdm(eval_dataloader, desc="Evaluating")):
226 | batch, prompt_lengths, total_lengths = batch
227 | batch = batch.squeeze()
228 | out = sample_sequence(
229 | model=model,
230 | context=batch,
231 | length=args.length,
232 | device=args.device,
233 | eos_token_id=tokenizer.convert_tokens_to_ids(EOS_TOKEN),
234 | )
235 | out = out[0, len(batch):].tolist()
236 | text = tokenizer.decode(out, clean_up_tokenization_spaces=True)
237 | text = text.split(EOS_TOKEN)[0].strip()
238 | eval_dataset.add_explanation(index, text)
239 | print (text)
240 |
241 | #save
242 | directory, filename = os.path.split(args.eval_data_file)
243 | model_directory, model_name = os.path.split(os.path.normpath(args.output_dir))
244 | output_name = os.path.join(directory, '{}_{}'.format(model_name, filename))
245 | eval_dataset.save(output_name)
246 |
247 | def evaluate(args, model, tokenizer, prefix=""):
248 | eval_output_dir = args.output_dir
249 | eval_dataset = TSVDataset(tokenizer, args, file_path=args.eval_data_file,
250 | block_size=args.block_size, get_annotations=True)
251 |
252 | if not os.path.exists(eval_output_dir) and args.local_rank in [-1, 0]:
253 | os.makedirs(eval_output_dir)
254 |
255 | args.eval_batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu)
256 | # Note that DistributedSampler samples randomly
257 | eval_sampler = SequentialSampler(eval_dataset) if args.local_rank == -1 else DistributedSampler(eval_dataset)
258 | eval_dataloader = DataLoader(eval_dataset, sampler=eval_sampler, batch_size=args.eval_batch_size)
259 |
260 | # Eval!
261 | logger.info("***** Running evaluation {} *****".format(prefix))
262 | logger.info(" Num examples = %d", len(eval_dataset))
263 | logger.info(" Batch size = %d", args.eval_batch_size)
264 | eval_loss = 0.0
265 | nb_eval_steps = 0
266 | model.eval()
267 |
268 | for batch in tqdm(eval_dataloader, desc="Evaluating"):
269 | batch, prompt_lengths, total_lengths = batch
270 | batch = batch.to(args.device)
271 |
272 | with torch.no_grad():
273 | outputs = model(batch, labels=batch)
274 | lm_loss = outputs[0]
275 | eval_loss += lm_loss.mean().item()
276 | nb_eval_steps += 1
277 |
278 | eval_loss = eval_loss / nb_eval_steps
279 | perplexity = torch.exp(torch.tensor(eval_loss))
280 |
281 | result = {
282 | "perplexity": perplexity
283 | }
284 |
285 | output_eval_file = os.path.join(eval_output_dir, "eval_results.txt")
286 | with open(output_eval_file, "w") as writer:
287 | logger.info("***** Eval results {} *****".format(prefix))
288 | for key in sorted(result.keys()):
289 | logger.info(" %s = %s", key, str(result[key]))
290 | writer.write("%s = %s\n" % (key, str(result[key])))
291 |
292 | return result
293 |
294 |
295 | def main():
296 | parser = argparse.ArgumentParser()
297 |
298 | ## Required parameters
299 | parser.add_argument("--train_data_file", default=None, type=str, required=True,
300 | help="The input training data file (a text file).")
301 | parser.add_argument("--output_dir", default=None, type=str, required=True,
302 | help="The output directory where the model predictions and checkpoints will be written.")
303 |
304 | ## Other parameters
305 | parser.add_argument("--eval_data_file", default=None, type=str,
306 | help="An optional input evaluation data file to evaluate the perplexity on (a text file).")
307 |
308 | parser.add_argument("--model_type", default="bert", type=str,
309 | help="The model architecture to be fine-tuned.")
310 | parser.add_argument("--model_name_or_path", default="bert-base-cased", type=str,
311 | help="The model checkpoint for weights initialization.")
312 |
313 | parser.add_argument("--config_name", default="", type=str,
314 | help="Optional pretrained config name or path if not the same as model_name_or_path")
315 | parser.add_argument("--tokenizer_name", default="", type=str,
316 | help="Optional pretrained tokenizer name or path if not the same as model_name_or_path")
317 | parser.add_argument("--cache_dir", default="", type=str,
318 | help="Optional directory to store the pre-trained models downloaded from s3 (instread of the default one)")
319 | parser.add_argument("--block_size", default=-1, type=int,
320 | help="Optional input sequence length after tokenization."
321 | "The training dataset will be truncated in block of this size for training."
322 | "Default to the model max input length for single sentence inputs (take into account special tokens).")
323 | parser.add_argument("--do_train", action='store_true',
324 | help="Whether to run training.")
325 | parser.add_argument("--do_eval", action='store_true',
326 | help="Whether to run eval on the eval data file")
327 | parser.add_argument("--do_generate", action='store_true',
328 | help="Whether to generate text on the eval data file")
329 | parser.add_argument("--length", type=int, default=100,
330 | help="Length for generation")
331 | parser.add_argument("--evaluate_during_training", action='store_true',
332 | help="Run evaluation during training at each logging step.")
333 | parser.add_argument("--do_lower_case", action='store_true',
334 | help="Set this flag if you are using an uncased model.")
335 |
336 | parser.add_argument("--per_gpu_train_batch_size", default=4, type=int,
337 | help="Batch size per GPU/CPU for training.")
338 | parser.add_argument("--per_gpu_eval_batch_size", default=4, type=int,
339 | help="Batch size per GPU/CPU for evaluation.")
340 | parser.add_argument('--gradient_accumulation_steps', type=int, default=1,
341 | help="Number of updates steps to accumulate before performing a backward/update pass.")
342 | parser.add_argument("--learning_rate", default=5e-5, type=float,
343 | help="The initial learning rate for Adam.")
344 | parser.add_argument("--weight_decay", default=0.0, type=float,
345 | help="Weight deay if we apply some.")
346 | parser.add_argument("--adam_epsilon", default=1e-8, type=float,
347 | help="Epsilon for Adam optimizer.")
348 | parser.add_argument("--max_grad_norm", default=1.0, type=float,
349 | help="Max gradient norm.")
350 | parser.add_argument("--num_train_epochs", default=1.0, type=float,
351 | help="Total number of training epochs to perform.")
352 | parser.add_argument("--max_steps", default=-1, type=int,
353 | help="If > 0: set total number of training steps to perform. Override num_train_epochs.")
354 | parser.add_argument("--warmup_steps", default=0, type=int,
355 | help="Linear warmup over warmup_steps.")
356 |
357 | parser.add_argument("--data_type", default="tsv", type=str,
358 | help="Dataset type")
359 |
360 | parser.add_argument('--logging_steps', type=int, default=50,
361 | help="Log every X updates steps.")
362 | parser.add_argument('--save_steps', type=int, default=50,
363 | help="Save checkpoint every X updates steps.")
364 | parser.add_argument("--eval_all_checkpoints", action='store_true',
365 | help="Evaluate all checkpoints starting with the same prefix as model_name_or_path ending and ending with step number")
366 | parser.add_argument("--no_cuda", action='store_true',
367 | help="Avoid using CUDA when available")
368 | parser.add_argument('--overwrite_output_dir', action='store_true',
369 | help="Overwrite the content of the output directory")
370 | parser.add_argument('--overwrite_cache', action='store_true',
371 | help="Overwrite the cached training and evaluation sets")
372 | parser.add_argument('--seed', type=int, default=42,
373 | help="random seed for initialization")
374 |
375 | parser.add_argument('--fp16', action='store_true',
376 | help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit")
377 | parser.add_argument('--fp16_opt_level', type=str, default='O1',
378 | help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
379 | "See details at https://nvidia.github.io/apex/amp.html")
380 | parser.add_argument("--local_rank", type=int, default=-1,
381 | help="For distributed training: local_rank")
382 | parser.add_argument('--server_ip', type=str, default='', help="For distant debugging.")
383 | parser.add_argument('--server_port', type=str, default='', help="For distant debugging.")
384 | args = parser.parse_args()
385 |
386 | if args.eval_data_file is None and args.do_eval:
387 | raise ValueError("Cannot do evaluation without an evaluation data file. Either supply a file to --eval_data_file "
388 | "or remove the --do_eval argument.")
389 |
390 | if os.path.exists(args.output_dir) and os.listdir(args.output_dir) and args.do_train and not args.overwrite_output_dir:
391 | raise ValueError("Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome.".format(args.output_dir))
392 |
393 | # Setup distant debugging if needed
394 | if args.server_ip and args.server_port:
395 | # Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script
396 | import ptvsd
397 | print("Waiting for debugger attach")
398 | ptvsd.enable_attach(address=(args.server_ip, args.server_port), redirect_output=True)
399 | ptvsd.wait_for_attach()
400 |
401 | # Setup CUDA, GPU & distributed training
402 | if args.local_rank == -1 or args.no_cuda:
403 | device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
404 | args.n_gpu = torch.cuda.device_count()
405 | else: # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
406 | torch.cuda.set_device(args.local_rank)
407 | device = torch.device("cuda", args.local_rank)
408 | torch.distributed.init_process_group(backend='nccl')
409 | args.n_gpu = 1
410 | args.device = device
411 |
412 | # Setup logging
413 | logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
414 | datefmt = '%m/%d/%Y %H:%M:%S',
415 | level = logging.INFO if args.local_rank in [-1, 0] else logging.WARN)
416 | logger.warning("Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
417 | args.local_rank, device, args.n_gpu, bool(args.local_rank != -1), args.fp16)
418 |
419 | # Set seed
420 | set_seed(args)
421 |
422 | # Load pretrained model and tokenizer
423 | if args.local_rank not in [-1, 0]:
424 | torch.distributed.barrier() # Barrier to make sure only the first process in distributed training download model & vocab
425 |
426 | config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
427 | config = config_class.from_pretrained(args.config_name if args.config_name else args.model_name_or_path,
428 | cache_dir=args.cache_dir if args.cache_dir else None)
429 | tokenizer = tokenizer_class.from_pretrained(args.tokenizer_name if args.tokenizer_name else args.model_name_or_path, do_lower_case=args.do_lower_case,
430 | cache_dir=args.cache_dir if args.cache_dir else None)
431 | if args.block_size <= 0:
432 | args.block_size = tokenizer.max_len_single_sentence # Our input block size will be the max possible for the model
433 | args.block_size = min(args.block_size, tokenizer.max_len_single_sentence)
434 |
435 | if args.do_train:
436 | #Additional tokens
437 | print ('#tokens', len(tokenizer))
438 | new_tokens = [EXP_TOKEN, EOS_TOKEN]
439 | tokenizer.add_tokens(new_tokens)
440 | print ('#extended tokens', len(tokenizer))
441 | model = model_class.from_pretrained(args.model_name_or_path, from_tf=bool('.ckpt' in args.model_name_or_path), config=config,
442 | cache_dir=args.cache_dir if args.cache_dir else None)
443 | model.resize_token_embeddings(len(tokenizer))
444 | model.to(args.device)
445 |
446 | if args.local_rank == 0:
447 | torch.distributed.barrier() # End of barrier to make sure only the first process in distributed training download model & vocab
448 |
449 | logger.info("Training/evaluation parameters %s", args)
450 |
451 | # Training
452 | if args.do_train:
453 | if args.local_rank not in [-1, 0]:
454 | torch.distributed.barrier() # Barrier to make sure only the first process in distributed training process the dataset, and the others will use the cache
455 |
456 | train_dataset = TSVDataset(tokenizer, args, file_path=args.train_data_file,
457 | block_size=args.block_size, get_annotations=True)
458 |
459 | if args.local_rank == 0:
460 | torch.distributed.barrier()
461 |
462 | global_step, tr_loss = train(args, train_dataset, model, tokenizer)
463 | logger.info(" global_step = %s, average loss = %s", global_step, tr_loss)
464 |
465 |
466 | # Saving best-practices: if you use save_pretrained for the model and tokenizer, you can reload them using from_pretrained()
467 | if args.do_train and (args.local_rank == -1 or torch.distributed.get_rank() == 0):
468 | # Create output directory if needed
469 | if not os.path.exists(args.output_dir) and args.local_rank in [-1, 0]:
470 | os.makedirs(args.output_dir)
471 |
472 | logger.info("Saving model checkpoint to %s", args.output_dir)
473 | # Save a trained model, configuration and tokenizer using `save_pretrained()`.
474 | # They can then be reloaded using `from_pretrained()`
475 | model_to_save = model.module if hasattr(model, 'module') else model # Take care of distributed/parallel training
476 | model_to_save.save_pretrained(args.output_dir)
477 | tokenizer.save_pretrained(args.output_dir)
478 |
479 | # Good practice: save your training arguments together with the trained model
480 | torch.save(args, os.path.join(args.output_dir, 'training_args.bin'))
481 |
482 | # Load a trained model and vocabulary that you have fine-tuned
483 | model = model_class.from_pretrained(args.output_dir)
484 | tokenizer = tokenizer_class.from_pretrained(args.output_dir, do_lower_case=args.do_lower_case)
485 | model.to(args.device)
486 |
487 | # Evaluation
488 | if args.do_eval and args.local_rank in [-1, 0]:
489 | model = model_class.from_pretrained(args.output_dir)
490 | model.to(args.device)
491 | tokenizer = tokenizer_class.from_pretrained(args.output_dir, do_lower_case=args.do_lower_case)
492 | result = evaluate(args, model, tokenizer)
493 | print (result)
494 |
495 | #Generation
496 | if args.do_generate:
497 | model = model_class.from_pretrained(args.output_dir)
498 | model.to(args.device)
499 | tokenizer = tokenizer_class.from_pretrained(args.output_dir, do_lower_case=args.do_lower_case)
500 | generate(args, model, tokenizer)
501 |
502 | if __name__ == "__main__":
503 | main()
504 |
--------------------------------------------------------------------------------
/generated_explanations/snli_dev.csv.gz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SawanKumar28/nile/ef7eb47dd49afff6855358901afca24de27f0eae/generated_explanations/snli_dev.csv.gz
--------------------------------------------------------------------------------
/generated_explanations/snli_test.csv.gz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SawanKumar28/nile/ef7eb47dd49afff6855358901afca24de27f0eae/generated_explanations/snli_test.csv.gz
--------------------------------------------------------------------------------
/images/architecture.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SawanKumar28/nile/ef7eb47dd49afff6855358901afca24de27f0eae/images/architecture.jpg
--------------------------------------------------------------------------------
/lm_utils.py:
--------------------------------------------------------------------------------
1 | import pandas as pd
2 | import os
3 | import pickle
4 | import torch
5 |
6 | from torch.utils.data import DataLoader, Dataset
7 |
8 | EXP_TOKEN = '[EXP]'
9 | EOS_TOKEN = '[EOS]'
10 | class TSVDataset(Dataset):
11 | def __init__(self, tokenizer, args, file_path='train', block_size=512, get_annotations=False):
12 | self.print_count = 5
13 | self.eos_token_id = tokenizer.convert_tokens_to_ids(EOS_TOKEN)
14 |
15 | cached_features_file, data = self.load_data(file_path, block_size)
16 | self.data = data
17 |
18 | if get_annotations: cached_features_file = cached_features_file + '_annotated'
19 |
20 | if os.path.exists(cached_features_file):
21 | print ('Loading features from', cached_features_file)
22 | with open(cached_features_file, 'rb') as handle:
23 | self.examples = pickle.load(handle)
24 | return
25 |
26 | print ('Saving features from ', file_path, ' into ', cached_features_file)
27 |
28 | def create_example(r):
29 | text1 = '{} {} '.format(r['input'], EXP_TOKEN)
30 | tokenized_text1 = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(text1))
31 | prompt_length = len(tokenized_text1)
32 | tokenized_text, total_length = tokenized_text1, len(tokenized_text1)
33 | if get_annotations:
34 | text2 = r['target']
35 | tokenized_text2 = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(text2))
36 | tokenized_text = tokenized_text1 + tokenized_text2
37 | tokenized_text = tokenized_text + [self.eos_token_id]
38 | total_length = len(tokenized_text)
39 | if len(tokenized_text) > block_size:
40 | tokenized_text = tokenized_text[:block_size]
41 | if len(tokenized_text) < block_size:
42 | tokenized_text = tokenized_text + [self.eos_token_id] * (block_size-len(tokenized_text))
43 | if self.print_count > 0:
44 | print ('example: ', text1 + text2 if get_annotations else text1)
45 | self.print_count = self.print_count - 1
46 | return (tokenized_text, prompt_length, total_length)
47 |
48 | self.examples = data.apply(create_example, axis=1).to_list()
49 | print ('Saving ', len(self.examples), ' examples')
50 | with open(cached_features_file, 'wb') as handle:
51 | pickle.dump(self.examples, handle, protocol=pickle.HIGHEST_PROTOCOL)
52 |
53 | def __len__(self):
54 | return len(self.examples)
55 |
56 | def __getitem__(self, item):
57 | return torch.tensor(self.examples[item][0]), self.examples[item][1], self.examples[item][2]
58 |
59 | def get_example_text(self, index):
60 | return self.data['prompt'][index]
61 |
62 | def add_explanation(self, index, explanation):
63 | explanation_name = 'Generated_Explanation'
64 | self.data.at[self.data.index[index], explanation_name] = explanation
65 |
66 | def load_data(self, file_path, block_size):
67 | assert os.path.isfile(file_path)
68 | data = pd.read_csv(file_path, sep='\t', index_col='pairID')
69 | print (data)
70 | directory, filename = os.path.split(file_path)
71 | cached_features_file = os.path.join(directory, 'cached_lm_{}_{}'.format(block_size, filename))
72 | return cached_features_file, data
73 |
74 | def save(self, filename):
75 | self.data.to_csv(filename, sep='\t')
76 |
--------------------------------------------------------------------------------
/merge_esnli_train.py:
--------------------------------------------------------------------------------
1 | import pandas as pd
2 | import os
3 | import sys
4 |
5 | esnli_path = sys.argv[1]
6 |
7 | #Merge train files
8 | merged_train_file_path = os.path.join(esnli_path, "dataset", "esnli_train.csv")
9 | train1_path = os.path.join(esnli_path, "dataset", "esnli_train_1.csv")
10 | train2_path = os.path.join(esnli_path, "dataset", "esnli_train_2.csv")
11 | d1 = pd.read_csv(train1_path, index_col="pairID")
12 | d2 = pd.read_csv(train2_path, index_col="pairID")
13 |
14 | d = pd.concat([d1, d2], 0, sort=False)
15 | d_nna = d.dropna()
16 |
17 | print ("Merging train files with lengths ", len(d1), len(d2), ", output length", len(d_nna))
18 | d_nna.to_csv(merged_train_file_path)
19 |
--------------------------------------------------------------------------------
/nli_utils.py:
--------------------------------------------------------------------------------
1 | from transformers.data.processors.utils import DataProcessor, InputExample, InputFeatures
2 | import pandas as pd
3 | import numpy as np
4 |
5 | class ExpProcessor(DataProcessor):
6 | s1 = 'sentence1'
7 | s2 = 'sentence2'
8 | index_col = "pairID"
9 | labels = ["entailment", "contradiction", "neutral"]
10 | gold_label = "gold_label"
11 | def get_train_examples(self, filepath, data_format="instance", to_drop=[]):
12 | data = pd.read_csv(filepath, index_col=self.index_col)
13 | examples = self._create_examples(data, 'train', data_format=data_format, to_drop=to_drop)
14 | return examples
15 |
16 | def get_dev_examples(self, filepath, data_format="instance", to_drop=[]):
17 | data = pd.read_csv(filepath, index_col=self.index_col)
18 | examples = self._create_examples(data, 'dev', data_format=data_format, to_drop=to_drop)
19 | return examples
20 |
21 | def get_labels(self):
22 | return self.labels
23 |
24 | data_formats = ["instance", "independent", "append", "instance_independent", "instance_append",
25 | "all_explanation", "Explanation_1"]
26 | #aggregate uses the same format as independent
27 | def _create_examples(self, labeled_examples, set_type, data_format="instance", to_drop=[]):
28 | """Creates examples for the training and dev sets."""
29 | if data_format not in self.data_formats:
30 | raise ValueError("Data format {} not supported".format(data_format))
31 |
32 | if 'explanation' in to_drop: to_drop = self.labels
33 |
34 | keep_labels = [True if l not in to_drop else False for l in self.labels]
35 | exp_names = ["{}_explanation".format(l) for l in self.labels]
36 |
37 | examples = []
38 | for (idx, le) in labeled_examples.iterrows():
39 | guid = idx
40 | label = le[self.gold_label]
41 |
42 | if data_format in ["independent", "instance_independent"]:
43 | exp_text = [le[exp_name] if keep else ""
44 | for l, keep, exp_name in zip(self.labels, keep_labels, exp_names)]
45 | elif data_format in ["append", "instance_append"]:
46 | exp_text = " ".join(["{}: {}".format(l, le[exp_name]) if keep else ""
47 | for l, keep, exp_name in zip(self.labels, keep_labels, exp_names)])
48 |
49 | if data_format == "instance":
50 | text_a, text_b = le[self.s1], le[self.s2]
51 | elif data_format in ["Explanation_1", "all_explanation"]:
52 | text_a, text_b = le[data_format], None
53 | elif data_format in ["independent", "append"]:
54 | text_a, text_b = exp_text, None
55 | elif data_format in ["instance_independent", "instance_append"]:
56 | instance_text = "Premise: {} Hypothesis: {}".format(
57 | le[self.s1], le[self.s2]) if "instance" not in to_drop else "Premise: Hypothesis:"
58 | text_a, text_b = instance_text, exp_text
59 |
60 | examples.append(
61 | InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
62 | return examples
63 |
64 | def simple_accuracy(preds, labels):
65 | return (preds == labels).mean()
66 |
67 | def exp_compute_metrics(preds, labels):
68 | assert len(preds) == len(labels)
69 | return {"acc": simple_accuracy(preds, labels)}
70 |
--------------------------------------------------------------------------------
/prepare_train_test.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import argparse
4 | import numpy as np
5 | from tqdm import tqdm
6 | import pandas as pd
7 |
8 | if __name__ == "__main__":
9 |
10 | parser = argparse.ArgumentParser()
11 |
12 | parser.add_argument("--dataset", type=str, default="snli") #snli, mnli
13 | parser.add_argument("--create_data", action="store_true")
14 | parser.add_argument("--filter_repetitions", action="store_true")
15 |
16 | #For merging
17 | parser.add_argument("--merge_data", action="store_true")
18 | parser.add_argument("--merge_single", action="store_true")
19 | parser.add_argument("--split", type=str, default="train")
20 | parser.add_argument("--input_prefix", type=str, default="dummy")
21 |
22 | #for shuffled evaluation
23 | parser.add_argument("--shuffle", action="store_true")
24 |
25 | args = parser.parse_args()
26 | tqdm.pandas()
27 |
28 | args.dataset = args.dataset.lower()
29 |
30 | s1,s2 = 'sentence1', 'sentence2'
31 | index_col = 'pairID'
32 | gold_label = 'gold_label'
33 | data_labels = ['entailment', 'neutral', 'contradiction']
34 | if args.dataset == 'snli':
35 | input_path = './external/esnli/dataset'
36 | filenames = {
37 | 'dev': 'esnli_dev.csv',
38 | 'train': 'esnli_train.csv',
39 | 'test': 'esnli_test.csv'
40 | }
41 | sep = ','
42 | data_index_col = 'pairID'
43 | data_gold_label = 'gold_label'
44 | quotechar = '"'
45 | quoting = 0
46 | data_s1,data_s2 = 'Sentence1', 'Sentence2'
47 | label_map = None
48 | skip_segregation = False
49 | explanation_available = True
50 | e1 = 'Explanation_1'
51 | output_root = './dataset_snli'
52 | elif args.dataset == 'mnli':
53 | input_path = './external/MNLI'
54 | filenames = {
55 | 'train': 'train.tsv',
56 | 'dev': 'dev_matched.tsv',
57 | 'dev_mm': 'dev_mismatched.tsv'
58 | }
59 | sep = '\t'
60 | data_index_col = 'pairID'
61 | data_gold_label = 'gold_label'
62 | quotechar = None
63 | quoting = 3
64 | data_s1,data_s2 = 'sentence1', 'sentence2'
65 | label_map = None
66 | skip_segregation = True
67 | explanation_available = False
68 | output_root = './dataset_mnli'
69 | else:
70 | raise ValueError("dataset not supported")
71 |
72 | if args.create_data:
73 | data = {}
74 | for split in filenames:
75 | data[split] = pd.read_csv(os.path.join(input_path, filenames[split]),
76 | index_col=data_index_col, sep=sep, quotechar=quotechar, quoting=quoting)
77 | data[split] = data[split].rename(columns={data_s1:s1, data_s2:s2, data_gold_label:gold_label})
78 | data[split].index.name = index_col
79 | if label_map: data[split][gold_label] = data[split][gold_label].apply(label_map)
80 |
81 | print ('\n Split {} Len {}'.format(split, len(data[split])))
82 | print (data[split][gold_label].value_counts())
83 |
84 | if args.filter_repetitions and split == "train" and args.dataset == 'snli':
85 | print ("Filtering repetitions")
86 | def has_repetition(r):
87 | exp = r[e1].lower()
88 | p = r[s1].lower()
89 | h = r[s2].lower()
90 | return True if p in exp or h in exp else False
91 | cond = data[split].apply(has_repetition, axis=1)
92 | print ("#cases with repetitions:", cond.sum())
93 | data[split] = data[split][cond==False]
94 | print ('Updated Split {} Len {}'.format(split, len(data[split])))
95 | print (data[split][gold_label].value_counts())
96 | cond = data[split].apply(has_repetition, axis=1)
97 | print ("#cases with repetitions:", cond.sum())
98 |
99 | label = "all"
100 | examples = data[split]
101 | dpath = os.path.join(output_root, label)
102 | os.makedirs(dpath) if not os.path.exists(dpath) else None
103 | fname = os.path.join(dpath, '{}_data.csv'.format(split))
104 | examples.to_csv(fname)
105 |
106 | if not skip_segregation:
107 | for label in data_labels:
108 | for split in filenames:
109 | dpath = os.path.join(output_root, label)
110 | os.makedirs(dpath) if not os.path.exists(dpath) else None
111 | fname = os.path.join(dpath, '{}_data.csv'.format(split))
112 | examples = data[split][data[split][gold_label] == label]
113 | print ('Saving {} | {} | {} examples'.format(label, split, len(examples)))
114 | examples.to_csv(fname)
115 |
116 | def generate_prompt(r):
117 | inp = 'Premise: {} Hypothesis: {}'.format(r[s1], r[s2])
118 | return inp
119 |
120 | if args.create_data:
121 | labels = ["all"]
122 | if not skip_segregation: labels.extend(data_labels)
123 |
124 | for label in labels:
125 | dpath = os.path.join(output_root, label)
126 | for split in filenames:
127 | fname = os.path.join(dpath, '{}_data.csv'.format(split))
128 | examples = pd.read_csv(fname, index_col=index_col)
129 | print ('Processing {} | {} | {} examples'.format(label, split, len(examples)))
130 | print (examples[gold_label].value_counts())
131 | examples['input'] = examples[[s1, s2]].progress_apply(generate_prompt, axis=1)
132 | columns_to_write = ['input']
133 | if explanation_available:
134 | examples['target'] = examples[e1]
135 | columns_to_write.append('target')
136 | print ('Writing')
137 | fname = os.path.join(dpath, '{}.tsv'.format(split))
138 | examples[columns_to_write].to_csv(fname, sep='\t')
139 |
140 | if args.merge_data:
141 | if args.merge_single:
142 | suffixes = ["all"]
143 | output_suffix = "_all"
144 | else:
145 | suffixes = ["entailment", "contradiction", "neutral"]
146 | output_suffix = ""
147 | split = args.split
148 | fname_csv = os.path.join(output_root, 'all', '{}_data.csv'.format(split))
149 | d_csv = pd.read_csv(fname_csv, index_col=index_col)
150 |
151 | for s in suffixes:
152 | fname_tsv = os.path.join(output_root, 'all', '{}{}_{}.tsv'.format(args.input_prefix, s, split))
153 | d_tsv = pd.read_csv(fname_tsv, index_col=index_col, sep='\t')
154 |
155 | d_csv['{}_explanation'.format(s)] = d_tsv['Generated_Explanation']
156 |
157 |
158 | print (d_csv.head(5))
159 | fname = os.path.join(output_root, 'all', '{}{}_{}{}.csv'.format(args.input_prefix, split, "merged", output_suffix))
160 | d_csv.to_csv(fname)
161 |
162 | if args.shuffle:
163 | split = args.split
164 | fname = os.path.join(output_root, 'all', '{}{}_{}.csv'.format(args.input_prefix, split, "merged"))
165 | d_csv = pd.read_csv(fname, index_col=index_col)
166 |
167 | d_csv_shuffled = d_csv.copy()
168 | for l in ["entailment", "contradiction", "neutral"]:
169 | d_csv_shuffled['{}_explanation'.format(l)] = np.random.choice(
170 | d_csv_shuffled['{}_explanation'.format(l)].values,
171 | len(d_csv_shuffled), replace=False)
172 | fname_csv_out = os.path.join(output_root, 'all', '{}shuffled{}_{}.csv'.format(args.input_prefix, split, "merged"))
173 | d_csv_shuffled.to_csv(fname_csv_out)
174 |
175 |
176 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | pandas
2 | tqdm
3 | torch
4 | transformers==2.3.0
5 | tensorboardX
6 |
--------------------------------------------------------------------------------
/run_clf.sh:
--------------------------------------------------------------------------------
1 | GPUDEV=$1
2 | SEED=$2
3 | DATASET=$3
4 | EXPMODEL=$4
5 | DATAFORMAT=$5
6 | INPREFIX=$6
7 | TRAIN=$7
8 | EVAL=$8
9 | SAMPLENEGS=$9
10 | TODROP="${10}"
11 | MODELPATH="${11}"
12 |
13 | MODELTYPE=roberta
14 | MODELNAME=roberta-base
15 | if [ "$DATAFORMAT" == "instance" ] || [ "$DATAFORMAT" == "Explanation_1" ]
16 | then
17 | SEQLEN=100
18 | BSZ=32
19 | INPREFIX=""
20 | INSUFFIX="_data"
21 | elif [ "$DATAFORMAT" == "all_explanation" ]
22 | then
23 | SEQLEN=100
24 | BSZ=32
25 | INSUFFIX="_merged_all"
26 | elif [ "$DATAFORMAT" == "independent" ] || [ "$DATAFORMAT" == "aggregate" ]
27 | then
28 | SEQLEN=50
29 | BSZ=32
30 | INSUFFIX="_merged"
31 | elif [ "$DATAFORMAT" == "append" ]
32 | then
33 | SEQLEN=100
34 | BSZ=32
35 | INSUFFIX="_merged"
36 | elif [ "$DATAFORMAT" == "instance_independent" ] || [ "$DATAFORMAT" == "instance_aggregate" ]
37 | then
38 | SEQLEN=100
39 | BSZ=16
40 | INSUFFIX="_merged"
41 | elif [ "$DATAFORMAT" == "instance_append" ]
42 | then
43 | SEQLEN=200
44 | BSZ=16
45 | INSUFFIX="_merged"
46 | fi
47 | NEPOCHS=3
48 |
49 | TRAINFILE=./dataset_"$DATASET"/all/"$INPREFIX""$TRAIN""$INSUFFIX".csv
50 | if [ "$TRAIN" == "_" ]
51 | then
52 | TRAINCMD=""
53 | else
54 | TRAINCMD="--do_train"
55 | fi
56 | EVALFILE=./dataset_"$DATASET"/all/"$INPREFIX""$EVAL""$INSUFFIX".csv
57 |
58 | if [ "$SAMPLENEGS" == "sample" ]
59 | then
60 | SAMPLECMD="--sample_negs"
61 | SAMPLESTR="_negs"
62 | else
63 | SAMPLECMD=""
64 | SAMPLESTR=""
65 | fi
66 |
67 | if [ "$TODROP" == "_" ]
68 | then
69 | TODROPCMD=""
70 | else
71 | TODROPCMD="--to_drop "$TODROP""
72 | fi
73 |
74 | if [ "$MODELPATH" == "_" ]
75 | then
76 | OUTPUTDIR="./saved_clf/seed"$SEED"_"$DATASET"_"$EXPMODEL"_"$BSZ"_"$SEQLEN"_"$NEPOCHS"_"$INPREFIX""$SAMPLESTR""
77 | else
78 | OUTPUTDIR="$MODELPATH"
79 | fi
80 |
81 |
82 | cmd="CUDA_VISIBLE_DEVICES=$GPUDEV python run_nli.py "$SAMPLECMD" "$TODROPCMD" \
83 | --cache_dir ../nile_release/cache \
84 | --seed "$SEED" \
85 | --model_type "$MODELTYPE" \
86 | --model_name_or_path "$MODELNAME" \
87 | --exp_model "$EXPMODEL" \
88 | --data_format "$DATAFORMAT" \
89 | "$TRAINCMD" --save_steps 1523000000000000 \
90 | --do_eval --eval_all_checkpoints \
91 | --do_lower_case \
92 | --train_file "$TRAINFILE" --eval_file "$EVALFILE" \
93 | --max_seq_length "$SEQLEN" \
94 | --per_gpu_eval_batch_size="$BSZ" \
95 | --per_gpu_train_batch_size="$BSZ" \
96 | --learning_rate 2e-5 \
97 | --num_train_epochs "$NEPOCHS" \
98 | --logging_steps 5000 --evaluate_during_training \
99 | --output_dir "$OUTPUTDIR""
100 | echo $cmd
101 | eval $cmd
102 |
--------------------------------------------------------------------------------
/run_finetune_gpt2m.sh:
--------------------------------------------------------------------------------
1 | GPUDEV=$1
2 | DATAROOT=$2
3 | BSZ=$3
4 | cmd="CUDA_VISIBLE_DEVICES="$GPUDEV" python finetune_lm.py \
5 | --cache_dir ./cache \
6 | --output_dir=./saved_lm/gpt2_m_"$DATAROOT" \
7 | --per_gpu_train_batch_size $BSZ
8 | --per_gpu_eval_batch_size $BSZ \
9 | --model_type=gpt2 \
10 | --model_name_or_path=gpt2-medium \
11 | --do_train \
12 | --block_size 128 \
13 | --save_steps 6866800 \
14 | --num_train_epochs 3 \
15 | --train_data_file=./dataset_snli/"$DATAROOT"/train.tsv \
16 | --do_eval \
17 | --eval_data_file=./dataset_snli/"$DATAROOT"/dev.tsv"
18 | echo $cmd
19 | eval $cmd
20 |
--------------------------------------------------------------------------------
/run_generate_gpt2m.sh:
--------------------------------------------------------------------------------
1 | GPUDEV=$1
2 | MODELSPLIT=$2
3 | DATASET=$3
4 | DATAROOT=$4
5 | DATASPLIT=$5
6 | cmd="CUDA_VISIBLE_DEVICES="$GPUDEV" python finetune_lm.py \
7 | --do_generate \
8 | --cache_dir ./cache \
9 | --output_dir=./saved_lm/gpt2_m_"$MODELSPLIT" \
10 | --model_type=gpt2 \
11 | --model_name_or_path=gpt2-medium \
12 | --block_size 128 \
13 | --save_steps 6866800 \
14 | --num_train_epochs 3 \
15 | --train_data_file=./dataset_"$DATASET"/"$DATAROOT"/train.tsv \
16 | --eval_data_file=./dataset_"$DATASET"/"$DATAROOT"/"$DATASPLIT".tsv"
17 | echo $cmd
18 | eval $cmd
19 |
--------------------------------------------------------------------------------
/run_nli.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
4 | # Modifications copyright (c) 2020 Sawan Kumar
5 | #
6 | # Licensed under the Apache License, Version 2.0 (the "License");
7 | # you may not use this file except in compliance with the License.
8 | # You may obtain a copy of the License at
9 | #
10 | # http://www.apache.org/licenses/LICENSE-2.0
11 | #
12 | # Unless required by applicable law or agreed to in writing, software
13 | # distributed under the License is distributed on an "AS IS" BASIS,
14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 | # See the License for the specific language governing permissions and
16 | # limitations under the License.
17 | #""" Finetuning the library models for sequence classification on GLUE (Bert, XLM, XLNet, RoBERTa)."""
18 | #
19 | # Modification history
20 | # 2020 Sawan Kumar: Adapted run_glue.py to learn classifiers for NILE
21 |
22 |
23 | from __future__ import absolute_import, division, print_function
24 |
25 | import argparse
26 | import glob
27 | import logging
28 | import os
29 | import random
30 |
31 | import numpy as np
32 | import torch
33 | import torch.nn as nn
34 | import torch.nn.functional as F
35 | from torch.utils.data import (DataLoader, RandomSampler, SequentialSampler,
36 | TensorDataset)
37 | from torch.utils.data.distributed import DistributedSampler
38 | from tensorboardX import SummaryWriter
39 | from tqdm import tqdm, trange
40 |
41 | from transformers import (WEIGHTS_NAME,
42 | RobertaConfig,
43 | RobertaForSequenceClassification,
44 | RobertaTokenizer)
45 |
46 | from transformers import AdamW, get_linear_schedule_with_warmup
47 |
48 | from example_to_feature import convert_examples_to_features as convert_examples_to_features
49 |
50 | from nli_utils import ExpProcessor
51 | from nli_utils import exp_compute_metrics as compute_metrics
52 |
53 | #from lm_utils import NLIDataset, CoQADataset
54 |
55 | logger = logging.getLogger(__name__)
56 |
57 | ALL_MODELS = sum((tuple(conf.pretrained_config_archive_map.keys()) for conf in (RobertaConfig,)), ())
58 |
59 | MODEL_CLASSES = {
60 | 'roberta': (RobertaConfig, RobertaForSequenceClassification, RobertaTokenizer),
61 | }
62 |
63 |
64 | def set_seed(args):
65 | random.seed(args.seed)
66 | np.random.seed(args.seed)
67 | torch.manual_seed(args.seed)
68 | if args.n_gpu > 0:
69 | torch.cuda.manual_seed_all(args.seed)
70 |
71 | def get_logits(batch, model_output, exp_model):
72 | if exp_model in ["instance", "append", "instance_append", "all_explanation", "Explanation_1"]:
73 | return model_output
74 | model_output = model_output.view(batch[0].size(0), batch[0].size(1), model_output.size(1))
75 | e,c,n = 0,1,2
76 | v1,v2 = 0,1
77 | if exp_model in ["independent", "instance_independent"]:
78 | evidence_e = [model_output[:, e, v1]]
79 | evidence_c = [model_output[:, c, v1]]
80 | evidence_n = [model_output[:, n, v1]]
81 | elif exp_model in ["aggregate", "instance_aggregate"]:
82 | evidence_e = [model_output[:, e, v1], model_output[:, c, v2]]
83 | evidence_c = [model_output[:, e, v2], model_output[:, c, v1]]
84 | evidence_n = [model_output[:, n, v1]]
85 | logits_e, logits_c, logits_n = [torch.cat([evd.unsqueeze(-1) for evd in item], 1)
86 | for item in [evidence_e, evidence_c, evidence_n]]
87 | logits_e, logits_c, logits_n = [torch.logsumexp(item, 1, keepdim=True) for item in [logits_e, logits_c, logits_n]]
88 | logits = torch.cat([logits_e, logits_c, logits_n], 1)
89 | return logits
90 |
91 | def train(args, train_dataset, model, tokenizer):
92 | """ Train the model """
93 | if args.local_rank in [-1, 0]:
94 | tb_writer = SummaryWriter()
95 |
96 | args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)
97 | train_sampler = RandomSampler(train_dataset) if args.local_rank == -1 else DistributedSampler(train_dataset)
98 | train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size)
99 |
100 | if args.max_steps > 0:
101 | t_total = args.max_steps
102 | args.num_train_epochs = args.max_steps // (len(train_dataloader) // args.gradient_accumulation_steps) + 1
103 | else:
104 | t_total = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs
105 |
106 | # Prepare optimizer and schedule (linear warmup and decay)
107 | no_decay = ['bias', 'LayerNorm.weight']
108 | optimizer_grouped_parameters = [
109 | {'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': args.weight_decay},
110 | {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
111 | ]
112 | optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
113 | scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total)
114 | if args.fp16:
115 | try:
116 | from apex import amp
117 | except ImportError:
118 | raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
119 | model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16_opt_level)
120 |
121 | # multi-gpu training (should be after apex fp16 initialization)
122 | if args.n_gpu > 1:
123 | model = torch.nn.DataParallel(model)
124 |
125 | # Distributed training (should be after apex fp16 initialization)
126 | if args.local_rank != -1:
127 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank],
128 | output_device=args.local_rank,
129 | find_unused_parameters=True)
130 |
131 | # Train!
132 | logger.info("***** Running training *****")
133 | logger.info(" Num examples = %d", len(train_dataset))
134 | logger.info(" Num Epochs = %d", args.num_train_epochs)
135 | logger.info(" Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size)
136 | logger.info(" Total train batch size (w. parallel, distributed & accumulation) = %d",
137 | args.train_batch_size * args.gradient_accumulation_steps * (torch.distributed.get_world_size() if args.local_rank != -1 else 1))
138 | logger.info(" Gradient Accumulation steps = %d", args.gradient_accumulation_steps)
139 | logger.info(" Total optimization steps = %d", t_total)
140 |
141 | global_step = 0
142 | tr_loss, logging_loss = 0.0, 0.0
143 | model.zero_grad()
144 | train_iterator = trange(int(args.num_train_epochs), desc="Epoch", disable=args.local_rank not in [-1, 0])
145 | set_seed(args) # Added here for reproductibility (even between python 2 and 3)
146 | for _ in train_iterator:
147 | epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0])
148 | for step, batch in enumerate(epoch_iterator):
149 | model.train()
150 | batch = tuple(t.to(args.device) for t in batch)
151 |
152 | inputs = {'input_ids': batch[0].view(-1, batch[0].size(-1)),
153 | 'attention_mask': batch[1].view(-1, batch[0].size(-1))}
154 | inputs['token_type_ids'] = None
155 | outputs = model(**inputs)
156 | model_output = outputs[0]
157 |
158 | loss_fct = nn.CrossEntropyLoss()
159 | logits = get_logits(batch, model_output, args.exp_model)
160 | loss = loss_fct(logits, batch[3])
161 |
162 | if args.n_gpu > 1:
163 | loss = loss.mean() # mean() to average on multi-gpu parallel training
164 | if args.gradient_accumulation_steps > 1:
165 | loss = loss / args.gradient_accumulation_steps
166 |
167 | if args.fp16:
168 | with amp.scale_loss(loss, optimizer) as scaled_loss:
169 | scaled_loss.backward()
170 | torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm)
171 | else:
172 | loss.backward()
173 | torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
174 |
175 | tr_loss += loss.item()
176 | if (step + 1) % args.gradient_accumulation_steps == 0:
177 | optimizer.step()
178 | scheduler.step() # Update learning rate schedule
179 | model.zero_grad()
180 | global_step += 1
181 |
182 | if args.local_rank in [-1, 0] and args.logging_steps > 0 and global_step % args.logging_steps == 0:
183 | # Log metrics
184 | if args.local_rank == -1 and args.evaluate_during_training: # Only evaluate when single GPU otherwise metrics may not average well
185 | results = evaluate(args, model, tokenizer)
186 | for key, value in results.items():
187 | tb_writer.add_scalar('eval_{}'.format(key), value, global_step)
188 | tb_writer.add_scalar('lr', scheduler.get_lr()[0], global_step)
189 | tb_writer.add_scalar('loss', (tr_loss - logging_loss)/args.logging_steps, global_step)
190 | logging_loss = tr_loss
191 |
192 | if args.local_rank in [-1, 0] and args.save_steps > 0 and global_step % args.save_steps == 0:
193 | # Save model checkpoint
194 | output_dir = os.path.join(args.output_dir, 'checkpoint-{}'.format(global_step))
195 | if not os.path.exists(output_dir):
196 | os.makedirs(output_dir)
197 | model_to_save = model.module if hasattr(model, 'module') else model # Take care of distributed/parallel training
198 | model_to_save.save_pretrained(output_dir)
199 | torch.save(args, os.path.join(output_dir, 'training_args.bin'))
200 | logger.info("Saving model checkpoint to %s", output_dir)
201 |
202 | if args.max_steps > 0 and global_step > args.max_steps:
203 | epoch_iterator.close()
204 | break
205 | if args.max_steps > 0 and global_step > args.max_steps:
206 | train_iterator.close()
207 | break
208 |
209 | if args.local_rank in [-1, 0]:
210 | tb_writer.close()
211 |
212 | return global_step, tr_loss / global_step
213 |
214 |
215 | def evaluate(args, model, tokenizer, prefix="", analyze_attentions=False, eval_on_train=False):
216 | processor = ExpProcessor()
217 | eval_output_dir = args.output_dir
218 |
219 | results = {}
220 | eval_dataset, indices = load_and_cache_examples(args, tokenizer, evaluate=True)
221 | if not os.path.exists(eval_output_dir) and args.local_rank in [-1, 0]:
222 | os.makedirs(eval_output_dir)
223 |
224 | args.eval_batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu)
225 | # Note that DistributedSampler samples randomly
226 | eval_sampler = SequentialSampler(eval_dataset) if args.local_rank == -1 else DistributedSampler(eval_dataset)
227 | eval_dataloader = DataLoader(eval_dataset, sampler=eval_sampler, batch_size=args.eval_batch_size)
228 |
229 | # Eval!
230 | logger.info("***** Running evaluation {} *****".format(prefix))
231 | logger.info(" Num examples = %d", len(eval_dataset))
232 | logger.info(" Batch size = %d", args.eval_batch_size)
233 | eval_loss = 0.0
234 | nb_eval_steps = 0
235 | preds = None
236 | out_label_ids = None
237 | attentions = []
238 | for batch in tqdm(eval_dataloader, desc="Evaluating"):
239 | model.eval()
240 | batch = tuple(t.to(args.device) for t in batch)
241 |
242 | with torch.no_grad():
243 | inputs = {'input_ids': batch[0].view(-1, batch[0].size(-1)),
244 | 'attention_mask': batch[1].view(-1, batch[0].size(-1))}
245 | inputs['token_type_ids'] = None
246 | outputs = model(**inputs)
247 | model_output = outputs[0]
248 | loss_fct = nn.CrossEntropyLoss()
249 | logits = get_logits(batch, model_output, args.exp_model)
250 | tmp_eval_loss = loss_fct(logits, batch[3])
251 | eval_loss += tmp_eval_loss.mean().item()
252 |
253 | nb_eval_steps += 1
254 | inputs['labels'] = batch[3]
255 | if preds is None:
256 | preds = logits.detach().cpu().numpy()
257 | out_label_ids = inputs['labels'].detach().cpu().numpy()
258 | else:
259 | preds = np.append(preds, logits.detach().cpu().numpy(), axis=0)
260 | out_label_ids = np.append(out_label_ids, inputs['labels'].detach().cpu().numpy(), axis=0)
261 |
262 | eval_loss = eval_loss / nb_eval_steps
263 | preds_ = preds
264 | preds = np.argmax(preds, axis=1)
265 | result = compute_metrics(preds, out_label_ids)
266 | results.update(result)
267 |
268 | if not eval_on_train:
269 | eval_dataset = os.path.basename(os.path.dirname(os.path.dirname(args.eval_file))).split('_')[1]
270 | eval_file_base = os.path.splitext(os.path.basename(args.eval_file))[0]
271 | to_drop_list = args.to_drop.split(',') if evaluate else []
272 | to_drop_str = '_drop'+''.join(to_drop_list) if args.to_drop else ''
273 | prediction_file = os.path.join(args.output_dir, 'predictions_{}_{}_{}_{}{}.npz'.format(
274 | eval_dataset,
275 | eval_file_base,
276 | str(args.max_seq_length),
277 | str(args.data_format),
278 | to_drop_str))
279 | print ("Writing predictions to ", prediction_file)
280 | np.savez_compressed(prediction_file, preds=preds, indices=indices, labels=out_label_ids)
281 |
282 | output_eval_file = os.path.join(eval_output_dir, "eval_results.txt")
283 | with open(output_eval_file, "w") as writer:
284 | logger.info("***** Eval results {} *****".format(prefix))
285 | for key in sorted(result.keys()):
286 | logger.info(" %s = %s", key, str(result[key]))
287 | writer.write("%s = %s\n" % (key, str(result[key])))
288 |
289 | return results
290 |
291 |
292 | def load_and_cache_examples(args, tokenizer, evaluate=False):
293 | if args.local_rank not in [-1, 0] and not evaluate:
294 | torch.distributed.barrier() # Make sure only the first process in distributed training process the dataset, and the others will use the cache
295 |
296 | processor = ExpProcessor()
297 | # Load data features from cache or dataset file
298 |
299 | filename = args.train_file if not evaluate else args.eval_file
300 | data_dir, filename_base = os.path.split(filename)
301 |
302 | if args.data_format == "aggregate": data_storage_format = "independent"
303 | elif args.data_format == "instance_aggregate": data_storage_format = "instance_independent"
304 | else: data_storage_format = args.data_format
305 |
306 | to_drop_list = args.to_drop.split(',') if evaluate else []
307 |
308 | cached_features_file = os.path.join(data_dir, 'cached_seq{}_{}_{}_{}_{}'.format(
309 | str(args.max_seq_length),
310 | filename_base,
311 | data_storage_format,
312 | 'drop'+args.to_drop if args.to_drop and evaluate else '',
313 | '_negs' if args.sample_negs and not evaluate else ''
314 | ))
315 |
316 | if os.path.exists(cached_features_file):
317 | logger.info("Loading features from cached file %s", cached_features_file)
318 | examples, features = torch.load(cached_features_file)
319 | else:
320 | logger.info("Creating features from dataset file at %s", data_dir)
321 | label_list = processor.get_labels()
322 | fn = processor.get_dev_examples if evaluate else processor.get_train_examples
323 | examples = fn(filename, data_format=data_storage_format, to_drop=to_drop_list)
324 |
325 | features = convert_examples_to_features(examples,
326 | tokenizer,
327 | label_list=label_list,
328 | max_length=args.max_seq_length,
329 | pad_on_left=False,
330 | pad_token=tokenizer.convert_tokens_to_ids([tokenizer.pad_token])[0],
331 | pad_token_segment_id=0,
332 | sample_negatives=args.sample_negs if not evaluate else False,
333 | )
334 | if args.local_rank in [-1, 0]:
335 | logger.info("Saving features into cached file %s", cached_features_file)
336 | torch.save((examples, features), cached_features_file)
337 |
338 | if args.local_rank == 0 and not evaluate:
339 | torch.distributed.barrier() # Make sure only the first process in distributed training process the dataset, and the others will use the cache
340 |
341 | # Convert to Tensors and build dataset
342 | features = features
343 | all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long)
344 | all_attention_mask = torch.tensor([f.attention_mask for f in features], dtype=torch.long)
345 | all_token_type_ids = torch.tensor([f.token_type_ids for f in features], dtype=torch.long)
346 | all_labels = torch.tensor([f.label for f in features], dtype=torch.long)
347 |
348 | dataset = TensorDataset(all_input_ids, all_attention_mask, all_token_type_ids, all_labels)
349 |
350 | indices = [example.guid for example in examples]
351 | return dataset, indices
352 |
353 |
354 | def main():
355 | parser = argparse.ArgumentParser()
356 |
357 | ## Required parameters
358 | parser.add_argument("--train_file", default=None, type=str, required=True,
359 | help="The input train file. Should contain the .tsv files (or other data files) for the task.")
360 | parser.add_argument("--eval_file", default=None, type=str, required=True,
361 | help="The input eval file. Should contain the .tsv files (or other data files) for the task.")
362 | parser.add_argument("--model_type", default=None, type=str, required=True,
363 | help="Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys()))
364 | parser.add_argument("--model_name_or_path", default=None, type=str, required=True,
365 | help="Path to pre-trained model or shortcut name selected in the list: " + ", ".join(ALL_MODELS))
366 | parser.add_argument("--output_dir", default=None, type=str, required=True,
367 | help="The output directory where the model predictions and checkpoints will be written.")
368 |
369 | ## Other parameters
370 | parser.add_argument("--exp_model", default="instance", type=str)
371 | parser.add_argument("--data_format", default="instance", type=str)
372 | parser.add_argument("--to_drop", default="", type=str) #comma-sep list
373 |
374 | parser.add_argument("--config_name", default="", type=str,
375 | help="Pretrained config name or path if not the same as model_name")
376 | parser.add_argument("--tokenizer_name", default="", type=str,
377 | help="Pretrained tokenizer name or path if not the same as model_name")
378 | parser.add_argument("--cache_dir", default="", type=str,
379 | help="Where do you want to store the pre-trained models downloaded from s3")
380 | parser.add_argument("--max_seq_length", default=128, type=int,
381 | help="The maximum total input sequence length after tokenization. Sequences longer "
382 | "than this will be truncated, sequences shorter will be padded.")
383 | parser.add_argument("--do_train", action='store_true',
384 | help="Whether to run training.")
385 | parser.add_argument("--do_eval", action='store_true',
386 | help="Whether to run eval on the dev set.")
387 | parser.add_argument("--evaluate_during_training", action='store_true',
388 | help="Rul evaluation during training at each logging step.")
389 | parser.add_argument("--do_lower_case", action='store_true',
390 | help="Set this flag if you are using an uncased model.")
391 |
392 | parser.add_argument("--per_gpu_train_batch_size", default=8, type=int,
393 | help="Batch size per GPU/CPU for training.")
394 | parser.add_argument("--per_gpu_eval_batch_size", default=8, type=int,
395 | help="Batch size per GPU/CPU for evaluation.")
396 | parser.add_argument('--gradient_accumulation_steps', type=int, default=1,
397 | help="Number of updates steps to accumulate before performing a backward/update pass.")
398 | parser.add_argument("--learning_rate", default=5e-5, type=float,
399 | help="The initial learning rate for Adam.")
400 | parser.add_argument("--weight_decay", default=0.0, type=float,
401 | help="Weight deay if we apply some.")
402 | parser.add_argument("--adam_epsilon", default=1e-8, type=float,
403 | help="Epsilon for Adam optimizer.")
404 | parser.add_argument("--max_grad_norm", default=1.0, type=float,
405 | help="Max gradient norm.")
406 | parser.add_argument("--num_train_epochs", default=3.0, type=float,
407 | help="Total number of training epochs to perform.")
408 | parser.add_argument("--max_steps", default=-1, type=int,
409 | help="If > 0: set total number of training steps to perform. Override num_train_epochs.")
410 | parser.add_argument("--warmup_steps", default=0, type=int,
411 | help="Linear warmup over warmup_steps.")
412 |
413 | parser.add_argument("--prompt_type", default="none", type=str,
414 | help="Prompt given before explanation")
415 | parser.add_argument("--use_annotations", action='store_true',
416 | help="Whether to use annotations instead of generated explanations")
417 |
418 | parser.add_argument('--logging_steps', type=int, default=50,
419 | help="Log every X updates steps.")
420 | parser.add_argument('--save_steps', type=int, default=50,
421 | help="Save checkpoint every X updates steps.")
422 | parser.add_argument("--eval_all_checkpoints", action='store_true',
423 | help="Evaluate all checkpoints starting with the same prefix as model_name ending and ending with step number")
424 | parser.add_argument("--no_cuda", action='store_true',
425 | help="Avoid using CUDA when available")
426 | parser.add_argument('--overwrite_output_dir', action='store_true',
427 | help="Overwrite the content of the output directory")
428 | parser.add_argument('--overwrite_cache', action='store_true',
429 | help="Overwrite the cached training and evaluation sets")
430 | parser.add_argument('--seed', type=int, default=42,
431 | help="random seed for initialization")
432 |
433 | parser.add_argument('--sample_negs', action='store_true',
434 | help='sample negative conjectures')
435 |
436 | parser.add_argument('--fp16', action='store_true',
437 | help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit")
438 | parser.add_argument('--fp16_opt_level', type=str, default='O1',
439 | help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
440 | "See details at https://nvidia.github.io/apex/amp.html")
441 | parser.add_argument("--local_rank", type=int, default=-1,
442 | help="For distributed training: local_rank")
443 | parser.add_argument('--server_ip', type=str, default='', help="For distant debugging.")
444 | parser.add_argument('--server_port', type=str, default='', help="For distant debugging.")
445 | args = parser.parse_args()
446 |
447 | if os.path.exists(args.output_dir) and os.listdir(args.output_dir) and args.do_train and not args.overwrite_output_dir:
448 | raise ValueError("Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome.".format(args.output_dir))
449 |
450 | # Setup distant debugging if needed
451 | if args.server_ip and args.server_port:
452 | # Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script
453 | import ptvsd
454 | print("Waiting for debugger attach")
455 | ptvsd.enable_attach(address=(args.server_ip, args.server_port), redirect_output=True)
456 | ptvsd.wait_for_attach()
457 |
458 | # Setup CUDA, GPU & distributed training
459 | if args.local_rank == -1 or args.no_cuda:
460 | device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
461 | args.n_gpu = torch.cuda.device_count()
462 | else: # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
463 | torch.cuda.set_device(args.local_rank)
464 | device = torch.device("cuda", args.local_rank)
465 | torch.distributed.init_process_group(backend='nccl')
466 | args.n_gpu = 1
467 | args.device = device
468 |
469 | # Setup logging
470 | logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
471 | datefmt = '%m/%d/%Y %H:%M:%S',
472 | level = logging.INFO if args.local_rank in [-1, 0] else logging.WARN)
473 | logger.warning("Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
474 | args.local_rank, device, args.n_gpu, bool(args.local_rank != -1), args.fp16)
475 |
476 | # Set seed
477 | set_seed(args)
478 |
479 | processor = ExpProcessor()
480 | label_list = processor.get_labels()
481 | num_labels = len(label_list)
482 |
483 | if args.exp_model in ["independent", "instance_independent"]: args.model_num_outputs = 1
484 | elif args.exp_model in ["aggregate", "instance_aggregate"]: args.model_num_outputs = 2
485 | else: args.model_num_outputs = 3
486 |
487 | # Load pretrained model and tokenizer
488 | if args.local_rank not in [-1, 0]:
489 | torch.distributed.barrier() # Make sure only the first process in distributed training will download model & vocab
490 |
491 | args.model_type = args.model_type.lower()
492 |
493 | config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
494 | config = config_class.from_pretrained(args.config_name if args.config_name else args.model_name_or_path, num_labels=args.model_num_outputs,
495 | cache_dir=args.cache_dir if args.cache_dir else None)
496 | tokenizer = tokenizer_class.from_pretrained(args.tokenizer_name if args.tokenizer_name else args.model_name_or_path, do_lower_case=args.do_lower_case,
497 | cache_dir=args.cache_dir if args.cache_dir else None)
498 |
499 | if args.do_train:
500 | model = model_class.from_pretrained(args.model_name_or_path, from_tf=bool('.ckpt' in args.model_name_or_path), config=config,
501 | cache_dir=args.cache_dir if args.cache_dir else None)
502 |
503 | if args.local_rank == 0:
504 | torch.distributed.barrier() # Make sure only the first process in distributed training will download model & vocab
505 |
506 |
507 | logger.info("Training/evaluation parameters %s", args)
508 |
509 |
510 | # Training
511 | if args.do_train:
512 | model.to(args.device)
513 | train_dataset, _ = load_and_cache_examples(args, tokenizer, evaluate=False)
514 | global_step, tr_loss = train(args, train_dataset, model, tokenizer)
515 | logger.info(" global_step = %s, average loss = %s", global_step, tr_loss)
516 |
517 |
518 | # Saving best-practices: if you use defaults names for the model, you can reload it using from_pretrained()
519 | if args.do_train and (args.local_rank == -1 or torch.distributed.get_rank() == 0):
520 | # Create output directory if needed
521 | if not os.path.exists(args.output_dir) and args.local_rank in [-1, 0]:
522 | os.makedirs(args.output_dir)
523 |
524 | logger.info("Saving model checkpoint to %s", args.output_dir)
525 | # Save a trained model, configuration and tokenizer using `save_pretrained()`.
526 | # They can then be reloaded using `from_pretrained()`
527 | model_to_save = model.module if hasattr(model, 'module') else model # Take care of distributed/parallel training
528 | model_to_save.save_pretrained(args.output_dir)
529 | tokenizer.save_pretrained(args.output_dir)
530 |
531 | # Good practice: save your training arguments together with the trained model
532 | torch.save(args, os.path.join(args.output_dir, 'training_args.bin'))
533 |
534 | # Load a trained model and vocabulary that you have fine-tuned
535 | model = model_class.from_pretrained(args.output_dir)
536 | tokenizer = tokenizer_class.from_pretrained(args.output_dir, do_lower_case=args.do_lower_case)
537 | model.to(args.device)
538 |
539 |
540 | # Evaluation
541 | results = {}
542 | if args.do_eval and args.local_rank in [-1, 0]:
543 | tokenizer = tokenizer_class.from_pretrained(args.output_dir, do_lower_case=args.do_lower_case)
544 | checkpoints = [args.output_dir]
545 | if args.eval_all_checkpoints:
546 | checkpoints = list(os.path.dirname(c) for c in sorted(glob.glob(args.output_dir + '/**/' + WEIGHTS_NAME, recursive=True)))
547 | logging.getLogger("transformers.modeling_utils").setLevel(logging.WARN) # Reduce logging
548 | logger.info("Evaluate the following checkpoints: %s", checkpoints)
549 | for checkpoint in checkpoints:
550 | global_step = checkpoint.split('-')[-1] if len(checkpoints) > 1 else ""
551 | model = model_class.from_pretrained(checkpoint)
552 | model.output_hidden_states = True
553 | model.output_attentions = True
554 | model.to(args.device)
555 | result = evaluate(args, model, tokenizer, prefix=global_step, analyze_attentions=True)
556 | result = dict((k + '_{}'.format(global_step), v) for k, v in result.items())
557 | results.update(result)
558 |
559 | print (results)
560 | return results
561 |
562 |
563 | if __name__ == "__main__":
564 | main()
565 |
--------------------------------------------------------------------------------