├── .gitignore ├── LICENSE ├── README.md ├── invariant_distilbert.py ├── invariant_roberta.py ├── invariant_trainer.py ├── requirements.txt └── run_invariant_mlm.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | ilm/ 113 | tmp/ 114 | 115 | # Spyder project settings 116 | .spyderproject 117 | .spyproject 118 | 119 | # Rope project settings 120 | .ropeproject 121 | 122 | # mkdocs documentation 123 | /site 124 | 125 | # mypy 126 | .mypy_cache/ 127 | .dmypy.json 128 | dmypy.json 129 | 130 | # Pyre type checker 131 | .pyre/ 132 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Invariant Language Modeling 2 | Implementation of the training for invariant language models. 3 | 4 | ## Motivation 5 | 6 | Modern pretrained language models are critical components of NLP pipelines. Yet, they suffer from spurious correlations, poor out-of-domain generalization, and biases. 7 | Inspired by recent progress in causal machine learning, we propose __invariant language modeling__, a framework to learn invariant representations that should generalize across training environments. 8 | In particular, we adapt [IRM-games](https://arxiv.org/abs/2002.04692) to language models, where the invariance emerges from a specific training schedule in which environments compete to optimize their environment-specific loss by updating subsets of the model in a round-robin fashion. 9 | 10 | ## Model Description 11 | 12 | The data is assumed to come as `n` distinct environments and we aim to learn a language model that focusing on correlations that generalize across environments. 13 | 14 | The model is decomposed into two components: 15 | * `ϕ` the main body of the transformer language model, 16 | * `w` the language modeling head that predicts the missing token. 17 | 18 | In our implementation, there are now as many heads as environments: `n`. 19 | For each data point, all heads make their predictions and they are averaged. 20 | However, during training we sample one batch from each environment in a round-robin fashion. 21 | When seeing a batch from environment `e` only the head `w_e` and the main body `ϕ` receive a batch update. 22 | 23 | ## Usage 24 | 25 | To get started with the code: 26 | ``` 27 | pip install -r requirements.txt 28 | ``` 29 | 30 | PyTorch with a CUDA installation is required to run this framework. 31 | Please find all useful installation information [here](https://pytorch.org/) 32 | 33 | Then, to continue the training of a language model from a [huggingface](https://huggingface.co/models) checkpoint: 34 | ``` 35 | CUDA_VISIBLE_DEVICES=0 python3 run_invariant_mlm.py \ 36 | --model_name_or_path roberta-base \ 37 | --validation_file data-folder/validation_file.txt \ 38 | --do_train \ 39 | --do_eval \ 40 | --nb_steps 5000 \ 41 | --learning_rate 1e-5 \ 42 | --output_dir folder-to-save-model \ 43 | --seed 123 \ 44 | --train_file data-folder/training-environments \ 45 | --overwrite_cache 46 | ``` 47 | If the machine on which the code is executed has several GPUs, we recommand to use the `CUDA_VISIBLE_DEVICE` command to 48 | restrict to one GPU as the multiple GPUs are currently not supported by the implementation. 49 | 50 | Currently, the supported base models are: 51 | * `roberta`: [checkpoints](https://huggingface.co/models?sort=downloads&search=roberta) 52 | * `distilbert`: [checkpoints](https://huggingface.co/models?sort=downloads&search=distilbert) 53 | 54 | ## Implementation 55 | 56 | To train language models according to the [IRM-games](https://arxiv.org/abs/2002.04692), one needs to modify: 57 | * the training schedule to perform batch updates according to each environment in a round-robin fashion. 58 | This logic is implemented by the `InvariantTrainer` in `invariant_trainer.py', a class inherited from the `Trainer` from huggingface. 59 | * the language modeling heads in the model. 60 | It needs one head per environment. 61 | This is done by creating variations of the base model classes. It is implemented in `invariant_roberta.py` for `roberta` and in `invariant_distilbert.py` for `distilbert`. 62 | 63 | #### Contact 64 | 65 | Maxime Peyrard, maxime.peyrard@epfl.ch 66 | -------------------------------------------------------------------------------- /invariant_distilbert.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import torch.nn as nn 3 | from torch.nn import CrossEntropyLoss 4 | from transformers.modeling_outputs import MaskedLMOutput 5 | from transformers.models.distilbert.modeling_distilbert import DistilBertPreTrainedModel, DistilBertModel, gelu 6 | from transformers.models.distilbert.configuration_distilbert import DistilBertConfig 7 | 8 | 9 | class DistilBertLMHead(nn.Module): 10 | """DistilBert Head for masked language modeling.""" 11 | 12 | def __init__(self, config): 13 | super().__init__() 14 | self.vocab_transform = nn.Linear(config.dim, config.dim) 15 | self.vocab_layer_norm = nn.LayerNorm(config.dim, eps=1e-12) 16 | self.vocab_projector = nn.Linear(config.dim, config.vocab_size) 17 | 18 | def forward(self, features, **kwargs): 19 | x = self.vocab_transform(features) # (bs, seq_length, dim) 20 | x = gelu(x) # (bs, seq_length, dim) 21 | x = self.vocab_layer_norm(x) # (bs, seq_length, dim) 22 | x = self.vocab_projector(x) 23 | 24 | return x 25 | 26 | 27 | class InvariantDistilBertConfig(DistilBertConfig): 28 | model_type = "invariant-distilbert" 29 | 30 | def __init__(self, envs=1, **kwargs): 31 | """Constructs InvariantDistilBertConfig.""" 32 | super().__init__(**kwargs) 33 | self.envs = envs 34 | 35 | 36 | class InvariantDistilBertForMaskedLM(DistilBertPreTrainedModel): 37 | authorized_missing_keys = [r"position_ids", r"predictions.decoder.bias"] 38 | authorized_unexpected_keys = [r"pooler"] 39 | 40 | def __init__(self, config, model=None): # , model, envs): 41 | super().__init__(config) 42 | 43 | self.config = config 44 | if config.is_decoder: 45 | logger.warning( 46 | "If you want to use `RobertaForMaskedLM` make sure `config.is_decoder=False` for " 47 | "bi-directional self-attention." 48 | ) 49 | 50 | self.encoder = DistilBertModel(config) 51 | self.encoder.to('cuda') 52 | 53 | if len(config.envs) == 0: 54 | self.envs = ['erm'] 55 | else: 56 | self.envs = config.envs 57 | 58 | self.lm_heads = {} 59 | for env_name in self.envs: 60 | self.lm_heads[env_name] = DistilBertLMHead(config) 61 | 62 | if model is not None: 63 | self.encoder = copy.deepcopy(model.distilbert) 64 | self.lm_heads = {} 65 | for env_name in self.envs: 66 | self.lm_heads[env_name] = DistilBertLMHead(config) 67 | self.lm_heads[env_name].vocab_transform = copy.deepcopy(model.vocab_transform) 68 | self.lm_heads[env_name].vocab_layer_norm = copy.deepcopy(model.vocab_layer_norm) 69 | self.lm_heads[env_name].vocab_projector = copy.deepcopy(model.vocab_projector) 70 | # self.register_parameter(env_name + '-head', self.lm_heads[env_name]) 71 | 72 | for env_name, lm_head in self.lm_heads.items(): 73 | self.__setattr__(env_name + '_head', self.lm_heads[env_name]) 74 | 75 | self.encoder.to('cuda') 76 | for _, lm_head in self.lm_heads.items(): 77 | lm_head.to('cuda') 78 | 79 | self.n_environments = len(self.lm_heads) 80 | 81 | def print_lm_w(self): 82 | for env, lm_h in self.lm_heads.items(): 83 | print(lm_h.dense.weight) 84 | 85 | def init_head(self): 86 | for env_name in self.envs: 87 | self.lm_heads[env_name] = DistilBertLMHead(config) 88 | self.lm_heads[env_name].to('cuda') 89 | 90 | def init_base(self): 91 | self.encoder.init_weights() 92 | self.init_head() 93 | 94 | def get_input_embeddings(self): 95 | return self.encoder.get_input_embeddings() 96 | 97 | def set_input_embeddings(self, value): 98 | self.encoder.set_input_embeddings(value) 99 | # self.embeddings.word_embeddings = value 100 | 101 | def get_output_embeddings(self): 102 | for env, lm_head in self.lm_heads.items(): 103 | return lm_head.vocab_projector 104 | 105 | def set_output_embeddings(self, new_embeddings): 106 | for env, lm_head in self.lm_heads.items(): 107 | lm_head.decoder = new_embeddings 108 | 109 | # @add_start_docstrings_to_callable(ROBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length")) 110 | # @add_code_sample_docstrings( 111 | # tokenizer_class=_TOKENIZER_FOR_DOC, 112 | # checkpoint="roberta-base", 113 | # output_type=MaskedLMOutput, 114 | # config_class=_CONFIG_FOR_DOC, 115 | # mask="", 116 | # ) 117 | def forward( 118 | self, 119 | input_ids=None, 120 | attention_mask=None, 121 | head_mask=None, 122 | inputs_embeds=None, 123 | labels=None, 124 | output_attentions=None, 125 | output_hidden_states=None, 126 | return_dict=None, 127 | **kwargs 128 | ): 129 | r""" 130 | labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): 131 | Labels for computing the masked language modeling loss. 132 | Indices should be in ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring) 133 | Tokens with indices set to ``-100`` are ignored (masked), the loss is only computed for the tokens with labels 134 | in ``[0, ..., config.vocab_size]`` 135 | kwargs (:obj:`Dict[str, any]`, optional, defaults to `{}`): 136 | Used to hide legacy arguments that have been deprecated. 137 | """ 138 | if "masked_lm_labels" in kwargs: 139 | warnings.warn( 140 | "The `masked_lm_labels` argument is deprecated and will be removed in a future version, use `labels` instead.", 141 | FutureWarning, 142 | ) 143 | labels = kwargs.pop("masked_lm_labels") 144 | assert kwargs == {}, f"Unexpected keyword arguments: {list(kwargs.keys())}." 145 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 146 | 147 | outputs = self.encoder( 148 | input_ids=input_ids, 149 | attention_mask=attention_mask, 150 | head_mask=head_mask, 151 | inputs_embeds=inputs_embeds, 152 | output_attentions=output_attentions, 153 | output_hidden_states=output_hidden_states, 154 | return_dict=return_dict, 155 | ) 156 | sequence_output = outputs[0] 157 | if self.n_environments == 1: 158 | lm_head = list(self.lm_heads.values())[0] 159 | prediction_scores = lm_head(sequence_output) 160 | else: 161 | prediction_scores = 0. 162 | for env, lm_head in self.lm_heads.items(): 163 | prediction_scores += 1. / self.n_environments * lm_head(sequence_output) 164 | 165 | masked_lm_loss = None 166 | if labels is not None: 167 | loss_fct = CrossEntropyLoss() 168 | masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)) 169 | 170 | if not return_dict: 171 | output = (prediction_scores,) + outputs[2:] 172 | return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output 173 | 174 | return MaskedLMOutput( 175 | loss=masked_lm_loss, 176 | logits=prediction_scores, 177 | hidden_states=outputs.hidden_states, 178 | attentions=outputs.attentions, 179 | ) 180 | -------------------------------------------------------------------------------- /invariant_roberta.py: -------------------------------------------------------------------------------- 1 | import copy 2 | from torch.nn import CrossEntropyLoss 3 | from transformers.models.roberta.modeling_roberta import RobertaPreTrainedModel, RobertaModel, RobertaLMHead 4 | from transformers.modeling_outputs import MaskedLMOutput 5 | from transformers.models.roberta.configuration_roberta import RobertaConfig 6 | 7 | 8 | class InvariantRobertaConfig(RobertaConfig): 9 | model_type = "invariant-roberta" 10 | 11 | def __init__(self, envs=1, **kwargs): 12 | """Constructs InvariantRobertaConfig.""" 13 | super().__init__(**kwargs) 14 | self.envs = envs 15 | 16 | 17 | # TODO: This could inherit from an InvariantRobertaModel class 18 | class InvariantRobertaForMaskedLM(RobertaPreTrainedModel): 19 | authorized_missing_keys = [r"position_ids", r"predictions.decoder.bias"] 20 | authorized_unexpected_keys = [r"pooler"] 21 | 22 | def __init__(self, config, model=None): # , model, envs): 23 | super().__init__(config) 24 | 25 | self.config = config 26 | if config.is_decoder: 27 | logger.warning( 28 | "If you want to use `RobertaForMaskedLM` make sure `config.is_decoder=False` for " 29 | "bi-directional self-attention." 30 | ) 31 | 32 | self.encoder = RobertaModel(config, add_pooling_layer=False) 33 | 34 | if len(config.envs) == 0: 35 | self.envs = ['erm'] 36 | else: 37 | self.envs = config.envs 38 | 39 | self.lm_heads = {} 40 | for env_name in self.envs: 41 | self.lm_heads[env_name] = RobertaLMHead(config) 42 | 43 | if model is not None: 44 | self.encoder = copy.deepcopy(model.roberta) 45 | self.lm_heads = {} 46 | for env_name in self.envs: 47 | self.lm_heads[env_name] = copy.deepcopy(model.lm_head) 48 | # self.register_parameter(env_name + '-head', self.lm_heads[env_name]) 49 | 50 | for env_name, lm_head in self.lm_heads.items(): 51 | self.__setattr__(env_name + '_head', self.lm_heads[env_name]) 52 | # self.register_parameter(name=env_name, param=lm_head) 53 | 54 | self.encoder.to('cuda') 55 | for _, lm_head in self.lm_heads.items(): 56 | lm_head.to('cuda') 57 | 58 | self.n_environments = len(self.lm_heads) 59 | 60 | def print_lm_w(self): 61 | for env, lm_h in self.lm_heads.items(): 62 | print(lm_h.dense.weight) 63 | 64 | def init_head(self): 65 | for env_name in self.envs: 66 | self.lm_heads[env_name] = RobertaLMHead(self.config) 67 | self.lm_heads[env_name].to('cuda') 68 | 69 | def init_base(self): 70 | self.encoder.init_weights() 71 | self.init_head() 72 | 73 | def get_input_embeddings(self): 74 | return self.encoder.get_input_embeddings() 75 | 76 | def set_input_embeddings(self, value): 77 | self.encoder.set_input_embeddings(value) 78 | # self.embeddings.word_embeddings = value 79 | 80 | def get_output_embeddings(self): 81 | for env, lm_head in self.lm_heads.items(): 82 | return lm_head.decoder 83 | # return self.lm_heads.decoder 84 | 85 | def set_output_embeddings(self, new_embeddings): 86 | for env, lm_head in self.lm_heads.items(): 87 | lm_head.decoder = new_embeddings 88 | # self.lm_head.decoder = new_embeddings 89 | 90 | # @add_start_docstrings_to_callable(ROBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length")) 91 | # @add_code_sample_docstrings( 92 | # tokenizer_class=_TOKENIZER_FOR_DOC, 93 | # checkpoint="roberta-base", 94 | # output_type=MaskedLMOutput, 95 | # config_class=_CONFIG_FOR_DOC, 96 | # mask="", 97 | # ) 98 | def forward( 99 | self, 100 | input_ids=None, 101 | attention_mask=None, 102 | token_type_ids=None, 103 | position_ids=None, 104 | head_mask=None, 105 | inputs_embeds=None, 106 | encoder_hidden_states=None, 107 | encoder_attention_mask=None, 108 | labels=None, 109 | output_attentions=None, 110 | output_hidden_states=None, 111 | return_dict=None, 112 | **kwargs 113 | ): 114 | r""" 115 | labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): 116 | Labels for computing the masked language modeling loss. 117 | Indices should be in ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring) 118 | Tokens with indices set to ``-100`` are ignored (masked), the loss is only computed for the tokens with labels 119 | in ``[0, ..., config.vocab_size]`` 120 | kwargs (:obj:`Dict[str, any]`, optional, defaults to `{}`): 121 | Used to hide legacy arguments that have been deprecated. 122 | """ 123 | if "masked_lm_labels" in kwargs: 124 | warnings.warn( 125 | "The `masked_lm_labels` argument is deprecated and will be removed in a future version, use `labels` instead.", 126 | FutureWarning, 127 | ) 128 | labels = kwargs.pop("masked_lm_labels") 129 | assert kwargs == {}, f"Unexpected keyword arguments: {list(kwargs.keys())}." 130 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 131 | 132 | outputs = self.encoder( 133 | input_ids, 134 | attention_mask=attention_mask, 135 | token_type_ids=token_type_ids, 136 | position_ids=position_ids, 137 | head_mask=head_mask, 138 | inputs_embeds=inputs_embeds, 139 | encoder_hidden_states=encoder_hidden_states, 140 | encoder_attention_mask=encoder_attention_mask, 141 | output_attentions=output_attentions, 142 | output_hidden_states=output_hidden_states, 143 | return_dict=return_dict, 144 | ) 145 | sequence_output = outputs[0] 146 | if self.n_environments == 1: 147 | prediction_scores = list(self.lm_heads.values())[0](sequence_output) 148 | else: 149 | prediction_scores = 0. 150 | for env, lm_head in self.lm_heads.items(): 151 | prediction_scores += 1. / self.n_environments * lm_head(sequence_output) 152 | 153 | masked_lm_loss = None 154 | if labels is not None: 155 | loss_fct = CrossEntropyLoss() 156 | masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)) 157 | 158 | if not return_dict: 159 | output = (prediction_scores,) + outputs[2:] 160 | return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output 161 | 162 | return MaskedLMOutput( 163 | loss=masked_lm_loss, 164 | logits=prediction_scores, 165 | hidden_states=outputs.hidden_states, 166 | attentions=outputs.attentions, 167 | ) 168 | -------------------------------------------------------------------------------- /invariant_trainer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data.dataloader import DataLoader 3 | from torch.utils.data.distributed import DistributedSampler 4 | from torch.utils.data.sampler import RandomSampler 5 | 6 | import transformers 7 | from transformers.optimization import Adafactor, AdamW, get_scheduler 8 | from transformers.trainer_callback import TrainerState 9 | from transformers.utils import logging 10 | 11 | from tqdm import tqdm 12 | 13 | import math 14 | import os 15 | import numpy as np 16 | from typing import List, Union, Dict, Optional 17 | 18 | logger = logging.get_logger(__name__) 19 | 20 | 21 | class InvariantTrainer(transformers.Trainer): 22 | 23 | def create_optimizer_and_scheduler(self, model, num_training_steps: int): 24 | """ 25 | Setup the optimizer and the learning rate scheduler. 26 | 27 | We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the 28 | Trainer's init through :obj:`optimizers`, or subclass and override this method in a subclass. 29 | """ 30 | optimizer, lr_scheduler = None, None 31 | # if self.optimizer is None: 32 | no_decay = ["bias", "LayerNorm.weight"] 33 | optimizer_grouped_parameters = [ 34 | { 35 | "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 36 | "weight_decay": self.args.weight_decay, 37 | }, 38 | { 39 | "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 40 | "weight_decay": 0.0, 41 | }, 42 | ] 43 | optimizer_cls = Adafactor if self.args.adafactor else AdamW 44 | if self.args.adafactor: 45 | optimizer_cls = Adafactor 46 | optimizer_kwargs = {"scale_parameter": False, "relative_step": False} 47 | else: 48 | optimizer_cls = AdamW 49 | optimizer_kwargs = { 50 | "betas": (self.args.adam_beta1, self.args.adam_beta2), 51 | "eps": self.args.adam_epsilon, 52 | } 53 | optimizer_kwargs["lr"] = self.args.learning_rate 54 | if self.sharded_dpp: 55 | optimizer = OSS( 56 | params=optimizer_grouped_parameters, 57 | optim=optimizer_cls, 58 | **optimizer_kwargs, 59 | ) 60 | else: 61 | optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs) 62 | 63 | lr_scheduler = get_scheduler( 64 | self.args.lr_scheduler_type, 65 | optimizer, 66 | num_warmup_steps=self.args.warmup_steps, 67 | num_training_steps=num_training_steps, 68 | ) 69 | 70 | return optimizer, lr_scheduler 71 | 72 | def remove_dataparallel_wrapper(self): 73 | if hasattr(self.model, 'module'): 74 | self.model = self.model.module 75 | 76 | def invariant_train( 77 | self, 78 | training_set, 79 | nb_steps: Optional[int] = None, 80 | nb_steps_heads_saving: Optional[int] = 0, 81 | resume_from_checkpoint: Optional[str] = None, 82 | num_train_epochs: Optional[int] = 1, 83 | nb_steps_model_saving: Optional[int] = 0, 84 | **kwargs, 85 | ): 86 | """ 87 | Main training entry point. 88 | 89 | Args: 90 | resume_from_checkpoint (:obj:`str`, `optional`): 91 | Local path to a saved checkpoint as saved by a previous instance of :class:`~transformers.Trainer`. If 92 | present, training will resume from the model/optimizer/scheduler states loaded here. 93 | trial (:obj:`optuna.Trial` or :obj:`Dict[str, Any]`, `optional`): 94 | The trial run or the hyperparameter dictionary for hyperparameter search. 95 | kwargs: 96 | Additional keyword arguments used to hide deprecated arguments 97 | """ 98 | if "model_path" in kwargs: 99 | resume_from_checkpoint = kwargs.pop("model_path") 100 | warnings.warn( 101 | "`model_path` is deprecated and will be removed in a future version. Use `resume_from_checkpoint` " 102 | "instead.", 103 | FutureWarning, 104 | ) 105 | 106 | if nb_steps is None and num_train_epochs is None: 107 | raise ValueError("Both nb_steps and num_train_epochs can't be None at the same time") 108 | 109 | if len(kwargs) > 0: 110 | raise TypeError(f"train() received got unexpected keyword arguments: {', '.join(list(kwargs.keys()))}.") 111 | 112 | min_train_set_size = min([len(data["train"]) for _, data in training_set.items()]) 113 | 114 | if nb_steps is not None: 115 | max_steps = nb_steps 116 | num_update_steps_per_epoch = math.floor( 117 | min_train_set_size / (self.args.gradient_accumulation_steps * self.args.train_batch_size)) 118 | num_train_epochs = max(1, math.floor(max_steps / num_update_steps_per_epoch)) 119 | else: 120 | num_update_steps_per_epoch = math.floor( 121 | min_train_set_size / (self.args.gradient_accumulation_steps * self.args.train_batch_size)) 122 | max_steps = num_update_steps_per_epoch * num_train_epochs 123 | 124 | dataloaders, optimizers, lr_schedulers = {}, {}, {} 125 | for env_name, data_features in training_set.items(): 126 | dataloaders[env_name] = self.get_single_train_dataloader(env_name, data_features["train"]) 127 | optimizer, lr_scheduler = self.create_optimizer_and_scheduler(self.model.lm_heads[env_name], 128 | num_training_steps=max_steps) 129 | optimizers[env_name] = optimizer 130 | lr_schedulers[env_name] = lr_scheduler 131 | 132 | optimizer, lr_scheduler = self.create_optimizer_and_scheduler(self.model.encoder, num_training_steps=max_steps) 133 | 134 | self.state = TrainerState() 135 | 136 | if self.args.n_gpu > 0: 137 | self.model.to('cuda') 138 | 139 | if self.args.n_gpu > 1: 140 | self.model = torch.nn.DataParallel(self.model) 141 | 142 | total_train_batch_size = self.args.train_batch_size * self.args.gradient_accumulation_steps 143 | num_examples = total_train_batch_size * max_steps 144 | 145 | logger.info("***** Running training *****") 146 | logger.info(f" Num examples = {num_examples}") 147 | logger.info(f" Num Epochs = {num_train_epochs}") 148 | logger.info(f" num_update_steps_per_epoch = {num_update_steps_per_epoch}") 149 | logger.info(f" Instantaneous batch size per device = {self.args.per_device_train_batch_size}") 150 | logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_train_batch_size}") 151 | logger.info(f" Gradient Accumulation steps = {self.args.gradient_accumulation_steps}") 152 | logger.info(f" Total optimization steps = {max_steps}") 153 | 154 | saving_heads = bool(nb_steps_heads_saving > 0) 155 | saving_intermediary_models = bool(nb_steps_model_saving > 0) 156 | total_trained_steps = 0 157 | 158 | for epoch in range(num_train_epochs): 159 | logger.info(f" Epoch: {epoch}") 160 | 161 | # make all dataloader iterateable 162 | iter_loaders = {} 163 | for env_name in training_set.keys(): 164 | train_loader = dataloaders[env_name] 165 | iter_loaders[env_name] = iter(train_loader) 166 | 167 | for steps_trained_in_current_epoch in tqdm(range(num_update_steps_per_epoch)): 168 | if total_trained_steps >= max_steps: 169 | break 170 | 171 | for env_name in training_set.keys(): 172 | logger.info(f" Update on environement {env_name}") 173 | # get a batch 174 | optimizer.zero_grad() 175 | optimizers[env_name].zero_grad() 176 | 177 | inputs = next(iter_loaders[env_name]) 178 | 179 | # make an update 180 | loss = self.training_step(self.model, inputs) 181 | 182 | if self.args.max_grad_norm is not None and self.args.max_grad_norm > 0: 183 | if self.use_amp: 184 | # AMP: gradients need unscaling 185 | self.scaler.unscale_(optimizer) 186 | self.scaler.unscale_(optimizers[env_name]) 187 | 188 | if hasattr(optimizer, "clip_grad_norm"): 189 | # Some optimizers (like the sharded optimizer) have a specific way to do gradient clipping 190 | optimizer.clip_grad_norm(self.args.max_grad_norm) 191 | optimizers[env_name].clip_grad_norm(self.args.max_grad_norm) 192 | else: 193 | # Revert to normal clipping otherwise, handling Apex or full precision 194 | torch.nn.utils.clip_grad_norm_( 195 | self.model.parameters(), 196 | self.args.max_grad_norm, 197 | ) 198 | 199 | if self.use_amp: 200 | self.scaler.step(optimizer) 201 | self.scaler.step(optimizers[env_name]) 202 | self.scaler.update() 203 | else: 204 | optimizer.step() 205 | optimizers[env_name].step() 206 | 207 | lr_scheduler.step() 208 | lr_schedulers[env_name].step() 209 | 210 | total_trained_steps += 1 211 | if saving_heads: 212 | if total_trained_steps % nb_steps_heads_saving == 0: 213 | self.save_heads(total_trained_steps) 214 | if saving_intermediary_models: 215 | if total_trained_steps % nb_steps_model_saving == 0: 216 | self.save_intermediary_model(total_trained_steps) 217 | 218 | def ensemble_train( 219 | self, 220 | training_set, 221 | nb_steps: Optional[int] = None, 222 | nb_steps_heads_saving: Optional[int] = 0, 223 | resume_from_checkpoint: Optional[str] = None, 224 | num_train_epochs: Optional[int] = 1, 225 | nb_steps_model_saving: Optional[int] = 0, 226 | **kwargs, 227 | ): 228 | """ 229 | Training the heads as en ensemble instead of following the IRM-games dynamic 230 | 231 | Args: 232 | resume_from_checkpoint (:obj:`str`, `optional`): 233 | Local path to a saved checkpoint as saved by a previous instance of :class:`~transformers.Trainer`. If 234 | present, training will resume from the model/optimizer/scheduler states loaded here. 235 | trial (:obj:`optuna.Trial` or :obj:`Dict[str, Any]`, `optional`): 236 | The trial run or the hyperparameter dictionary for hyperparameter search. 237 | kwargs: 238 | Additional keyword arguments used to hide deprecated arguments 239 | """ 240 | if "model_path" in kwargs: 241 | resume_from_checkpoint = kwargs.pop("model_path") 242 | warnings.warn( 243 | "`model_path` is deprecated and will be removed in a future version. Use `resume_from_checkpoint` " 244 | "instead.", 245 | FutureWarning, 246 | ) 247 | 248 | if nb_steps is None and num_train_epochs is None: 249 | raise ValueError("Both nb_steps and num_train_epochs can't be None at the same time") 250 | 251 | if len(kwargs) > 0: 252 | raise TypeError(f"train() received got unexpected keyword arguments: {', '.join(list(kwargs.keys()))}.") 253 | 254 | min_train_set_size = min([len(data["train"]) for _, data in training_set.items()]) 255 | 256 | if nb_steps is not None: 257 | max_steps = nb_steps 258 | num_update_steps_per_epoch = math.floor( 259 | min_train_set_size / (self.args.gradient_accumulation_steps * self.args.train_batch_size)) 260 | num_train_epochs = max(1, math.floor(max_steps / num_update_steps_per_epoch)) 261 | else: 262 | num_update_steps_per_epoch = math.floor( 263 | min_train_set_size / (self.args.gradient_accumulation_steps * self.args.train_batch_size)) 264 | max_steps = num_update_steps_per_epoch * num_train_epochs 265 | 266 | dataloaders, optimizers, lr_schedulers = {}, {}, {} 267 | for env_name, data_features in training_set.items(): 268 | dataloaders[env_name] = self.get_single_train_dataloader(env_name, data_features["train"]) 269 | optimizer, lr_scheduler = self.create_optimizer_and_scheduler(self.model.lm_heads[env_name], 270 | num_training_steps=max_steps) 271 | optimizers[env_name] = optimizer 272 | lr_schedulers[env_name] = lr_scheduler 273 | 274 | optimizer, lr_scheduler = self.create_optimizer_and_scheduler(self.model.encoder, num_training_steps=max_steps) 275 | 276 | self.state = TrainerState() 277 | 278 | if self.args.n_gpu > 0: 279 | self.model.to('cuda') 280 | 281 | if self.args.n_gpu > 1: 282 | self.model = torch.nn.DataParallel(self.model) 283 | 284 | total_train_batch_size = self.args.train_batch_size * self.args.gradient_accumulation_steps 285 | num_examples = total_train_batch_size * max_steps 286 | 287 | logger.info("***** Running training *****") 288 | logger.info(f" Num examples = {num_examples}") 289 | logger.info(f" Num Epochs = {num_train_epochs}") 290 | logger.info(f" num_update_steps_per_epoch = {num_update_steps_per_epoch}") 291 | logger.info(f" Instantaneous batch size per device = {self.args.per_device_train_batch_size}") 292 | logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_train_batch_size}") 293 | logger.info(f" Gradient Accumulation steps = {self.args.gradient_accumulation_steps}") 294 | logger.info(f" Total optimization steps = {max_steps}") 295 | 296 | saving_heads = bool(nb_steps_heads_saving > 0) 297 | saving_intermediary_models = bool(nb_steps_model_saving > 0) 298 | total_trained_steps = 0 299 | 300 | print("Num train epoch: ", num_train_epochs) 301 | print("Batch size: ", total_train_batch_size) 302 | print("Train data size: ", min_train_set_size) 303 | print("num_update_steps_per_epoch: ", num_update_steps_per_epoch) 304 | 305 | for epoch in range(num_train_epochs): 306 | logger.info(f" Epoch: {epoch}") 307 | print("epoch: ", epoch) 308 | # make all dataloader iterateable 309 | iter_loaders = {} 310 | for env_name in training_set.keys(): 311 | train_loader = dataloaders[env_name] 312 | iter_loaders[env_name] = iter(train_loader) 313 | 314 | for steps_trained_in_current_epoch in tqdm(range(num_update_steps_per_epoch)): 315 | if total_trained_steps >= max_steps: 316 | break 317 | 318 | for env_name in training_set.keys(): 319 | logger.info(f" Update on environement {env_name}") 320 | # get a batch 321 | optimizer.zero_grad() 322 | for e_n in training_set.keys(): 323 | optimizers[e_n].zero_grad() 324 | 325 | batch = next(iter_loaders[env_name]) 326 | # uncomment it, for CPU only run 327 | if self.args.n_gpu > 0: 328 | batch = batch.to('cuda') 329 | 330 | # loss.backward() is done inside training step 331 | loss = self.training_step(self.model, batch) 332 | 333 | if self.args.max_grad_norm is not None and self.args.max_grad_norm > 0: 334 | if self.use_amp: 335 | # AMP: gradients need unscaling 336 | self.scaler.unscale_(optimizer) 337 | for env_name in training_set.keys(): 338 | self.scaler.unscale_(optimizers[env_name]) 339 | 340 | if hasattr(optimizer, "clip_grad_norm"): 341 | # Some optimizers (like the sharded optimizer) have a specific way to do gradient clipping 342 | optimizer.clip_grad_norm(self.args.max_grad_norm) 343 | for e_n in training_set.keys(): 344 | optimizers[e_n].clip_grad_norm(self.args.max_grad_norm) 345 | else: 346 | # Revert to normal clipping otherwise, handling Apex or full precision 347 | torch.nn.utils.clip_grad_norm_( 348 | self.model.parameters(), 349 | self.args.max_grad_norm, 350 | ) 351 | 352 | if self.use_amp: 353 | self.scaler.step(optimizer) 354 | for e_n in training_set.keys(): 355 | self.scaler.step(optimizers[e_n]) 356 | self.scaler.update() 357 | else: 358 | optimizer.step() 359 | for e_n in training_set.keys(): 360 | optimizers[e_n].step() 361 | 362 | lr_scheduler.step() 363 | for e_n in training_set.keys(): 364 | lr_schedulers[e_n].step() 365 | 366 | total_trained_steps += 1 367 | if saving_heads: 368 | if total_trained_steps % nb_steps_heads_saving == 0: 369 | self.save_heads(total_trained_steps) 370 | if saving_intermediary_models: 371 | if total_trained_steps % nb_steps_model_saving == 0: 372 | self.save_intermediary_model(total_trained_steps) 373 | 374 | def save_intermediary_model(self, n_steps): 375 | fname = os.path.join(self.args.output_dir, f"model-{n_steps}") 376 | self.save_model(output_dir=fname) 377 | 378 | def save_heads(self, step_count): 379 | print("saving-heads") 380 | if not os.path.exists("lm_heads"): 381 | os.makedirs("lm_heads") 382 | 383 | for env, lm_head in self.model.lm_heads.items(): 384 | filepath = os.path.join("lm_heads", "{}-{}".format(env, step_count)) 385 | np.save(filepath, lm_head.dense.weight.data.cpu().numpy()) 386 | 387 | def get_single_train_dataloader(self, env_name, train_dataset): 388 | """ 389 | Create a single-task data loader that also yields task names 390 | """ 391 | if train_dataset is None: 392 | raise ValueError("Trainer: training requires a train_dataset.") 393 | # if is_tpu_available(): 394 | # train_sampler = get_tpu_sampler(train_dataset) 395 | # else: 396 | train_sampler = ( 397 | RandomSampler(train_dataset) 398 | if self.args.local_rank == -1 399 | else DistributedSampler(train_dataset) 400 | ) 401 | 402 | return DataLoader( 403 | train_dataset, 404 | batch_size=self.args.train_batch_size, 405 | sampler=train_sampler, 406 | collate_fn=self.data_collator 407 | ) 408 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | aiohttp==3.7.4.post0 2 | async-timeout==3.0.1 3 | attrs==21.2.0 4 | certifi==2021.10.8 5 | chardet==4.0.0 6 | charset-normalizer==2.0.7 7 | click==8.0.3 8 | dataclasses==0.6 9 | datasets==1.13.2 10 | dill==0.3.4 11 | filelock==3.3.0 12 | fsspec==2021.10.0 13 | future==0.18.2 14 | huggingface-hub==0.0.19 15 | idna==3.3 16 | importlib-metadata==4.8.1 17 | joblib==1.1.0 18 | multidict==5.2.0 19 | multiprocess==0.70.12.2 20 | numpy==1.19.2 21 | packaging==21.0 22 | pandas==1.3.3 23 | Pillow==8.3.2 24 | pyarrow==5.0.0 25 | pyparsing==2.4.7 26 | python-dateutil==2.8.2 27 | pytz==2021.3 28 | PyYAML==6.0 29 | regex==2021.10.8 30 | requests==2.26.0 31 | sacremoses==0.0.46 32 | six==1.16.0 33 | tokenizers==0.10.3 34 | torch==1.7.0 35 | torchvision==0.8.1 36 | tqdm==4.62.3 37 | transformers==4.3.3 38 | typing-extensions==3.10.0.2 39 | urllib3==1.26.7 40 | xxhash==2.0.2 41 | yarl==1.7.0 42 | zipp==3.6.0 43 | -------------------------------------------------------------------------------- /run_invariant_mlm.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | # Copyright 2020 The HuggingFace Team All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """ 17 | Fine-tuning the library models for masked language modeling (BERT, ALBERT, RoBERTa...) on a text file or a dataset. 18 | Here is the full list of checkpoints on the hub that can be fine-tuned by this script: 19 | https://huggingface.co/models?filter=masked-lm 20 | """ 21 | # You can also adapt this script on your own masked language modeling task. Pointers for this are left as comments. 22 | 23 | import logging 24 | import math 25 | import os 26 | import sys 27 | from dataclasses import dataclass, field 28 | from typing import Optional 29 | 30 | from datasets import load_dataset 31 | 32 | from invariant_trainer import InvariantTrainer 33 | from invariant_roberta import InvariantRobertaForMaskedLM, InvariantRobertaConfig 34 | from invariant_distilbert import InvariantDistilBertForMaskedLM, InvariantDistilBertConfig 35 | 36 | import transformers 37 | from transformers import ( 38 | CONFIG_MAPPING, 39 | TOKENIZER_MAPPING, 40 | MODEL_FOR_MASKED_LM_MAPPING, 41 | AutoConfig, 42 | AutoModel, 43 | AutoModelForMaskedLM, 44 | AutoTokenizer, 45 | DataCollatorForLanguageModeling, 46 | HfArgumentParser, 47 | # Trainer, 48 | TrainingArguments, 49 | set_seed, 50 | DistilBertTokenizer, 51 | DistilBertTokenizerFast, 52 | RobertaTokenizer, 53 | RobertaTokenizerFast 54 | ) 55 | from transformers.trainer_utils import get_last_checkpoint, is_main_process 56 | 57 | 58 | logger = logging.getLogger(__name__) 59 | MODEL_CONFIG_CLASSES = list(MODEL_FOR_MASKED_LM_MAPPING.keys()) 60 | MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES) 61 | 62 | CONFIG_MAPPING.update({'invariant-distilbert': InvariantDistilBertConfig}) 63 | CONFIG_MAPPING.update({'invariant-roberta': InvariantRobertaConfig}) 64 | 65 | MODEL_FOR_MASKED_LM_MAPPING.update({InvariantDistilBertConfig: InvariantDistilBertForMaskedLM}) 66 | MODEL_FOR_MASKED_LM_MAPPING.update({InvariantRobertaConfig: InvariantRobertaForMaskedLM}) 67 | 68 | TOKENIZER_MAPPING.update({InvariantDistilBertConfig: (DistilBertTokenizer, DistilBertTokenizerFast)}) 69 | TOKENIZER_MAPPING.update({InvariantRobertaConfig: (RobertaTokenizer, RobertaTokenizerFast)}) 70 | 71 | @dataclass 72 | class ModelArguments: 73 | """ 74 | Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch. 75 | """ 76 | 77 | model_name_or_path: Optional[str] = field( 78 | default=None, 79 | metadata={ 80 | "help": "The model checkpoint for weights initialization." 81 | "Don't set if you want to train a model from scratch." 82 | }, 83 | ) 84 | model_type: Optional[str] = field( 85 | default=None, 86 | metadata={"help": "If training from scratch, pass a model type from the list: " + ", ".join(MODEL_TYPES)}, 87 | ) 88 | config_name: Optional[str] = field( 89 | default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"} 90 | ) 91 | tokenizer_name: Optional[str] = field( 92 | default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"} 93 | ) 94 | cache_dir: Optional[str] = field( 95 | default=None, 96 | metadata={"help": "Where do you want to store the pretrained models downloaded from huggingface.co"}, 97 | ) 98 | use_fast_tokenizer: bool = field( 99 | default=True, 100 | metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."}, 101 | ) 102 | model_revision: str = field( 103 | default="main", 104 | metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."}, 105 | ) 106 | use_auth_token: bool = field( 107 | default=False, 108 | metadata={ 109 | "help": "Will use the token generated when running `transformers-cli login` (necessary to use this script " 110 | "with private models)." 111 | }, 112 | ) 113 | init_head: Optional[bool] = field( 114 | default=False, 115 | metadata={"help": "Re-initialize the language modeling heads to random weights before training"} 116 | ) 117 | init_base: Optional[bool] = field( 118 | default=False, 119 | metadata={"help": "Re-initialize the base language model (and thus the language modeling heads) before training"} 120 | ) 121 | ensembling: Optional[bool] = field( 122 | default=False, 123 | metadata={ 124 | "help": "Whether to train the heads as an ensemble instead of following the IRM-games dynamics"} 125 | ) 126 | nb_steps_heads_saving: Optional[int] = field( 127 | default=0, 128 | metadata={"help": "Number of training steps between saving the head weights (if 0, the heads are not saved regularly)."}, 129 | ) 130 | nb_steps_model_saving: Optional[int] = field( 131 | default=0, 132 | metadata={ 133 | "help": "Number of training steps between saving the full model (if 0, the heads are not saved regularly)."}, 134 | ) 135 | do_lower_case: Optional[bool] = field( 136 | default=True, 137 | metadata={"help": "Lower-case during tokenization."}, 138 | ) 139 | 140 | 141 | @dataclass 142 | class DataTrainingArguments: 143 | """ 144 | Arguments pertaining to what data we are going to input our model for training and eval. 145 | """ 146 | 147 | dataset_name: Optional[str] = field( 148 | default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."} 149 | ) 150 | dataset_config_name: Optional[str] = field( 151 | default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."} 152 | ) 153 | train_file: Optional[str] = field(default=None, metadata={"help": "The input training data file (a text file)."}) 154 | validation_file: Optional[str] = field( 155 | default=None, 156 | metadata={"help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."}, 157 | ) 158 | overwrite_cache: bool = field( 159 | default=False, metadata={"help": "Overwrite the cached training and evaluation sets"} 160 | ) 161 | validation_split_percentage: Optional[int] = field( 162 | default=5, 163 | metadata={ 164 | "help": "The percentage of the train set used as validation set in case there's no validation split" 165 | }, 166 | ) 167 | max_seq_length: Optional[int] = field( 168 | default=None, 169 | metadata={ 170 | "help": "The maximum total input sequence length after tokenization. Sequences longer " 171 | "than this will be truncated." 172 | }, 173 | ) 174 | preprocessing_num_workers: Optional[int] = field( 175 | default=None, 176 | metadata={"help": "The number of processes to use for the preprocessing."}, 177 | ) 178 | mlm_probability: float = field( 179 | default=0.15, metadata={"help": "Ratio of tokens to mask for masked language modeling loss"} 180 | ) 181 | line_by_line: bool = field( 182 | default=False, 183 | metadata={"help": "Whether distinct lines of text in the dataset are to be handled as distinct sequences."}, 184 | ) 185 | pad_to_max_length: bool = field( 186 | default=False, 187 | metadata={ 188 | "help": "Whether to pad all samples to `max_seq_length`. " 189 | "If False, will pad the samples dynamically when batching to the maximum length in the batch." 190 | }, 191 | ) 192 | nb_steps: Optional[int] = field( 193 | default=0, 194 | metadata={"help": "Number of training steps."}, 195 | ) 196 | 197 | def __post_init__(self): 198 | if self.dataset_name is None and self.train_file is None and self.validation_file is None: 199 | raise ValueError("Need either a dataset name or a training/validation file.") 200 | # else: 201 | # continue 202 | 203 | # if self.train_file is not None: 204 | # extension = self.train_file.split(".")[-1] 205 | # assert extension in ["csv", "json", "txt"], "`train_file` should be a csv, a json or a txt file." 206 | # if self.validation_file is not None: 207 | # extension = self.validation_file.split(".")[-1] 208 | # assert extension in ["csv", "json", "txt"], "`validation_file` should be a csv, a json or a txt file." 209 | 210 | 211 | def main(): 212 | # See all possible arguments in src/transformers/training_args.py 213 | # or by passing the --help flag to this script. 214 | # We now keep distinct sets of args, for a cleaner separation of concerns. 215 | 216 | parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments)) 217 | if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): 218 | # If we pass only one argument to the script and it's the path to a json file, 219 | # let's parse it to get our arguments. 220 | model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) 221 | else: 222 | model_args, data_args, training_args = parser.parse_args_into_dataclasses() 223 | 224 | nb_steps = data_args.nb_steps 225 | 226 | # Detecting last checkpoint. 227 | last_checkpoint = None 228 | if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir: 229 | last_checkpoint = get_last_checkpoint(training_args.output_dir) 230 | if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0: 231 | raise ValueError( 232 | f"Output directory ({training_args.output_dir}) already exists and is not empty. " 233 | "Use --overwrite_output_dir to overcome." 234 | ) 235 | elif last_checkpoint is not None: 236 | logger.info( 237 | f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change " 238 | "the `--output_dir` or add `--overwrite_output_dir` to train from scratch." 239 | ) 240 | 241 | # Setup logging 242 | logging.basicConfig( 243 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 244 | datefmt="%m/%d/%Y %H:%M:%S", 245 | handlers=[logging.StreamHandler(sys.stdout)], 246 | ) 247 | logger.setLevel(logging.INFO if is_main_process(training_args.local_rank) else logging.WARN) 248 | 249 | # Log on each process the small summary: 250 | logger.info( 251 | f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}" 252 | + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}" 253 | ) 254 | 255 | # Set the verbosity to info of the Transformers logger (on main process only): 256 | if is_main_process(training_args.local_rank): 257 | transformers.utils.logging.set_verbosity_info() 258 | transformers.utils.logging.enable_default_handler() 259 | transformers.utils.logging.enable_explicit_format() 260 | logger.info("Training/evaluation parameters %s", training_args) 261 | 262 | # Set seed before initializing model. 263 | set_seed(training_args.seed) 264 | 265 | # Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below) 266 | # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/ 267 | # (the dataset will be downloaded automatically from the datasets Hub 268 | # 269 | # For CSV/JSON files, this script will use the column called 'text' or the first column. You can easily tweak this 270 | # behavior (see below) 271 | # 272 | # In distributed training, the load_dataset function guarantee that only one local process can concurrently 273 | # download the dataset. 274 | irm_folder = data_args.train_file 275 | irm_datasets = {} 276 | for file in os.listdir(irm_folder): 277 | if file.endswith('.txt'): 278 | env_name = file.split(".")[0] 279 | data_files = {} 280 | data_files["train"] = os.path.join(irm_folder, file) 281 | datasets = load_dataset("text", data_files=data_files) 282 | irm_datasets[env_name] = datasets 283 | 284 | if data_args.validation_file is not None: 285 | data_files = {} 286 | data_files["validation"] = data_args.validation_file 287 | eval_datasets = load_dataset("text", data_files=data_files) 288 | irm_datasets['validation-file'] = eval_datasets 289 | 290 | # 291 | # Distributed training: 292 | # The .from_pretrained methods guarantee that only one local process can concurrently 293 | # download model & vocab. 294 | config_kwargs = { 295 | "cache_dir": model_args.cache_dir, 296 | "revision": model_args.model_revision, 297 | "use_auth_token": True if model_args.use_auth_token else None, 298 | } 299 | if model_args.config_name: 300 | config = AutoConfig.from_pretrained(model_args.config_name, **config_kwargs) 301 | elif model_args.model_name_or_path: 302 | config = AutoConfig.from_pretrained(model_args.model_name_or_path, **config_kwargs) 303 | else: 304 | config = CONFIG_MAPPING[model_args.model_type]() 305 | logger.warning("You are instantiating a new config instance from scratch.") 306 | 307 | tokenizer_kwargs = { 308 | "cache_dir": model_args.cache_dir, 309 | "use_fast": model_args.use_fast_tokenizer, 310 | "revision": model_args.model_revision, 311 | "use_auth_token": True if model_args.use_auth_token else None, 312 | } 313 | if model_args.tokenizer_name: 314 | tokenizer = AutoTokenizer.from_pretrained(model_args.tokenizer_name, **tokenizer_kwargs) 315 | elif model_args.model_name_or_path: 316 | tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path, **tokenizer_kwargs) 317 | else: 318 | raise ValueError( 319 | "You are instantiating a new tokenizer from scratch. This is not supported by this script." 320 | "You can do it from another script, save it, and load it from here, using --tokenizer_name." 321 | ) 322 | 323 | if model_args.model_name_or_path: 324 | model = AutoModelForMaskedLM.from_pretrained( 325 | model_args.model_name_or_path, 326 | from_tf=bool(".ckpt" in model_args.model_name_or_path), 327 | config=config, 328 | cache_dir=model_args.cache_dir, 329 | revision=model_args.model_revision, 330 | use_auth_token=True if model_args.use_auth_token else None, 331 | ) 332 | else: 333 | logger.info("Training new model from scratch") 334 | model = AutoModelForMaskedLM.from_config(config) 335 | 336 | envs = [k for k in irm_datasets.keys() if 'validation' not in k] 337 | 338 | def is_jsonable(x): 339 | import json 340 | try: 341 | json.dumps(x) 342 | return True 343 | except: 344 | return False 345 | 346 | if 'envs' not in config.to_dict(): #if we didn't already load from pretrained an irm model 347 | if 'distil' in model_args.model_name_or_path: 348 | inv_config = InvariantDistilBertConfig(envs=envs, **config.to_dict()) 349 | irm_model = InvariantDistilBertForMaskedLM(inv_config, model) 350 | else: 351 | inv_config = InvariantRobertaConfig(envs=envs, **config.to_dict()) 352 | irm_model = InvariantRobertaForMaskedLM(inv_config, model) 353 | else: 354 | irm_model = model 355 | 356 | irm_model.resize_token_embeddings(len(tokenizer)) 357 | 358 | if model_args.init_head: 359 | irm_model.init_head() 360 | if model_args.init_base: 361 | irm_model.init_base() 362 | 363 | # Preprocessing the datasets. 364 | # First we tokenize all the texts. 365 | irm_tokenized_datasets = {} 366 | for env_name, datasets in irm_datasets.items(): 367 | if training_args.do_train and 'validation' not in env_name: 368 | column_names = datasets["train"].column_names 369 | elif training_args.do_eval and 'validation' in env_name: 370 | column_names = datasets["validation"].column_names 371 | text_column_name = "text" if "text" in column_names else column_names[0] 372 | 373 | if data_args.max_seq_length is None: 374 | max_seq_length = tokenizer.model_max_length 375 | if max_seq_length > 1024: 376 | logger.warn( 377 | f"The tokenizer picked seems to have a very large `model_max_length` ({tokenizer.model_max_length}). " 378 | "Picking 1024 instead. You can change that default value by passing --max_seq_length xxx." 379 | ) 380 | max_seq_length = 1024 381 | else: 382 | if data_args.max_seq_length > tokenizer.model_max_length: 383 | logger.warn( 384 | f"The max_seq_length passed ({data_args.max_seq_length}) is larger than the maximum length for the" 385 | f"model ({tokenizer.model_max_length}). Using max_seq_length={tokenizer.model_max_length}." 386 | ) 387 | max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length) 388 | 389 | if data_args.line_by_line: 390 | # When using line_by_line, we just tokenize each nonempty line. 391 | padding = "max_length" if data_args.pad_to_max_length else False 392 | 393 | def tokenize_function(examples): 394 | # Remove empty lines 395 | examples["text"] = [line for line in examples["text"] if len(line) > 0 and not line.isspace()] 396 | return tokenizer( 397 | examples["text"], 398 | padding=padding, 399 | truncation=True, 400 | max_length=max_seq_length, 401 | # We use this option because DataCollatorForLanguageModeling (see below) is more efficient when it 402 | # receives the `special_tokens_mask`. 403 | return_special_tokens_mask=True, 404 | ) 405 | 406 | tokenized_datasets = datasets.map( 407 | tokenize_function, 408 | batched=True, 409 | num_proc=data_args.preprocessing_num_workers, 410 | remove_columns=[text_column_name], 411 | load_from_cache_file=not data_args.overwrite_cache, 412 | ) 413 | irm_tokenized_datasets[env_name] = tokenized_datasets 414 | else: 415 | # Otherwise, we tokenize every text, then concatenate them together before splitting them in smaller parts. 416 | # We use `return_special_tokens_mask=True` because DataCollatorForLanguageModeling (see below) is more 417 | # efficient when it receives the `special_tokens_mask`. 418 | def tokenize_function(examples): 419 | return tokenizer(examples[text_column_name], return_special_tokens_mask=True) 420 | 421 | tokenized_datasets = datasets.map( 422 | tokenize_function, 423 | batched=True, 424 | num_proc=data_args.preprocessing_num_workers, 425 | remove_columns=column_names, 426 | load_from_cache_file=not data_args.overwrite_cache, 427 | ) 428 | 429 | # Main data processing function that will concatenate all texts from our dataset and generate chunks of 430 | # max_seq_length. 431 | def group_texts(examples): 432 | # Concatenate all texts. 433 | concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()} 434 | total_length = len(concatenated_examples[list(examples.keys())[0]]) 435 | # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can 436 | # customize this part to your needs. 437 | total_length = (total_length // max_seq_length) * max_seq_length 438 | # Split by chunks of max_len. 439 | result = { 440 | k: [t[i: i + max_seq_length] for i in range(0, total_length, max_seq_length)] 441 | for k, t in concatenated_examples.items() 442 | } 443 | return result 444 | 445 | # Note that with `batched=True`, this map processes 1,000 texts together, so group_texts throws away a 446 | # remainder for each of those groups of 1,000 texts. You can adjust that batch_size here but a higher value 447 | # might be slower to preprocess. 448 | # 449 | # To speed up this part, we use multiprocessing. See the documentation of the map method for more information: 450 | # https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.map 451 | tokenized_datasets = tokenized_datasets.map( 452 | group_texts, 453 | batched=True, 454 | num_proc=data_args.preprocessing_num_workers, 455 | load_from_cache_file=not data_args.overwrite_cache, 456 | ) 457 | irm_tokenized_datasets[env_name] = tokenized_datasets 458 | 459 | # Data collator 460 | # This one will take care of randomly masking the tokens. 461 | data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm_probability=data_args.mlm_probability) 462 | 463 | train_tokenized_datasets = {k: v for k, v in irm_tokenized_datasets.items() if not('validation-file' in k)} 464 | eval_tokenized_datasets = irm_tokenized_datasets['validation-file']['validation'] 465 | 466 | # Initialize our Trainer 467 | trainer = InvariantTrainer( 468 | model=irm_model, 469 | args=training_args, 470 | # train_dataset=tokenized_datasets["train"] if training_args.do_train else None, 471 | eval_dataset=eval_tokenized_datasets if training_args.do_eval else None, 472 | tokenizer=tokenizer, 473 | data_collator=data_collator, 474 | ) 475 | 476 | # Training 477 | if training_args.do_train: 478 | if last_checkpoint is not None: 479 | checkpoint = last_checkpoint 480 | elif model_args.model_name_or_path is not None and os.path.isdir(model_args.model_name_or_path): 481 | checkpoint = model_args.model_name_or_path 482 | else: 483 | checkpoint = None 484 | 485 | if model_args.ensembling: 486 | logger.info("TRAINING WITH ENSEMBLE -- NOT FOLLOWING IRM-GAMES DYNAMIC") 487 | train_result = trainer.ensemble_train(training_set=train_tokenized_datasets, 488 | nb_steps=nb_steps, 489 | nb_steps_heads_saving=model_args.nb_steps_heads_saving, 490 | nb_steps_model_saving=model_args.nb_steps_model_saving, 491 | resume_from_checkpoint=checkpoint) 492 | else: 493 | train_result = trainer.invariant_train(training_set=train_tokenized_datasets, 494 | nb_steps=nb_steps, 495 | nb_steps_heads_saving=model_args.nb_steps_heads_saving, 496 | nb_steps_model_saving=model_args.nb_steps_model_saving, 497 | resume_from_checkpoint=checkpoint) 498 | trainer.save_model() # Saves the tokenizer too for easy upload 499 | 500 | output_train_file = os.path.join(training_args.output_dir, "train_results.txt") 501 | if trainer.is_world_process_zero(): 502 | with open(output_train_file, "w") as writer: 503 | logger.info("***** Train results *****") 504 | # for key, value in sorted(train_result.metrics.items()): 505 | # logger.info(f" {key} = {value}") 506 | # writer.write(f"{key} = {value}\n") 507 | 508 | # Need to save the state, since Trainer.save_model saves only the tokenizer with the model 509 | trainer.state.save_to_json(os.path.join(training_args.output_dir, "trainer_state.json")) 510 | 511 | # metrics = train_result.metrics 512 | 513 | # trainer.log_metrics("train", metrics) 514 | # trainer.save_metrics("train", metrics) 515 | # trainer.save_state() 516 | 517 | # Evaluation 518 | results = {} 519 | if training_args.do_eval: 520 | logger.info("*** Evaluate ***") 521 | 522 | eval_output = trainer.evaluate() 523 | 524 | perplexity = math.exp(eval_output["eval_loss"]) 525 | results["perplexity"] = perplexity 526 | output_eval_file = os.path.join(training_args.output_dir, "eval_results_mlm.txt") 527 | if trainer.is_world_process_zero(): 528 | with open(output_eval_file, "w") as writer: 529 | logger.info("***** Eval results *****") 530 | for key, value in sorted(results.items()): 531 | logger.info(f" {key} = {value}") 532 | # writer.write(f"{key} = {value}\n") 533 | 534 | # trainer.log_metrics("eval", results) 535 | # trainer.save_metrics("eval", results) 536 | 537 | return results 538 | 539 | 540 | def _mp_fn(index): 541 | # For xla_spawn (TPUs) 542 | main() 543 | 544 | 545 | if __name__ == "__main__": 546 | main() 547 | --------------------------------------------------------------------------------