├── LICENSE ├── README.md ├── mixout.py ├── model_utils.py ├── options.py ├── prior_wd_optim.py ├── repo_illustration.png ├── requirements.txt ├── run_glue.py └── sample_commands ├── debiased_adam_baseline.sh ├── debiased_adam_longer.sh ├── llrd.sh ├── mixout.sh ├── pretrained_wd.sh └── reinit.sh /LICENSE: -------------------------------------------------------------------------------- 1 | This software repository is a fork of the following Hugging Face 2 | Transformers checkpoint: 3 | https://github.com/huggingface/transformers/commit/11c3257a18c4b5e1a3c1746eefd96f180358397b 4 | 5 | The Hugging Face Transformers checkpoint remains under the Apache 6 | License Version 2.0 that is copied below. Changes by ASAPP to the 7 | Hugging Face Transformers checkpoint are released under the MIT 8 | License that is also copied below. 9 | 10 | ASAPP is distributing its modifications of the Hugging Face 11 | Transformers checkpoint under Section 4 of the Apache License but is 12 | not contributing its modifications back to Hugging Face Transformers 13 | and accordingly Section 3 of the Apache License does not apply. 14 | 15 | ================================================== 16 | 17 | Apache License 18 | Version 2.0, January 2004 19 | http://www.apache.org/licenses/ 20 | 21 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 22 | 23 | 1. Definitions. 24 | 25 | "License" shall mean the terms and conditions for use, reproduction, 26 | and distribution as defined by Sections 1 through 9 of this document. 27 | 28 | "Licensor" shall mean the copyright owner or entity authorized by 29 | the copyright owner that is granting the License. 30 | 31 | "Legal Entity" shall mean the union of the acting entity and all 32 | other entities that control, are controlled by, or are under common 33 | control with that entity. For the purposes of this definition, 34 | "control" means (i) the power, direct or indirect, to cause the 35 | direction or management of such entity, whether by contract or 36 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 37 | outstanding shares, or (iii) beneficial ownership of such entity. 38 | 39 | "You" (or "Your") shall mean an individual or Legal Entity 40 | exercising permissions granted by this License. 41 | 42 | "Source" form shall mean the preferred form for making modifications, 43 | including but not limited to software source code, documentation 44 | source, and configuration files. 45 | 46 | "Object" form shall mean any form resulting from mechanical 47 | transformation or translation of a Source form, including but 48 | not limited to compiled object code, generated documentation, 49 | and conversions to other media types. 50 | 51 | "Work" shall mean the work of authorship, whether in Source or 52 | Object form, made available under the License, as indicated by a 53 | copyright notice that is included in or attached to the work 54 | (an example is provided in the Appendix below). 55 | 56 | "Derivative Works" shall mean any work, whether in Source or Object 57 | form, that is based on (or derived from) the Work and for which the 58 | editorial revisions, annotations, elaborations, or other modifications 59 | represent, as a whole, an original work of authorship. For the purposes 60 | of this License, Derivative Works shall not include works that remain 61 | separable from, or merely link (or bind by name) to the interfaces of, 62 | the Work and Derivative Works thereof. 63 | 64 | "Contribution" shall mean any work of authorship, including 65 | the original version of the Work and any modifications or additions 66 | to that Work or Derivative Works thereof, that is intentionally 67 | submitted to Licensor for inclusion in the Work by the copyright owner 68 | or by an individual or Legal Entity authorized to submit on behalf of 69 | the copyright owner. For the purposes of this definition, "submitted" 70 | means any form of electronic, verbal, or written communication sent 71 | to the Licensor or its representatives, including but not limited to 72 | communication on electronic mailing lists, source code control systems, 73 | and issue tracking systems that are managed by, or on behalf of, the 74 | Licensor for the purpose of discussing and improving the Work, but 75 | excluding communication that is conspicuously marked or otherwise 76 | designated in writing by the copyright owner as "Not a Contribution." 77 | 78 | "Contributor" shall mean Licensor and any individual or Legal Entity 79 | on behalf of whom a Contribution has been received by Licensor and 80 | subsequently incorporated within the Work. 81 | 82 | 2. Grant of Copyright License. Subject to the terms and conditions of 83 | this License, each Contributor hereby grants to You a perpetual, 84 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 85 | copyright license to reproduce, prepare Derivative Works of, 86 | publicly display, publicly perform, sublicense, and distribute the 87 | Work and such Derivative Works in Source or Object form. 88 | 89 | 3. Grant of Patent License. Subject to the terms and conditions of 90 | this License, each Contributor hereby grants to You a perpetual, 91 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 92 | (except as stated in this section) patent license to make, have made, 93 | use, offer to sell, sell, import, and otherwise transfer the Work, 94 | where such license applies only to those patent claims licensable 95 | by such Contributor that are necessarily infringed by their 96 | Contribution(s) alone or by combination of their Contribution(s) 97 | with the Work to which such Contribution(s) was submitted. If You 98 | institute patent litigation against any entity (including a 99 | cross-claim or counterclaim in a lawsuit) alleging that the Work 100 | or a Contribution incorporated within the Work constitutes direct 101 | or contributory patent infringement, then any patent licenses 102 | granted to You under this License for that Work shall terminate 103 | as of the date such litigation is filed. 104 | 105 | 4. Redistribution. You may reproduce and distribute copies of the 106 | Work or Derivative Works thereof in any medium, with or without 107 | modifications, and in Source or Object form, provided that You 108 | meet the following conditions: 109 | 110 | (a) You must give any other recipients of the Work or 111 | Derivative Works a copy of this License; and 112 | 113 | (b) You must cause any modified files to carry prominent notices 114 | stating that You changed the files; and 115 | 116 | (c) You must retain, in the Source form of any Derivative Works 117 | that You distribute, all copyright, patent, trademark, and 118 | attribution notices from the Source form of the Work, 119 | excluding those notices that do not pertain to any part of 120 | the Derivative Works; and 121 | 122 | (d) If the Work includes a "NOTICE" text file as part of its 123 | distribution, then any Derivative Works that You distribute must 124 | include a readable copy of the attribution notices contained 125 | within such NOTICE file, excluding those notices that do not 126 | pertain to any part of the Derivative Works, in at least one 127 | of the following places: within a NOTICE text file distributed 128 | as part of the Derivative Works; within the Source form or 129 | documentation, if provided along with the Derivative Works; or, 130 | within a display generated by the Derivative Works, if and 131 | wherever such third-party notices normally appear. The contents 132 | of the NOTICE file are for informational purposes only and 133 | do not modify the License. You may add Your own attribution 134 | notices within Derivative Works that You distribute, alongside 135 | or as an addendum to the NOTICE text from the Work, provided 136 | that such additional attribution notices cannot be construed 137 | as modifying the License. 138 | 139 | You may add Your own copyright statement to Your modifications and 140 | may provide additional or different license terms and conditions 141 | for use, reproduction, or distribution of Your modifications, or 142 | for any such Derivative Works as a whole, provided Your use, 143 | reproduction, and distribution of the Work otherwise complies with 144 | the conditions stated in this License. 145 | 146 | 5. Submission of Contributions. Unless You explicitly state otherwise, 147 | any Contribution intentionally submitted for inclusion in the Work 148 | by You to the Licensor shall be under the terms and conditions of 149 | this License, without any additional terms or conditions. 150 | Notwithstanding the above, nothing herein shall supersede or modify 151 | the terms of any separate license agreement you may have executed 152 | with Licensor regarding such Contributions. 153 | 154 | 6. Trademarks. This License does not grant permission to use the trade 155 | names, trademarks, service marks, or product names of the Licensor, 156 | except as required for reasonable and customary use in describing the 157 | origin of the Work and reproducing the content of the NOTICE file. 158 | 159 | 7. Disclaimer of Warranty. Unless required by applicable law or 160 | agreed to in writing, Licensor provides the Work (and each 161 | Contributor provides its Contributions) on an "AS IS" BASIS, 162 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 163 | implied, including, without limitation, any warranties or conditions 164 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 165 | PARTICULAR PURPOSE. You are solely responsible for determining the 166 | appropriateness of using or redistributing the Work and assume any 167 | risks associated with Your exercise of permissions under this License. 168 | 169 | 8. Limitation of Liability. In no event and under no legal theory, 170 | whether in tort (including negligence), contract, or otherwise, 171 | unless required by applicable law (such as deliberate and grossly 172 | negligent acts) or agreed to in writing, shall any Contributor be 173 | liable to You for damages, including any direct, indirect, special, 174 | incidental, or consequential damages of any character arising as a 175 | result of this License or out of the use or inability to use the 176 | Work (including but not limited to damages for loss of goodwill, 177 | work stoppage, computer failure or malfunction, or any and all 178 | other commercial damages or losses), even if such Contributor 179 | has been advised of the possibility of such damages. 180 | 181 | 9. Accepting Warranty or Additional Liability. While redistributing 182 | the Work or Derivative Works thereof, You may choose to offer, 183 | and charge a fee for, acceptance of support, warranty, indemnity, 184 | or other liability obligations and/or rights consistent with this 185 | License. However, in accepting such obligations, You may act only 186 | on Your own behalf and on Your sole responsibility, not on behalf 187 | of any other Contributor, and only if You agree to indemnify, 188 | defend, and hold each Contributor harmless for any liability 189 | incurred by, or claims asserted against, such Contributor by reason 190 | of your accepting any such warranty or additional liability. 191 | 192 | ================================================== 193 | 194 | Copyright (c) 2020 ASAPP Inc. 195 | 196 | MIT License 197 | 198 | Permission is hereby granted, free of charge, to any person obtaining a copy 199 | of this software and associated documentation files (the "Software"), to deal 200 | in the Software without restriction, including without limitation the rights 201 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 202 | copies of the Software, and to permit persons to whom the Software is 203 | furnished to do so, subject to the following conditions: 204 | 205 | The above copyright notice and this permission notice shall be included in all 206 | copies or substantial portions of the Software. 207 | 208 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 209 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 210 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 211 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 212 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 213 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 214 | SOFTWARE. 215 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Revisiting Few-sample BERT Fine-tuning 2 | [![made-with-python](https://img.shields.io/badge/Made%20with-Python-red.svg)](#python) 3 | [![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT) 4 | 5 | [Paper Link](https://arxiv.org/abs/2006.05987) 6 | 7 | #### Authors: 8 | * [Tianyi Zhang](https://tiiiger.github.io)* 9 | * [Felix Wu](https://sites.google.com/view/felixwu/home)* 10 | * [Arzoo Katiyar](https://sites.google.com/site/arzook99/home) 11 | * [Kilian Q. Weinberger](http://kilian.cs.cornell.edu/index.html) 12 | * [Yoav Artzi](https://yoavartzi.com/) 13 | 14 | *: Equal Contribution 15 | 16 | ## Overview 17 | 18 | ![](./repo_illustration.png "repo_illustration") 19 | 20 | In this paper, we study the problem of few-sample BERT fine-tuning and identify three sub-optimal practices. 21 | First, we observe that the omission of the gradient bias correction in the BERTAdam makes fine-tuning unstable. 22 | We also find that the top layers of BERT provide a detrimental initialization and simply re-initializing these layers improves convergence and performance. 23 | Finally, we observe that commonly used recipes often do not allocate sufficient time for training. 24 | 25 | If you find this repo useful, please cite: 26 | ``` 27 | @article{revisit-bert-finetuning, 28 | title={Revisiting Few-sample BERT Fine-tuning}, 29 | author={Zhang, Tianyi and Wu, Felix and Katiyar, Arzoo and Weinberger, Kilian Q. and Artzi, Yoav.}, 30 | journal={arXiv preprint arXiv:2006.05987}, 31 | year={2019} 32 | } 33 | ``` 34 | 35 | ## Requirements 36 | ``` 37 | torch==1.4.0 38 | transformers==2.8.0 39 | apex==0.1 40 | tqdm 41 | tensorboardX 42 | ``` 43 | Please install apex following the instructions at [https://github.com/NVIDIA/apex](https://github.com/NVIDIA/apex). 44 | 45 | ## Usage 46 | We provide the following sample scripts. When using these scripts, please change `--data_dir`, `--output_dir` and `--cache_dir` to the your path to data folder, output folder, and `transformers` cache directory. 47 | 48 | 1. To train BERT baseline (with debiased Adam): 49 | ```sh 50 | bash sample_commands/debiased_adam_baseline.sh 51 | ``` 52 | 2. To use Re-init: 53 | ```sh 54 | bash sample_commands/reinit.sh 55 | ``` 56 | 3. To train the model with more iterations 57 | ```sh 58 | bash sample_commands/debiased_adam_longer.sh 59 | ``` 60 | 4. To use mixout: 61 | ```sh 62 | bash sample_commands/mixout.sh 63 | ``` 64 | 5. To use layer-wise learning rate decay: 65 | ```sh 66 | bash sample_commands/llrd.sh 67 | ``` 68 | 6. To use pretrained weight decay: 69 | ```sh 70 | bash sample_commands/pretrained_wd.sh 71 | ``` 72 | 73 | ### Input 74 | You need to download [GLUE](https://gluebenchmark.com/) dataset by this [script](https://gist.github.com/W4ngatang/60c2bdb54d156a41194446737ce03e2e). 75 | Feed the path to your data through `--data_dir`. 76 | 77 | ### Commands 78 | We provide example commands to replicate our experiments in `sample_commands`. 79 | 80 | `run_glue.py` contains the main program to fine-tuning and evaluate models. 81 | `python run_glue.py --help` shows all available options. 82 | 83 | Some key options are: 84 | 85 | ``` 86 | # These two replicate our experiments of bias cortrection 87 | --use_bertadam No bias correction # this replicates the behavior of BERTAdam 88 | --use_torch_adamw Use pytorch adamw # this replicates the behavior of debiased Adam 89 | # These two two replicate our experiments of Re-init 90 | --reinit_pooler reinitialize the pooler 91 | --reinit_layers re-initialize the last N Transformer blocks. reinit_pooler must be turned on. 92 | ``` 93 | 94 | ### Output 95 | 96 | A standard output folder generated by `run_glue.py` will look like: 97 | ``` 98 | ├── raw_log.txt 99 | ├── test_best_log.txt 100 | ├── test_last_log.txt 101 | └── training_args.bin 102 | ``` 103 | `*_log.txt` are csv files that record relevant training and evaluate results. 104 | `test_best_log.txt` records the test performance with the best model checkpoint during training. 105 | `test_last_log.txt` records that with the last model checkpoint. 106 | `training_args.bin` contains all arguments used to run a job. 107 | -------------------------------------------------------------------------------- /mixout.py: -------------------------------------------------------------------------------- 1 | ##++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 | ## Created by: Cheolhyoung Lee 3 | ## Department of Mathematical Sciences, KAIST 4 | ## Email: cheolhyoung.lee@kaist.ac.kr 5 | ## Implementation of mixout from https://arxiv.org/abs/1909.11299 6 | ## "Mixout: Effective Regularization to Finetune Large-scale Pretrained Language Models" 7 | ##++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 8 | import math 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.init as init 12 | import torch.nn.functional as F 13 | 14 | from torch.nn import Parameter 15 | from torch.autograd.function import InplaceFunction 16 | 17 | 18 | class Mixout(InplaceFunction): 19 | # target: a weight tensor mixes with a input tensor 20 | # A forward method returns 21 | # [(1 - Bernoulli(1 - p) mask) * target + (Bernoulli(1 - p) mask) * input - p * target]/(1 - p) 22 | # where p is a mix probability of mixout. 23 | # A backward returns the gradient of the forward method. 24 | # Dropout is equivalent to the case of target=None. 25 | # I modified the code of dropout in PyTorch. 26 | @staticmethod 27 | def _make_noise(input): 28 | return input.new().resize_as_(input) 29 | 30 | @classmethod 31 | def forward(cls, ctx, input, target=None, p=0.0, training=False, inplace=False): 32 | if p < 0 or p > 1: 33 | raise ValueError("A mix probability of mixout has to be between 0 and 1," " but got {}".format(p)) 34 | if target is not None and input.size() != target.size(): 35 | raise ValueError( 36 | "A target tensor size must match with a input tensor size {}," 37 | " but got {}".format(input.size(), target.size()) 38 | ) 39 | ctx.p = p 40 | ctx.training = training 41 | 42 | if ctx.p == 0 or not ctx.training: 43 | return input 44 | 45 | if target is None: 46 | target = cls._make_noise(input) 47 | target.fill_(0) 48 | target = target.to(input.device) 49 | 50 | if inplace: 51 | ctx.mark_dirty(input) 52 | output = input 53 | else: 54 | output = input.clone() 55 | 56 | ctx.noise = cls._make_noise(input) 57 | if len(ctx.noise.size()) == 1: 58 | ctx.noise.bernoulli_(1 - ctx.p) 59 | else: 60 | ctx.noise[0].bernoulli_(1 - ctx.p) 61 | ctx.noise = ctx.noise[0].repeat(input.size()[0], 1) 62 | ctx.noise.expand_as(input) 63 | 64 | if ctx.p == 1: 65 | output = target 66 | else: 67 | output = ((1 - ctx.noise) * target + ctx.noise * output - ctx.p * target) / (1 - ctx.p) 68 | return output 69 | 70 | @staticmethod 71 | def backward(ctx, grad_output): 72 | if ctx.p > 0 and ctx.training: 73 | return grad_output * ctx.noise, None, None, None, None 74 | else: 75 | return grad_output, None, None, None, None 76 | 77 | 78 | def mixout(input, target=None, p=0.0, training=False, inplace=False): 79 | return Mixout.apply(input, target, p, training, inplace) 80 | 81 | 82 | class MixLinear(torch.nn.Module): 83 | __constants__ = ["bias", "in_features", "out_features"] 84 | # If target is None, nn.Sequential(nn.Linear(m, n), MixLinear(m', n', p)) 85 | # is equivalent to nn.Sequential(nn.Linear(m, n), nn.Dropout(p), nn.Linear(m', n')). 86 | # If you want to change a dropout layer to a mixout layer, 87 | # you should replace nn.Linear right after nn.Dropout(p) with Mixout(p) 88 | def __init__(self, in_features, out_features, bias=True, target=None, p=0.0): 89 | super(MixLinear, self).__init__() 90 | self.in_features = in_features 91 | self.out_features = out_features 92 | self.weight = Parameter(torch.Tensor(out_features, in_features)) 93 | if bias: 94 | self.bias = Parameter(torch.Tensor(out_features)) 95 | else: 96 | self.register_parameter("bias", None) 97 | self.reset_parameters() 98 | self.target = target 99 | self.p = p 100 | 101 | def reset_parameters(self): 102 | init.kaiming_uniform_(self.weight, a=math.sqrt(5)) 103 | if self.bias is not None: 104 | fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight) 105 | bound = 1 / math.sqrt(fan_in) 106 | init.uniform_(self.bias, -bound, bound) 107 | 108 | def forward(self, input): 109 | return F.linear(input, mixout(self.weight, self.target, self.p, self.training), self.bias) 110 | 111 | def extra_repr(self): 112 | type = "drop" if self.target is None else "mix" 113 | return "{}={}, in_features={}, out_features={}, bias={}".format( 114 | type + "out", self.p, self.in_features, self.out_features, self.bias is not None 115 | ) 116 | 117 | -------------------------------------------------------------------------------- /model_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import transformers 4 | from transformers.modeling_electra import ElectraPreTrainedModel, ElectraModel 5 | from torch.nn import CrossEntropyLoss, MSELoss 6 | 7 | 8 | class ElectraForSequenceClassification(ElectraPreTrainedModel): 9 | def __init__(self, config): 10 | super().__init__(config) 11 | self.num_labels = config.num_labels 12 | 13 | self.electra = ElectraModel(config) 14 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 15 | self.classifier = nn.Linear(config.hidden_size, self.config.num_labels) 16 | 17 | self.init_weights() 18 | 19 | def forward( 20 | self, 21 | input_ids=None, 22 | attention_mask=None, 23 | token_type_ids=None, 24 | position_ids=None, 25 | head_mask=None, 26 | inputs_embeds=None, 27 | labels=None, 28 | ): 29 | r""" 30 | labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`): 31 | Labels for computing the token classification loss. 32 | Indices should be in ``[0, ..., config.num_labels - 1]``. 33 | 34 | Returns: 35 | :obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.ElectraConfig`) and inputs: 36 | loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when ``labels`` is provided) : 37 | Classification loss. 38 | scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.num_labels)`) 39 | Classification scores (before SoftMax). 40 | hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``): 41 | Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) 42 | of shape :obj:`(batch_size, sequence_length, hidden_size)`. 43 | 44 | Hidden-states of the model at the output of each layer plus the initial embedding outputs. 45 | attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_attentions=True``): 46 | Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape 47 | :obj:`(batch_size, num_heads, sequence_length, sequence_length)`. 48 | 49 | Attentions weights after the attention softmax, used to compute the weighted average in the self-attention 50 | heads. 51 | 52 | Examples:: 53 | 54 | from transformers import ElectraTokenizer, ElectraForSequenceClassification 55 | import torch 56 | 57 | tokenizer = ElectraTokenizer.from_pretrained('google/electra-small-discriminator') 58 | model = ElectraForSequenceClassification.from_pretrained('google/electra-small-discriminator') 59 | 60 | input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True)).unsqueeze(0) # Batch size 1 61 | labels = torch.tensor([1].unsqueeze(0) # Batch size 1 62 | outputs = model(input_ids, labels=labels) 63 | 64 | loss, scores = outputs[:2] 65 | 66 | """ 67 | 68 | discriminator_hidden_states = self.electra( 69 | input_ids, attention_mask, token_type_ids, position_ids, head_mask, inputs_embeds 70 | ) 71 | pooled_output = discriminator_hidden_states[0][:, 0, :] 72 | 73 | pooled_output = self.dropout(pooled_output) 74 | logits = self.classifier(pooled_output) 75 | 76 | outputs = (logits,) + discriminator_hidden_states[1:] # add hidden states and attention if they are here 77 | 78 | if labels is not None: 79 | if self.num_labels == 1: 80 | # We are doing regression 81 | loss_fct = MSELoss() 82 | loss = loss_fct(logits.view(-1), labels.view(-1)) 83 | else: 84 | loss_fct = CrossEntropyLoss() 85 | loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) 86 | outputs = (loss,) + outputs 87 | 88 | return outputs # (loss), logits, (hidden_states), (attentions) 89 | -------------------------------------------------------------------------------- /options.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from transformers import MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING 3 | from transformers import glue_processors as processors 4 | 5 | MODEL_CONFIG_CLASSES = list(MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING.keys()) 6 | MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES) 7 | 8 | ALL_MODELS = sum((tuple(conf.pretrained_config_archive_map.keys()) for conf in MODEL_CONFIG_CLASSES), (),) 9 | 10 | 11 | def get_parser(): 12 | parser = argparse.ArgumentParser() 13 | 14 | # Required parameters 15 | parser.add_argument( 16 | "--data_dir", 17 | default=None, 18 | type=str, 19 | required=True, 20 | help="The input data dir. Should contain the .tsv files (or other data files) for the task.", 21 | ) 22 | parser.add_argument( 23 | "--model_type", 24 | default=None, 25 | type=str, 26 | required=True, 27 | help="Model type selected in the list: " + ", ".join(MODEL_TYPES), 28 | ) 29 | parser.add_argument( 30 | "--model_name_or_path", 31 | default=None, 32 | type=str, 33 | required=True, 34 | help="Path to pre-trained model or shortcut name selected in the list: " + ", ".join(ALL_MODELS), 35 | ) 36 | parser.add_argument( 37 | "--task_name", 38 | default=None, 39 | type=str, 40 | required=True, 41 | help="The name of the task to train selected in the list: " + ", ".join(processors.keys()), 42 | ) 43 | parser.add_argument( 44 | "--output_dir", 45 | default=None, 46 | type=str, 47 | required=True, 48 | help="The output directory where the model predictions and checkpoints will be written.", 49 | ) 50 | 51 | # Other parameters 52 | parser.add_argument( 53 | "--config_name", default="", type=str, help="Pretrained config name or path if not the same as model_name", 54 | ) 55 | parser.add_argument( 56 | "--tokenizer_name", 57 | default="", 58 | type=str, 59 | help="Pretrained tokenizer name or path if not the same as model_name", 60 | ) 61 | parser.add_argument( 62 | "--cache_dir", 63 | default="", 64 | type=str, 65 | help="Where do you want to store the pre-trained models downloaded from s3", 66 | ) 67 | parser.add_argument( 68 | "--max_seq_length", 69 | default=128, 70 | type=int, 71 | help="The maximum total input sequence length after tokenization. Sequences longer " 72 | "than this will be truncated, sequences shorter will be padded.", 73 | ) 74 | parser.add_argument("--do_train", action="store_true", help="Whether to run training.") 75 | parser.add_argument( 76 | "--do_lower_case", action="store_true", help="Set this flag if you are using an uncased model.", 77 | ) 78 | parser.add_argument( 79 | "--save_best", action="store_true", help="Set this flag if you want to save the early stop model.", 80 | ) 81 | parser.add_argument( 82 | "--save_last", action="store_true", help="Set this flag if you want to save the last model.", 83 | ) 84 | parser.add_argument( 85 | "--train_batch_size", 86 | default=0, 87 | type=int, 88 | help="Batch size per GPU/CPU for training to override per_gpu_train_batch_size", 89 | ) 90 | parser.add_argument( 91 | "--per_gpu_train_batch_size", default=8, type=int, help="Batch size per GPU/CPU for training.", 92 | ) 93 | parser.add_argument( 94 | "--per_gpu_eval_batch_size", default=8, type=int, help="Batch size per GPU/CPU for evaluation.", 95 | ) 96 | parser.add_argument( 97 | "--gradient_accumulation_steps", 98 | type=int, 99 | default=1, 100 | help="Number of updates steps to accumulate before performing a backward/update pass.", 101 | ) 102 | parser.add_argument( 103 | "--learning_rate", default=5e-5, type=float, help="The initial learning rate for Adam.", 104 | ) 105 | parser.add_argument( 106 | "--layerwise_learning_rate_decay", default=1.0, type=float, help="layerwise learning rate decay", 107 | ) 108 | parser.add_argument("--weight_decay", default=0.0, type=float, help="Weight decay if we apply some.") 109 | parser.add_argument("--adam_epsilon", default=1e-8, type=float, help="Epsilon for Adam optimizer.") 110 | parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") 111 | parser.add_argument( 112 | "--num_train_epochs", default=3.0, type=float, help="Total number of training epochs to perform.", 113 | ) 114 | parser.add_argument( 115 | "--max_steps", 116 | default=-1, 117 | type=int, 118 | help="If > 0: set total number of training steps to perform. Override num_train_epochs.", 119 | ) 120 | parser.add_argument("--warmup_steps", default=0, type=int, help="Linear warmup over warmup_steps.") 121 | parser.add_argument("--warmup_ratio", default=0, type=float, help="Linear ratio over total steps.") 122 | parser.add_argument("--weight_logging_steps", type=int, default=10, help="Log every X updates steps.") 123 | parser.add_argument("--logging_steps", type=int, default=0, help="Log every X updates steps.") 124 | parser.add_argument("--num_loggings", type=int, default=0, help="Total amount of evaluations in training.") 125 | parser.add_argument( 126 | "--save_steps", type=int, default=500, help="Save checkpoint every X updates steps.", 127 | ) 128 | parser.add_argument("--no_cuda", action="store_true", help="Avoid using CUDA when available") 129 | parser.add_argument( 130 | "--overwrite_output_dir", action="store_true", help="Overwrite the content of the output directory", 131 | ) 132 | parser.add_argument( 133 | "--overwrite_cache", action="store_true", help="Overwrite the cached training and evaluation sets", 134 | ) 135 | parser.add_argument("--seed", type=int, default=42, help="random seed for initialization") 136 | 137 | parser.add_argument( 138 | "--fp16", 139 | action="store_true", 140 | help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit", 141 | ) 142 | parser.add_argument( 143 | "--fp16_opt_level", 144 | type=str, 145 | default="O1", 146 | help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']." 147 | "See details at https://nvidia.github.io/apex/amp.html", 148 | ) 149 | parser.add_argument( 150 | "--local_rank", type=int, default=-1, help="For distributed training: local_rank", 151 | ) 152 | parser.add_argument("--server_ip", type=str, default="", help="For distant debugging.") 153 | parser.add_argument("--server_port", type=str, default="", help="For distant debugging.") 154 | # Added by our paper 155 | parser.add_argument("--use_bertadam", action="store_true", help="No bias correction") 156 | parser.add_argument("--use_torch_adamw", action="store_true", help="Use pytorch adamw") 157 | parser.add_argument( 158 | "--downsample_trainset", default=-1, type=int, help="down sample training set to this number.", 159 | ) 160 | parser.add_argument("--resplit_val", default=0, type=int, help="Whether to get the (simulated) test accuracy.") 161 | parser.add_argument( 162 | "--reinit_layers", 163 | type=int, 164 | default=0, 165 | help="re-initialize the last N Transformer blocks. reinit_pooler must be turned on.", 166 | ) 167 | parser.add_argument( 168 | "--reinit_pooler", action="store_true", help="reinitialize the pooler", 169 | ) 170 | parser.add_argument("--rezero_layers", type=int, default=0, help="re-zero layers") 171 | parser.add_argument("--mixout", type=float, default=0.0, help="mixout probability (default: 0)") 172 | parser.add_argument( 173 | "--prior_weight_decay", action="store_true", help="Weight Decaying toward the bert params", 174 | ) 175 | parser.add_argument( 176 | "--test_val_split", action="store_true", help="Split the original development set in half", 177 | ) 178 | 179 | return parser 180 | -------------------------------------------------------------------------------- /prior_wd_optim.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.optim import Optimizer 3 | 4 | 5 | class PriorWD(Optimizer): 6 | def __init__(self, optim, use_prior_wd=False, exclude_last_group=True): 7 | super(PriorWD, self).__init__(optim.param_groups, optim.defaults) 8 | 9 | # python dictionary does not copy by default 10 | self.param_groups = optim.param_groups 11 | self.optim = optim 12 | self.use_prior_wd = use_prior_wd 13 | self.exclude_last_group = exclude_last_group 14 | 15 | self.weight_decay_by_group = [] 16 | for i, group in enumerate(self.param_groups): 17 | self.weight_decay_by_group.append(group["weight_decay"]) 18 | group["weight_decay"] = 0 19 | 20 | self.prior_params = {} 21 | for i, group in enumerate(self.param_groups): 22 | for p in group["params"]: 23 | self.prior_params[id(p)] = p.detach().clone() 24 | 25 | def step(self, closure=None): 26 | if self.use_prior_wd: 27 | for i, group in enumerate(self.param_groups): 28 | for p in group["params"]: 29 | if self.exclude_last_group and i == len(self.param_groups): 30 | p.data.add_(-group["lr"] * self.weight_decay_by_group[i], p.data) 31 | else: 32 | p.data.add_( 33 | -group["lr"] * self.weight_decay_by_group[i], p.data - self.prior_params[id(p)], 34 | ) 35 | loss = self.optim.step(closure) 36 | 37 | return loss 38 | 39 | def compute_distance_to_prior(self, param): 40 | """ 41 | Compute the L2-norm between the current parameter value to its initial (pre-trained) value. 42 | """ 43 | assert id(param) in self.prior_params, "parameter not in PriorWD optimizer" 44 | return (param.data - self.prior_params[id(param)]).pow(2).sum().sqrt() 45 | -------------------------------------------------------------------------------- /repo_illustration.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/asappresearch/revisit-bert-finetuning/0aa4f4e117ee4422f7cb9355158203e01d6730db/repo_illustration.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==1.4.0 2 | transformers==2.5.0 3 | tqdm 4 | numpy 5 | pandas -------------------------------------------------------------------------------- /run_glue.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # This file has been modified by ASAPP. The original file is licensed under the 3 | # Apache License Version 2.0. The modifications by ASAPP are licensed under 4 | # the MIT license. 5 | # 6 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. 7 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 8 | # 9 | # Licensed under the Apache License, Version 2.0 (the "License"); 10 | # you may not use this file except in compliance with the License. 11 | # You may obtain a copy of the License at 12 | # 13 | # http://www.apache.org/licenses/LICENSE-2.0 14 | # 15 | # Unless required by applicable law or agreed to in writing, software 16 | # distributed under the License is distributed on an "AS IS" BASIS, 17 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 18 | # See the License for the specific language governing permissions and 19 | # limitations under the License. 20 | """ Finetuning the library models for sequence classification on GLUE (Bert, XLM, XLNet, RoBERTa, Albert, XLM-RoBERTa).""" 21 | 22 | 23 | import argparse 24 | import glob 25 | import json 26 | import logging 27 | import os 28 | import random 29 | import re 30 | from collections import defaultdict 31 | 32 | import numpy as np 33 | import torch 34 | import torch.nn as nn 35 | from torch.utils.data import DataLoader, RandomSampler, SequentialSampler, TensorDataset 36 | from torch.utils.data.distributed import DistributedSampler 37 | from tqdm import tqdm, trange 38 | from torch.optim import Adam 39 | from options import get_parser 40 | from model_utils import ElectraForSequenceClassification 41 | 42 | from transformers import ( 43 | WEIGHTS_NAME, 44 | AdamW, 45 | AutoConfig, 46 | AutoModelForSequenceClassification, 47 | AutoTokenizer, 48 | get_linear_schedule_with_warmup, 49 | get_constant_schedule_with_warmup, 50 | get_cosine_schedule_with_warmup, 51 | ) 52 | 53 | from transformers import glue_compute_metrics as compute_metrics 54 | from transformers import glue_convert_examples_to_features as convert_examples_to_features 55 | from transformers import glue_output_modes as output_modes 56 | from transformers import glue_processors as processors 57 | 58 | try: 59 | from torch.utils.tensorboard import SummaryWriter 60 | except ImportError: 61 | from tensorboardX import SummaryWriter 62 | 63 | from prior_wd_optim import PriorWD 64 | 65 | logger = logging.getLogger(__name__) 66 | 67 | 68 | def set_seed(seed): 69 | random.seed(seed) 70 | np.random.seed(seed) 71 | torch.manual_seed(seed) 72 | torch.cuda.manual_seed_all(seed) 73 | 74 | 75 | def get_optimizer_grouped_parameters(args, model): 76 | no_decay = ["bias", "LayerNorm.weight"] 77 | if args.layerwise_learning_rate_decay == 1.0: 78 | optimizer_grouped_parameters = [ 79 | { 80 | "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 81 | "weight_decay": args.weight_decay, 82 | "lr": args.learning_rate, 83 | }, 84 | { 85 | "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 86 | "weight_decay": 0.0, 87 | "lr": args.learning_rate, 88 | }, 89 | ] 90 | else: 91 | optimizer_grouped_parameters = [ 92 | { 93 | "params": [p for n, p in model.named_parameters() if "classifier" in n or "pooler" in n], 94 | "weight_decay": 0.0, 95 | "lr": args.learning_rate, 96 | }, 97 | ] 98 | 99 | if args.model_type in ["bert", "roberta", "electra"]: 100 | num_layers = model.config.num_hidden_layers 101 | layers = [getattr(model, args.model_type).embeddings] + list(getattr(model, args.model_type).encoder.layer) 102 | layers.reverse() 103 | lr = args.learning_rate 104 | for layer in layers: 105 | lr *= args.layerwise_learning_rate_decay 106 | optimizer_grouped_parameters += [ 107 | { 108 | "params": [p for n, p in layer.named_parameters() if not any(nd in n for nd in no_decay)], 109 | "weight_decay": args.weight_decay, 110 | "lr": lr, 111 | }, 112 | { 113 | "params": [p for n, p in layer.named_parameters() if any(nd in n for nd in no_decay)], 114 | "weight_decay": 0.0, 115 | "lr": lr, 116 | }, 117 | ] 118 | else: 119 | raise NotImplementedError 120 | return optimizer_grouped_parameters 121 | 122 | 123 | def train(args, train_dataset, model, tokenizer): 124 | """ Train the model """ 125 | if args.local_rank in [-1, 0]: 126 | with open(f"{args.output_dir}/raw_log.txt", "w") as f: 127 | pass # create a new file 128 | 129 | if args.train_batch_size == 0: 130 | args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu) 131 | train_sampler = RandomSampler(train_dataset) if args.local_rank == -1 else DistributedSampler(train_dataset) 132 | train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size) 133 | 134 | eval_task_names = (args.task_name,) 135 | eval_datasets = [load_and_cache_examples(args, task, tokenizer, evaluate=True) for task in eval_task_names] 136 | if args.test_val_split: 137 | assert len(eval_datasets) == 1 138 | val_test_indices = [] 139 | for i, eval_dataset in enumerate(eval_datasets): 140 | class2idx = defaultdict(list) 141 | for i, sample in enumerate(eval_dataset): 142 | class2idx[sample[-1].item()].append(i) 143 | val_indices = [] 144 | test_indices = [] 145 | for class_num, indices in class2idx.items(): 146 | state = np.random.RandomState(1) 147 | state.shuffle(indices) 148 | class_val_indices, class_test_indices = indices[: len(indices) // 2], indices[len(indices) // 2 :] 149 | val_indices += class_val_indices 150 | test_indices += class_test_indices 151 | val_indices = torch.tensor(val_indices).long() 152 | test_indices = torch.tensor(test_indices).long() 153 | val_test_indices.append((val_indices, test_indices)) 154 | eval_dataset.tensors = [t[val_indices] for t in eval_dataset.tensors] 155 | 156 | if args.max_steps > 0: 157 | t_total = args.max_steps 158 | args.num_train_epochs = args.max_steps // (len(train_dataloader) // args.gradient_accumulation_steps) + 1 159 | else: 160 | t_total = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs 161 | 162 | assert args.logging_steps == 0 or args.num_loggings == 0, "Can only use 1 logging option" 163 | if args.logging_steps == 0: 164 | assert args.num_loggings > 0 165 | args.logging_steps = t_total // args.num_loggings 166 | 167 | if args.warmup_ratio > 0: 168 | assert args.warmup_steps == 0 169 | args.warmup_steps = int(args.warmup_ratio * t_total) 170 | 171 | # Prepare optimizer and schedule (linear warmup and decay) 172 | optimizer_grouped_parameters = get_optimizer_grouped_parameters(args, model) 173 | 174 | if args.use_torch_adamw: 175 | optimizer = torch.optim.AdamW( 176 | optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon, weight_decay=args.weight_decay 177 | ) 178 | else: 179 | optimizer = AdamW( 180 | optimizer_grouped_parameters, 181 | lr=args.learning_rate, 182 | eps=args.adam_epsilon, 183 | correct_bias=not args.use_bertadam, 184 | ) 185 | 186 | optimizer = PriorWD(optimizer, use_prior_wd=args.prior_weight_decay) 187 | scheduler = get_linear_schedule_with_warmup( 188 | optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total 189 | ) 190 | 191 | # Check if saved optimizer or scheduler states exist 192 | if os.path.isfile(os.path.join(args.model_name_or_path, "optimizer.pt")) and os.path.isfile( 193 | os.path.join(args.model_name_or_path, "scheduler.pt") 194 | ): 195 | # Load in optimizer and scheduler states 196 | optimizer.load_state_dict(torch.load(os.path.join(args.model_name_or_path, "optimizer.pt"))) 197 | scheduler.load_state_dict(torch.load(os.path.join(args.model_name_or_path, "scheduler.pt"))) 198 | 199 | if args.fp16: 200 | try: 201 | from apex import amp 202 | except ImportError: 203 | raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.") 204 | model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16_opt_level) 205 | 206 | # multi-gpu training (should be after apex fp16 initialization) 207 | if args.n_gpu > 1: 208 | model = torch.nn.DataParallel(model) 209 | 210 | # Distributed training (should be after apex fp16 initialization) 211 | if args.local_rank != -1: 212 | model = torch.nn.parallel.DistributedDataParallel( 213 | model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True, 214 | ) 215 | 216 | # Train! 217 | logger.info("***** Running training *****") 218 | logger.info(" Num examples = %d", len(train_dataset)) 219 | logger.info(" Num Epochs = %d", args.num_train_epochs) 220 | logger.info(" Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size) 221 | logger.info( 222 | " Total train batch size (w. parallel, distributed & accumulation) = %d", 223 | args.train_batch_size 224 | * args.gradient_accumulation_steps 225 | * (torch.distributed.get_world_size() if args.local_rank != -1 else 1), 226 | ) 227 | logger.info(" Gradient Accumulation steps = %d", args.gradient_accumulation_steps) 228 | logger.info(" Total optimization steps = %d", t_total) 229 | 230 | global_step = 0 231 | epochs_trained = 0 232 | steps_trained_in_current_epoch = 0 233 | # Check if continuing training from a checkpoint 234 | if os.path.exists(args.model_name_or_path): 235 | # set global_step to global_step of last saved checkpoint from model path 236 | try: 237 | global_step = int(args.model_name_or_path.split("-")[-1].split("/")[0]) 238 | except ValueError: 239 | global_step = 0 240 | epochs_trained = global_step // (len(train_dataloader) // args.gradient_accumulation_steps) 241 | steps_trained_in_current_epoch = global_step % (len(train_dataloader) // args.gradient_accumulation_steps) 242 | 243 | logger.info(" Continuing training from checkpoint, will skip to saved global_step") 244 | logger.info(" Continuing training from epoch %d", epochs_trained) 245 | logger.info(" Continuing training from global step %d", global_step) 246 | logger.info( 247 | " Will skip the first %d steps in the first epoch", steps_trained_in_current_epoch, 248 | ) 249 | 250 | tr_loss, logging_loss = 0.0, 0.0 251 | best_val_acc = -100.0 252 | best_model = None 253 | model.zero_grad() 254 | train_iterator = trange( 255 | epochs_trained, int(args.num_train_epochs), desc="Epoch", disable=args.local_rank not in [-1, 0], 256 | ) 257 | set_seed(args.seed) # Added here for reproductibility 258 | for _ in train_iterator: 259 | epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0]) 260 | for step, batch in enumerate(epoch_iterator): 261 | 262 | # Skip past any already trained steps if resuming training 263 | if steps_trained_in_current_epoch > 0: 264 | steps_trained_in_current_epoch -= 1 265 | continue 266 | 267 | model.train() 268 | batch = tuple(t.to(args.device) for t in batch) 269 | inputs = { 270 | "input_ids": batch[0], 271 | "attention_mask": batch[1], 272 | "labels": batch[3], 273 | } 274 | if args.model_type not in {"distilbert", "bart"}: 275 | inputs["token_type_ids"] = ( 276 | batch[2] if args.model_type in ["bert", "xlnet", "albert"] else None 277 | ) # XLM, DistilBERT, RoBERTa, and XLM-RoBERTa don't use segment_ids 278 | outputs = model(**inputs) 279 | loss = outputs[0] # model outputs are always tuple in transformers (see doc) 280 | 281 | if args.n_gpu > 1: 282 | loss = loss.mean() # mean() to average on multi-gpu parallel training 283 | if args.gradient_accumulation_steps > 1: 284 | loss = loss / args.gradient_accumulation_steps 285 | 286 | if args.fp16: 287 | with amp.scale_loss(loss, optimizer) as scaled_loss: 288 | scaled_loss.backward() 289 | else: 290 | loss.backward() 291 | 292 | tr_loss += loss.item() 293 | 294 | if (step + 1) % args.gradient_accumulation_steps == 0 or ( 295 | # last step in epoch but step is always smaller than gradient_accumulation_steps 296 | len(epoch_iterator) <= args.gradient_accumulation_steps 297 | and (step + 1) == len(epoch_iterator) 298 | ): 299 | if args.fp16: 300 | torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm) 301 | else: 302 | torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm) 303 | 304 | optimizer.step() 305 | scheduler.step() # Update learning rate schedule 306 | model.zero_grad() 307 | global_step += 1 308 | 309 | if args.local_rank in [-1, 0] and ( 310 | (args.logging_steps > 0 and global_step % args.logging_steps == 0) or (global_step == t_total) 311 | ): 312 | logs = {} 313 | if args.local_rank == -1: 314 | results = evaluate(args, model, tokenizer, eval_datasets=eval_datasets) 315 | for key, value in results.items(): 316 | eval_key = "val_{}".format(key) 317 | logs[eval_key] = value 318 | 319 | if args.local_rank in [-1, 0] and args.save_best and logs["val_acc"] > best_val_acc: 320 | output_dir = os.path.join(args.output_dir, "checkpoint-best") 321 | os.makedirs(output_dir, exist_ok=True) 322 | model_to_save = ( 323 | model.module if hasattr(model, "module") else model 324 | ) # Take care of distributed/parallel training 325 | model_to_save.save_pretrained(output_dir) 326 | tokenizer.save_pretrained(output_dir) 327 | 328 | torch.save(args, os.path.join(output_dir, "training_args.bin")) 329 | logger.info("Saving model checkpoint to %s", output_dir) 330 | 331 | torch.save( 332 | optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt"), 333 | ) 334 | torch.save( 335 | scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt"), 336 | ) 337 | logger.info("Saving optimizer and scheduler states to %s", output_dir) 338 | 339 | if "val_acc" in logs: 340 | if logs["val_acc"] > best_val_acc: 341 | best_val_acc = logs["val_acc"] 342 | best_model = {k: v.cpu().detach() for k, v in model.state_dict().items()} 343 | logs["best_val_acc"] = best_val_acc 344 | elif "val_mcc" in logs: 345 | if logs["val_mcc"] > best_val_acc: 346 | best_val_acc = logs["val_mcc"] 347 | best_model = {k: v.cpu().detach() for k, v in model.state_dict().items()} 348 | logs["best_val_mcc"] = best_val_acc 349 | elif "val_spearmanr": 350 | if logs["val_spearmanr"] > best_val_acc: 351 | best_val_acc = logs["val_spearmanr"] 352 | best_model = {k: v.cpu().detach() for k, v in model.state_dict().items()} 353 | logs["best_val_spearmanr"] = best_val_acc 354 | else: 355 | raise ValueError(f"logs:{logs}") 356 | 357 | learning_rate_scalar = scheduler.get_lr()[0] 358 | logs["learning_rate"] = learning_rate_scalar 359 | 360 | if args.logging_steps > 0: 361 | if global_step % args.logging_steps == 0: 362 | loss_scalar = (tr_loss - logging_loss) / args.logging_steps 363 | else: 364 | loss_scalar = (tr_loss - logging_loss) / (global_step % args.logging_steps) 365 | else: 366 | loss_scalar = (tr_loss - logging_loss) / global_step 367 | logs["loss"] = loss_scalar 368 | logging_loss = tr_loss 369 | 370 | logs["step"] = global_step 371 | with open(f"{args.output_dir}/raw_log.txt", "a") as f: 372 | if os.stat(f"{args.output_dir}/raw_log.txt").st_size == 0: 373 | for k in logs: 374 | f.write(f"{k},") 375 | f.write("\n") 376 | for v in logs.values(): 377 | f.write(f"{v:.6f},") 378 | f.write("\n") 379 | 380 | if args.local_rank in [-1, 0] and args.save_steps > 0 and global_step % args.save_steps == 0: 381 | # Save model checkpoint 382 | output_dir = os.path.join(args.output_dir, "checkpoint-last".format(global_step)) 383 | os.makedirs(output_dir, exist_ok=True) 384 | model_to_save = ( 385 | model.module if hasattr(model, "module") else model 386 | ) # Take care of distributed/parallel training 387 | model_to_save.save_pretrained(output_dir) 388 | tokenizer.save_pretrained(output_dir) 389 | 390 | torch.save(args, os.path.join(output_dir, "training_args.bin")) 391 | logger.info("Saving model checkpoint to %s", output_dir) 392 | 393 | torch.save(optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt")) 394 | torch.save(scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt")) 395 | logger.info("Saving optimizer and scheduler states to %s", output_dir) 396 | 397 | if args.max_steps > 0 and global_step > args.max_steps: 398 | epoch_iterator.close() 399 | break 400 | if args.max_steps > 0 and global_step > args.max_steps: 401 | train_iterator.close() 402 | break 403 | 404 | args.resplit_val = 0 # test on the original test_set 405 | eval_task_names = (args.task_name,) 406 | 407 | # test the last checkpoint on the second half 408 | eval_datasets = [load_and_cache_examples(args, task, tokenizer, evaluate=True) for task in eval_task_names] 409 | if args.test_val_split: 410 | for i, eval_dataset in enumerate(eval_datasets): 411 | test_indices = val_test_indices[i][1] 412 | eval_dataset.tensors = [t[test_indices] for t in eval_dataset.tensors] 413 | 414 | result = evaluate(args, model, tokenizer, eval_datasets=eval_datasets) 415 | result["step"] = t_total 416 | # overwriting validation results 417 | with open(f"{args.output_dir}/test_last_log.txt", "w") as f: 418 | f.write(",".join(["test_" + k for k in result.keys()]) + "\n") 419 | f.write(",".join([f"{v:.4f}" for v in result.values()])) 420 | 421 | if best_model is not None: 422 | model.load_state_dict(best_model) 423 | 424 | # test on the second half 425 | eval_datasets = [load_and_cache_examples(args, task, tokenizer, evaluate=True) for task in eval_task_names] 426 | if args.test_val_split: 427 | for i, eval_dataset in enumerate(eval_datasets): 428 | test_indices = val_test_indices[i][1] 429 | eval_dataset.tensors = [t[test_indices] for t in eval_dataset.tensors] 430 | 431 | result = evaluate(args, model, tokenizer, eval_datasets=eval_datasets) 432 | result["step"] = t_total 433 | # overwriting validation results 434 | with open(f"{args.output_dir}/test_best_log.txt", "w") as f: 435 | f.write(",".join(["test_" + k for k in result.keys()]) + "\n") 436 | f.write(",".join([f"{v:.4f}" for v in result.values()])) 437 | 438 | return global_step, tr_loss / global_step 439 | 440 | 441 | def evaluate(args, model, tokenizer, prefix="", eval_datasets=None): 442 | eval_task_names = [args.task_name] 443 | eval_outputs_dirs = [args.output_dir] 444 | 445 | results = {} 446 | for i, (eval_task, eval_output_dir) in enumerate(zip(eval_task_names, eval_outputs_dirs)): 447 | if eval_datasets is None: 448 | eval_dataset = load_and_cache_examples(args, eval_task, tokenizer, evaluate=True) 449 | elif isinstance(eval_datasets, list): 450 | eval_dataset = eval_datasets[i] 451 | else: 452 | raise ValueError("Wrong Pre-fetched Eval Set") 453 | 454 | if not os.path.exists(eval_output_dir) and args.local_rank in [-1, 0]: 455 | os.makedirs(eval_output_dir) 456 | 457 | args.eval_batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu) 458 | # Note that DistributedSampler samples randomly 459 | eval_sampler = SequentialSampler(eval_dataset) 460 | eval_dataloader = DataLoader(eval_dataset, sampler=eval_sampler, batch_size=args.eval_batch_size) 461 | 462 | # multi-gpu eval 463 | if args.n_gpu > 1 and not isinstance(model, torch.nn.DataParallel): 464 | model = torch.nn.DataParallel(model) 465 | 466 | # Eval! 467 | eval_loss = 0.0 468 | nb_eval_steps = 0 469 | preds = None 470 | out_label_ids = None 471 | for batch in eval_dataloader: 472 | model.eval() 473 | batch = tuple(t.to(args.device) for t in batch) 474 | 475 | with torch.no_grad(): 476 | inputs = { 477 | "input_ids": batch[0], 478 | "attention_mask": batch[1], 479 | "labels": batch[3], 480 | } 481 | if args.model_type not in {"distilbert", "bart"}: 482 | inputs["token_type_ids"] = ( 483 | batch[2] if args.model_type in ["bert", "xlnet", "albert"] else None 484 | ) # XLM, DistilBERT, RoBERTa, and XLM-RoBERTa don't use segment_ids 485 | outputs = model(**inputs) 486 | tmp_eval_loss, logits = outputs[:2] 487 | 488 | eval_loss += tmp_eval_loss.mean().item() 489 | 490 | nb_eval_steps += 1 491 | if preds is None: 492 | preds = logits.detach().cpu().numpy() 493 | out_label_ids = inputs["labels"].detach().cpu().numpy() 494 | else: 495 | preds = np.append(preds, logits.detach().cpu().numpy(), axis=0) 496 | out_label_ids = np.append(out_label_ids, inputs["labels"].detach().cpu().numpy(), axis=0) 497 | 498 | eval_loss = eval_loss / nb_eval_steps 499 | if args.output_mode == "classification": 500 | preds = np.argmax(preds, axis=1) 501 | elif args.output_mode == "regression": 502 | preds = np.squeeze(preds) 503 | result = compute_metrics(eval_task, preds, out_label_ids) 504 | results.update(result) 505 | 506 | results["loss"] = eval_loss 507 | 508 | return results 509 | 510 | 511 | def load_and_cache_examples(args, task, tokenizer, evaluate=False): 512 | if args.local_rank not in [-1, 0] and not evaluate: 513 | torch.distributed.barrier() # Make sure only the first process in distributed training process the dataset, and the others will use the cache 514 | 515 | processor = processors[task]() 516 | output_mode = output_modes[task] 517 | # Load data features from cache or dataset file 518 | cached_features_file = os.path.join( 519 | args.data_dir, 520 | "cached_{}_{}_{}_{}".format( 521 | "dev" if (evaluate and args.resplit_val <= 0) else "train", 522 | list(filter(None, args.model_name_or_path.split("/"))).pop(), 523 | str(args.max_seq_length), 524 | str(task), 525 | ), 526 | ) 527 | if os.path.exists(cached_features_file) and not args.overwrite_cache: 528 | logger.info("Loading features from cached file %s", cached_features_file) 529 | features = torch.load(cached_features_file) 530 | else: 531 | logger.info("Creating features from dataset file at %s", args.data_dir) 532 | label_list = processor.get_labels() 533 | if task in ["mnli", "mnli-mm"] and args.model_type in ["roberta", "xlmroberta"]: 534 | # HACK(label indices are swapped in RoBERTa pretrained model) 535 | label_list[1], label_list[2] = label_list[2], label_list[1] 536 | examples = ( 537 | processor.get_dev_examples(args.data_dir) if evaluate else processor.get_train_examples(args.data_dir) 538 | ) 539 | features = convert_examples_to_features( 540 | examples, 541 | tokenizer, 542 | label_list=label_list, 543 | max_length=args.max_seq_length, 544 | output_mode=output_mode, 545 | pad_on_left=bool(args.model_type in ["xlnet"]), # pad on the left for xlnet 546 | pad_token=tokenizer.convert_tokens_to_ids([tokenizer.pad_token])[0], 547 | pad_token_segment_id=4 if args.model_type in ["xlnet"] else 0, 548 | ) 549 | if args.local_rank in [-1, 0]: 550 | logger.info("Saving features into cached file %s", cached_features_file) 551 | torch.save(features, cached_features_file) 552 | 553 | if args.local_rank == 0 and not evaluate: 554 | torch.distributed.barrier() # Make sure only the first process in distributed training process the dataset, and the others will use the cache 555 | 556 | if args.downsample_trainset > 0 and not evaluate: 557 | assert (args.downsample_trainset + args.resplit_val) <= len(features) 558 | 559 | if args.downsample_trainset > 0 or args.resplit_val > 0: 560 | set_seed(0) # use the same seed for downsample 561 | if output_mode == "classification": 562 | label_to_idx = defaultdict(list) 563 | for i, f in enumerate(features): 564 | label_to_idx[f.label].append(i) 565 | 566 | samples_per_class = args.resplit_val if evaluate else args.downsample_trainset 567 | samples_per_class = samples_per_class // len(label_to_idx) 568 | 569 | for k in label_to_idx: 570 | label_to_idx[k] = np.array(label_to_idx[k]) 571 | np.random.shuffle(label_to_idx[k]) 572 | if evaluate: 573 | if args.resplit_val > 0: 574 | label_to_idx[k] = label_to_idx[k][-samples_per_class:] 575 | else: 576 | pass 577 | else: 578 | if args.resplit_val > 0 and args.downsample_trainset <= 0: 579 | samples_per_class = len(label_to_idx[k]) - args.resplit_val // len(label_to_idx) 580 | label_to_idx[k] = label_to_idx[k][:samples_per_class] 581 | 582 | sampled_idx = np.concatenate(list(label_to_idx.values())) 583 | else: 584 | if args.downsample_trainset > 0: 585 | sampled_idx = torch.randperm(len(features))[: args.downsample_trainset] 586 | else: 587 | raise NotImplementedError 588 | set_seed(args.seed) 589 | features = [features[i] for i in sampled_idx] 590 | 591 | # Convert to Tensors and build dataset 592 | all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long) 593 | all_attention_mask = torch.tensor([f.attention_mask for f in features], dtype=torch.long) 594 | all_token_type_ids = torch.tensor([f.token_type_ids for f in features], dtype=torch.long) 595 | if output_mode == "classification": 596 | all_labels = torch.tensor([f.label for f in features], dtype=torch.long) 597 | elif output_mode == "regression": 598 | all_labels = torch.tensor([f.label for f in features], dtype=torch.float) 599 | 600 | dataset = TensorDataset(all_input_ids, all_attention_mask, all_token_type_ids, all_labels) 601 | return dataset 602 | 603 | 604 | def main(): 605 | parser = get_parser() 606 | args = parser.parse_args() 607 | 608 | if ( 609 | os.path.exists(args.output_dir) 610 | and os.listdir(args.output_dir) 611 | and args.do_train 612 | and not args.overwrite_output_dir 613 | ): 614 | raise ValueError( 615 | "Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome.".format( 616 | args.output_dir 617 | ) 618 | ) 619 | 620 | if not os.path.exists(args.output_dir) and args.local_rank in [-1, 0]: 621 | os.makedirs(args.output_dir) 622 | 623 | # Setup distant debugging if needed 624 | if args.server_ip and args.server_port: 625 | # Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script 626 | import ptvsd 627 | 628 | print("Waiting for debugger attach") 629 | ptvsd.enable_attach(address=(args.server_ip, args.server_port), redirect_output=True) 630 | ptvsd.wait_for_attach() 631 | 632 | # Setup CUDA, GPU & distributed training 633 | if args.local_rank == -1 or args.no_cuda: 634 | device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu") 635 | args.n_gpu = 0 if args.no_cuda else torch.cuda.device_count() 636 | else: # Initializes the distributed backend which will take care of sychronizing nodes/GPUs 637 | torch.cuda.set_device(args.local_rank) 638 | device = torch.device("cuda", args.local_rank) 639 | torch.distributed.init_process_group(backend="nccl") 640 | args.n_gpu = 1 641 | args.device = device 642 | 643 | # Setup logging 644 | logging.basicConfig( 645 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 646 | datefmt="%m/%d/%Y %H:%M:%S", 647 | level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN, 648 | ) 649 | logger.warning( 650 | "Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s", 651 | args.local_rank, 652 | device, 653 | args.n_gpu, 654 | bool(args.local_rank != -1), 655 | args.fp16, 656 | ) 657 | 658 | # Set seed 659 | set_seed(args.seed) 660 | 661 | args.task_name = args.task_name.lower() 662 | if args.task_name not in processors: 663 | raise ValueError("Task not found: %s" % (args.task_name)) 664 | processor = processors[args.task_name]() 665 | args.output_mode = output_modes[args.task_name] 666 | label_list = processor.get_labels() 667 | num_labels = len(label_list) 668 | 669 | # Load pretrained model and tokenizer 670 | if args.local_rank not in [-1, 0]: 671 | torch.distributed.barrier() # Make sure only the first process in distributed training will download model & vocab 672 | 673 | args.model_type = args.model_type.lower() 674 | 675 | num_labels_old = AutoConfig.from_pretrained(args.model_name_or_path).num_labels 676 | config = AutoConfig.from_pretrained( 677 | args.config_name if args.config_name else args.model_name_or_path, 678 | num_labels=num_labels_old, 679 | finetuning_task=args.task_name, 680 | cache_dir=args.cache_dir if args.cache_dir else None, 681 | ) 682 | tokenizer = AutoTokenizer.from_pretrained( 683 | args.tokenizer_name if args.tokenizer_name else args.model_name_or_path, 684 | do_lower_case=args.do_lower_case, 685 | cache_dir=args.cache_dir if args.cache_dir else None, 686 | ) 687 | if args.model_type == "electra": 688 | model = ElectraForSequenceClassification.from_pretrained( 689 | args.model_name_or_path, 690 | from_tf=bool(".ckpt" in args.model_name_or_path), 691 | config=config, 692 | cache_dir=args.cache_dir if args.cache_dir else None, 693 | ) 694 | else: 695 | model = AutoModelForSequenceClassification.from_pretrained( 696 | args.model_name_or_path, 697 | from_tf=bool(".ckpt" in args.model_name_or_path), 698 | config=config, 699 | cache_dir=args.cache_dir if args.cache_dir else None, 700 | ) 701 | if num_labels != num_labels_old: 702 | config.num_labels = num_labels 703 | model.num_labels = num_labels 704 | if args.model_type in ["roberta", "bert", "electra"]: 705 | from transformers.modeling_roberta import RobertaClassificationHead 706 | 707 | model.classifier = ( 708 | RobertaClassificationHead(config) 709 | if args.model_type == "roberta" 710 | else nn.Linear(config.hidden_size, config.num_labels) 711 | ) 712 | for module in model.classifier.modules(): 713 | if isinstance(module, (nn.Linear, nn.Embedding)): 714 | # Slightly different from the TF version which uses truncated_normal for initialization 715 | # cf https://github.com/pytorch/pytorch/pull/5617 716 | module.weight.data.normal_(mean=0.0, std=config.initializer_range) 717 | if isinstance(module, nn.Linear) and module.bias is not None: 718 | module.bias.data.zero_() 719 | elif args.model_type == "bart": 720 | from transformers.modeling_bart import BartClassificationHead 721 | 722 | model.classification_head = BartClassificationHead( 723 | config.d_model, config.d_model, config.num_labels, config.classif_dropout, 724 | ) 725 | model.model._init_weights(model.classification_head.dense) 726 | model.model._init_weights(model.classification_head.out_proj) 727 | elif args.model_type == "xlnet": 728 | model.logits_proj = nn.Linear(config.d_model, config.num_labels) 729 | model.transformer._init_weights(model.logits_proj) 730 | else: 731 | raise NotImplementedError 732 | 733 | if args.local_rank == 0: 734 | torch.distributed.barrier() # Make sure only the first process in distributed training will download model & vocab 735 | 736 | if args.reinit_pooler: 737 | if args.model_type in ["bert", "roberta"]: 738 | encoder_temp = getattr(model, args.model_type) 739 | encoder_temp.pooler.dense.weight.data.normal_(mean=0.0, std=encoder_temp.config.initializer_range) 740 | encoder_temp.pooler.dense.bias.data.zero_() 741 | for p in encoder_temp.pooler.parameters(): 742 | p.requires_grad = True 743 | elif args.model_type in ["xlnet", "bart", "electra"]: 744 | raise ValueError(f"{args.model_type} does not have a pooler at the end") 745 | else: 746 | raise NotImplementedError 747 | 748 | if args.reinit_layers > 0: 749 | if args.model_type in ["bert", "roberta", "electra"]: 750 | assert args.reinit_pooler or args.model_type == "electra" 751 | from transformers.modeling_bert import BertLayerNorm 752 | 753 | encoder_temp = getattr(model, args.model_type) 754 | for layer in encoder_temp.encoder.layer[-args.reinit_layers :]: 755 | for module in layer.modules(): 756 | if isinstance(module, (nn.Linear, nn.Embedding)): 757 | # Slightly different from the TF version which uses truncated_normal for initialization 758 | # cf https://github.com/pytorch/pytorch/pull/5617 759 | module.weight.data.normal_(mean=0.0, std=encoder_temp.config.initializer_range) 760 | elif isinstance(module, BertLayerNorm): 761 | module.bias.data.zero_() 762 | module.weight.data.fill_(1.0) 763 | if isinstance(module, nn.Linear) and module.bias is not None: 764 | module.bias.data.zero_() 765 | elif args.model_type == "xlnet": 766 | from transformers.modeling_xlnet import XLNetLayerNorm, XLNetRelativeAttention 767 | 768 | for layer in model.transformer.layer[-args.reinit_layers :]: 769 | for module in layer.modules(): 770 | if isinstance(module, (nn.Linear, nn.Embedding)): 771 | # Slightly different from the TF version which uses truncated_normal for initialization 772 | # cf https://github.com/pytorch/pytorch/pull/5617 773 | module.weight.data.normal_(mean=0.0, std=model.transformer.config.initializer_range) 774 | if isinstance(module, nn.Linear) and module.bias is not None: 775 | module.bias.data.zero_() 776 | elif isinstance(module, XLNetLayerNorm): 777 | module.bias.data.zero_() 778 | module.weight.data.fill_(1.0) 779 | elif isinstance(module, XLNetRelativeAttention): 780 | for param in [ 781 | module.q, 782 | module.k, 783 | module.v, 784 | module.o, 785 | module.r, 786 | module.r_r_bias, 787 | module.r_s_bias, 788 | module.r_w_bias, 789 | module.seg_embed, 790 | ]: 791 | param.data.normal_(mean=0.0, std=model.transformer.config.initializer_range) 792 | elif args.model_type == "bart": 793 | for layer in model.model.decoder.layers[-args.reinit_layers :]: 794 | for module in layer.modules(): 795 | model.model._init_weights(module) 796 | 797 | else: 798 | raise NotImplementedError 799 | 800 | if args.mixout > 0: 801 | from mixout import MixLinear 802 | 803 | for sup_module in model.modules(): 804 | for name, module in sup_module.named_children(): 805 | if isinstance(module, nn.Dropout): 806 | module.p = 0.0 807 | if isinstance(module, nn.Linear): 808 | target_state_dict = module.state_dict() 809 | bias = True if module.bias is not None else False 810 | new_module = MixLinear( 811 | module.in_features, module.out_features, bias, target_state_dict["weight"], args.mixout 812 | ) 813 | new_module.load_state_dict(target_state_dict) 814 | setattr(sup_module, name, new_module) 815 | print(model) 816 | 817 | model.to(args.device) 818 | 819 | logger.info("Training/evaluation parameters %s", args) 820 | 821 | # Training 822 | if args.do_train: 823 | train_dataset = load_and_cache_examples(args, args.task_name, tokenizer, evaluate=False) 824 | global_step, tr_loss = train(args, train_dataset, model, tokenizer) 825 | logger.info(" global_step = %s, average loss = %s", global_step, tr_loss) 826 | 827 | # Saving best-practices: if you use defaults names for the model, you can reload it using from_pretrained() 828 | if args.do_train and (args.local_rank == -1 or torch.distributed.get_rank() == 0): 829 | logger.info("Saving model checkpoint to %s", args.output_dir) 830 | torch.save(args, os.path.join(args.output_dir, "training_args.bin")) 831 | 832 | 833 | if __name__ == "__main__": 834 | main() 835 | -------------------------------------------------------------------------------- /sample_commands/debiased_adam_baseline.sh: -------------------------------------------------------------------------------- 1 | python run_glue.py \ 2 | --model_type bert --model_name_or_path bert-large-uncased --task_name RTE \ 3 | --do_train --data_dir /persist/data/glue_data/RTE --max_seq_length 128 \ 4 | --per_gpu_eval_batch_size 64 --weight_decay 0 --seed 0 --fp16 \ 5 | --overwrite_output_dir --do_lower_case --per_gpu_train_batch_size 32 \ 6 | --gradient_accumulation_steps 1 --logging_steps 0 --num_loggings 10 \ 7 | --save_steps 0 --test_val_split --use_torch_adamw --cache_dir /home/ubuntu/hf-transformers-cache \ 8 | --num_train_epochs 3.0 --warmup_ratio 0.1 --learning_rate 2e-05 \ 9 | --output_dir bert_output/ORIGINAL/RTE/SEED0 -------------------------------------------------------------------------------- /sample_commands/debiased_adam_longer.sh: -------------------------------------------------------------------------------- 1 | # using bias correction 2 | python run_glue.py \ 3 | --model_type bert --model_name_or_path bert-large-uncased --task_name RTE \ 4 | --do_train --data_dir /persist/data/glue_data/RTE --max_seq_length 128 \ 5 | --per_gpu_eval_batch_size 64 --weight_decay 0 --seed 0 --fp16 \ 6 | --overwrite_output_dir --do_lower_case --per_gpu_train_batch_size 32 \ 7 | --gradient_accumulation_steps 1 --logging_steps 0 --num_loggings 10 \ 8 | --save_steps 0 --test_val_split --use_torch_adamw --cache_dir /home/ubuntu/hf-transformers-cache \ 9 | --max_steps 400 --num_train_epochs 0 --warmup_ratio 0.1 --learning_rate 2e-05 \ 10 | --output_dir bert_output/LONGER/RTE/SEED0 -------------------------------------------------------------------------------- /sample_commands/llrd.sh: -------------------------------------------------------------------------------- 1 | python run_glue.py \ 2 | --model_type bert --model_name_or_path bert-large-uncased --task_name RTE \ 3 | --do_train --data_dir /persist/data/glue_data/RTE --max_seq_length 128 \ 4 | --per_gpu_eval_batch_size 64 --weight_decay 0 --seed 0 --fp16 \ 5 | --overwrite_output_dir --do_lower_case --per_gpu_train_batch_size 32 \ 6 | --gradient_accumulation_steps 1 --logging_steps 0 --num_loggings 10 \ 7 | --save_steps 0 --test_val_split --use_torch_adamw --cache_dir /home/ubuntu/hf-transformers-cache \ 8 | --num_train_epochs 3.0 --warmup_ratio 0.1 --learning_rate 2e-05 \ 9 | --output_dir bert_output/LLRD/RTE/SEED0 --layerwise_learning_rate_decay 0.95 -------------------------------------------------------------------------------- /sample_commands/mixout.sh: -------------------------------------------------------------------------------- 1 | python run_glue.py \ 2 | --model_type bert --model_name_or_path bert-large-uncased --task_name RTE \ 3 | --do_train --data_dir /persist/data/glue_data/RTE --max_seq_length 128 \ 4 | --per_gpu_eval_batch_size 64 --weight_decay 0 --seed 0 --fp16 \ 5 | --overwrite_output_dir --do_lower_case --per_gpu_train_batch_size 16 \ 6 | --gradient_accumulation_steps 2 --logging_steps 0 --num_loggings 10 \ 7 | --save_steps 0 --test_val_split --use_torch_adamw --cache_dir /home/ubuntu/hf-transformers-cache \ 8 | --num_train_epochs 3.0 --warmup_ratio 0.1 --learning_rate 2e-05 \ 9 | --output_dir bert_output/MIXOUT/RTE/SEED0 --mixout 0.1 -------------------------------------------------------------------------------- /sample_commands/pretrained_wd.sh: -------------------------------------------------------------------------------- 1 | python run_glue.py \ 2 | --model_type bert --model_name_or_path bert-large-uncased --task_name RTE \ 3 | --do_train --data_dir /persist/data/glue_data/RTE --max_seq_length 128 \ 4 | --per_gpu_eval_batch_size 64 --weight_decay 0 --seed 0 --fp16 \ 5 | --overwrite_output_dir --do_lower_case --per_gpu_train_batch_size 32 \ 6 | --gradient_accumulation_steps 1 --logging_steps 0 --num_loggings 10 \ 7 | --save_steps 0 --test_val_split --use_torch_adamw --cache_dir /home/ubuntu/hf-transformers-cache \ 8 | --num_train_epochs 3.0 --warmup_ratio 0.1 --learning_rate 2e-05 \ 9 | --output_dir bert_output/ORIGINAL/RTE/SEED0 --prior_weight_decay --weight_decay 1e-3 -------------------------------------------------------------------------------- /sample_commands/reinit.sh: -------------------------------------------------------------------------------- 1 | # using bias correction 2 | python run_glue.py \ 3 | --model_type bert --model_name_or_path bert-large-uncased --task_name RTE \ 4 | --do_train --data_dir /persist/data/glue_data/RTE --max_seq_length 128 \ 5 | --per_gpu_eval_batch_size 64 --weight_decay 0 --seed 0 --fp16 \ 6 | --overwrite_output_dir --do_lower_case --per_gpu_train_batch_size 32 \ 7 | --gradient_accumulation_steps 1 --logging_steps 0 --num_loggings 10 \ 8 | --save_steps 0 --test_val_split --use_torch_adamw --cache_dir /home/ubuntu/hf-transformers-cache \ 9 | --num_train_epochs 3.0 --warmup_ratio 0.1 --learning_rate 2e-05 \ 10 | --output_dir bert_output/REINIT5/RTE/SEED0 \ 11 | --reinit_pooler --reinit_layers 5 --------------------------------------------------------------------------------