├── .gitignore ├── CODE_OF_CONDUCT.md ├── LICENSE ├── README.md ├── SECURITY.md ├── SUPPORT.md ├── mutransformers ├── __init__.py └── models │ ├── __init__.py │ ├── bert │ ├── __init__.py │ ├── _original_configuration_bert.py │ ├── _original_modeling_bert.py │ ├── configuration_bert.py │ └── modeling_bert.py │ ├── gpt2 │ ├── __init__.py │ ├── _original_configuration_gpt2.py │ ├── _original_modeling_gpt2.py │ ├── configuration_gpt2.py │ └── modeling_gpt2.py │ └── roberta │ ├── __init__.py │ ├── _original_configuration_roberta.py │ ├── _original_modeling_roberta.py │ ├── configuration_roberta.py │ └── modeling_roberta.py ├── requirements.txt ├── setup.py └── tests └── coordcheck ├── CoordCheck.ipynb ├── bert_mup_dhead_coord_check.png ├── bert_mup_nhead_coord_check.png ├── bert_sp_dhead_coord_check.png ├── bert_sp_nhead_coord_check.png ├── coordcheck.py ├── gpt2_mup_dhead_coord_check.png ├── gpt2_mup_nhead_coord_check.png ├── gpt2_sp_dhead_coord_check.png ├── gpt2_sp_nhead_coord_check.png ├── roberta_mup_dhead_coord_check.png ├── roberta_mup_nhead_coord_check.png ├── roberta_sp_dhead_coord_check.png └── roberta_sp_nhead_coord_check.png /.gitignore: -------------------------------------------------------------------------------- 1 | **/.ipynb_checkpoints 2 | 3 | # Compiled python modules. 4 | *.pyc 5 | 6 | # Byte-compiled 7 | _pycache__/ 8 | .cache/ 9 | 10 | # Python egg metadata, regenerated from source files by setuptools. 11 | *.egg-info 12 | .eggs/ 13 | 14 | # PyPI distribution artifacts. 15 | build/ 16 | dist/ 17 | 18 | # Environments 19 | .env 20 | .venv 21 | env/ 22 | venv/ 23 | ENV/ 24 | env.bak/ 25 | venv.bak/ 26 | 27 | # PyCharm/vscode 28 | .idea 29 | .vscode 30 | 31 | # Vim 32 | .*.swp 33 | 34 | # Other 35 | *.DS_Store -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Microsoft Open Source Code of Conduct 2 | 3 | This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/). 4 | 5 | Resources: 6 | 7 | - [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/) 8 | - [Microsoft Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) 9 | - Contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with questions or concerns 10 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) Microsoft Corporation. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # muTransformers 2 | 3 | This repo implements [muP](https://github.com/microsoft/mup) ([paper](https://arxiv.org/abs/2203.03466)) for selected PyTorch models in [Huggingface Transformers](https://github.com/huggingface/transformers). 4 | The primary purpose of this repo is as a clean demonstration of how to inject muP into different variants of transformers. 5 | As a secondary purpose, one can also use the models here as provided. 6 | 7 | ## Installation 8 | 9 | Go to this project directory and do 10 | ``` 11 | pip install -r requirements.txt 12 | pip install -e . 13 | ``` 14 | 15 | ## Injecting muP into Existing Transformers 16 | 17 | Taking BERT as an example, there are two files [`modeling_bert.py`](mutransformers/models/bert/modeling_bert.py) and [`configuration_bert.py`](mutransformers/models/bert/configuration_bert.py) in `mutransformers/models/bert/` we copied from [Huggingface Transformers](https://github.com/huggingface/transformers) and made a small number of modifications to implement muP. 18 | Our modifications in these files can all be found by searching for `### muP`. 19 | 20 | These files are copied from [Huggingface Transformers v4.16.2](https://github.com/huggingface/transformers/tree/v4.16.2). We provide the original files as `_original_*.py` for easy comparison, for example, [`_original_modeling_bert.py`](mutransformers/models/bert/_original_modeling_bert.py). 21 | 22 | ## Coord Check 23 | 24 | [Coordinate checking](https://github.com/microsoft/mup#coord-check) is a way of verifying that muP is implemented correctly just like gradient checking is a way of verifying that autograd is implemented correctly. 25 | You can find the coord check results in [`tests/coordcheck/CoordCheck.ipynb`](tests/coordcheck/CoordCheck.ipynb). 26 | You can rerun the notebook yourself as well after installation. 27 | 28 | For example, the coord check for BERT in standard parametrization (SP) shows many activations blow up with width, 29 | ![](tests/coordcheck/bert_sp_dhead_coord_check.png) 30 | but the same for BERT in muP shows activation scale consistent with width. 31 | ![](tests/coordcheck/bert_mup_dhead_coord_check.png) 32 | 33 | ## Basic Usage of Models 34 | The models here can be used for your training purposes as well, though we have not made sure to replicate the original numbers of each of these transformer models. 35 | The models in this package can be used as follows, taking BERT as an example: 36 | ```python 37 | from mutransformers import BertConfig, BertForMaskedLM 38 | from mup import make_base_shapes, set_base_shapes, MuAdamW 39 | from functools import partial 40 | # define a base model 41 | base_config = BertConfig( 42 | hidden_size=256, 43 | intermediate_size=256, 44 | num_attention_heads=16, 45 | ) 46 | base_model = BertForMaskedLM(config=base_config) 47 | # define a delta models where we vary all "widths" we want to vary 48 | delta_config = BertConfig( 49 | hidden_size=200, 50 | intermediate_size=300, 51 | num_attention_heads=5, 52 | ) 53 | delta_model = BertForMaskedLM(config=delta_config) 54 | # define a base shape object based on comparing delta_model against base_model 55 | base_shapes = make_base_shapes(base_model, delta_model, savefile='bert256.bsh') 56 | 57 | # define target model 58 | target_config = BertConfig( 59 | hidden_size=1024, 60 | intermediate_size=1024*4, 61 | num_attention_heads=32, 62 | ) 63 | target_model = BertForMaskedLM(config=target_config) 64 | 65 | # set base shapes 66 | set_base_shapes(target_model, base_shapes) 67 | # you can alternatively load base shape from file 68 | # set_base_shapes(target_model, 'bert256.bsh') 69 | 70 | # re-initialize 71 | target_model.apply(target_model._init_weights) 72 | 73 | # make sure to use mup optimizers for training 74 | optimizer = MuAdamW(target_model.parameters(), lr=1e-3) 75 | 76 | # train 77 | ... 78 | ``` 79 | 80 | For more general information on how to use `mup`, see [the muP package documentation](https://github.com/microsoft/mup#basic-usage). 81 | 82 | ## Contributing 83 | 84 | This project welcomes contributions and suggestions. Most contributions require you to agree to a 85 | Contributor License Agreement (CLA) declaring that you have the right to, and actually do, grant us 86 | the rights to use your contribution. For details, visit https://cla.opensource.microsoft.com. 87 | 88 | When you submit a pull request, a CLA bot will automatically determine whether you need to provide 89 | a CLA and decorate the PR appropriately (e.g., status check, comment). Simply follow the instructions 90 | provided by the bot. You will only need to do this once across all repos using our CLA. 91 | 92 | This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/). 93 | For more information see the [Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) or 94 | contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with any additional questions or comments. 95 | 96 | ## Trademarks 97 | 98 | This project may contain trademarks or logos for projects, products, or services. Authorized use of Microsoft 99 | trademarks or logos is subject to and must follow 100 | [Microsoft's Trademark & Brand Guidelines](https://www.microsoft.com/en-us/legal/intellectualproperty/trademarks/usage/general). 101 | Use of Microsoft trademarks or logos in modified versions of this project must not cause confusion or imply Microsoft sponsorship. 102 | Any use of third-party trademarks or logos are subject to those third-party's policies. 103 | -------------------------------------------------------------------------------- /SECURITY.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | ## Security 4 | 5 | Microsoft takes the security of our software products and services seriously, which includes all source code repositories managed through our GitHub organizations, which include [Microsoft](https://github.com/Microsoft), [Azure](https://github.com/Azure), [DotNet](https://github.com/dotnet), [AspNet](https://github.com/aspnet), [Xamarin](https://github.com/xamarin), and [our GitHub organizations](https://opensource.microsoft.com/). 6 | 7 | If you believe you have found a security vulnerability in any Microsoft-owned repository that meets [Microsoft's definition of a security vulnerability](https://docs.microsoft.com/en-us/previous-versions/tn-archive/cc751383(v=technet.10)), please report it to us as described below. 8 | 9 | ## Reporting Security Issues 10 | 11 | **Please do not report security vulnerabilities through public GitHub issues.** 12 | 13 | Instead, please report them to the Microsoft Security Response Center (MSRC) at [https://msrc.microsoft.com/create-report](https://msrc.microsoft.com/create-report). 14 | 15 | If you prefer to submit without logging in, send email to [secure@microsoft.com](mailto:secure@microsoft.com). If possible, encrypt your message with our PGP key; please download it from the [Microsoft Security Response Center PGP Key page](https://www.microsoft.com/en-us/msrc/pgp-key-msrc). 16 | 17 | You should receive a response within 24 hours. If for some reason you do not, please follow up via email to ensure we received your original message. Additional information can be found at [microsoft.com/msrc](https://www.microsoft.com/msrc). 18 | 19 | Please include the requested information listed below (as much as you can provide) to help us better understand the nature and scope of the possible issue: 20 | 21 | * Type of issue (e.g. buffer overflow, SQL injection, cross-site scripting, etc.) 22 | * Full paths of source file(s) related to the manifestation of the issue 23 | * The location of the affected source code (tag/branch/commit or direct URL) 24 | * Any special configuration required to reproduce the issue 25 | * Step-by-step instructions to reproduce the issue 26 | * Proof-of-concept or exploit code (if possible) 27 | * Impact of the issue, including how an attacker might exploit the issue 28 | 29 | This information will help us triage your report more quickly. 30 | 31 | If you are reporting for a bug bounty, more complete reports can contribute to a higher bounty award. Please visit our [Microsoft Bug Bounty Program](https://microsoft.com/msrc/bounty) page for more details about our active programs. 32 | 33 | ## Preferred Languages 34 | 35 | We prefer all communications to be in English. 36 | 37 | ## Policy 38 | 39 | Microsoft follows the principle of [Coordinated Vulnerability Disclosure](https://www.microsoft.com/en-us/msrc/cvd). 40 | 41 | -------------------------------------------------------------------------------- /SUPPORT.md: -------------------------------------------------------------------------------- 1 | # Support 2 | 3 | ## How to file issues and get help 4 | 5 | This project uses GitHub Issues to track bugs and feature requests. Please search the existing 6 | issues before filing new issues to avoid duplicates. For new issues, file your bug or 7 | feature request as a new Issue. 8 | 9 | For help and questions about using this project, please use Github Discussions in this repo. 10 | 11 | ## Microsoft Support Policy 12 | 13 | Support for this project is limited to the resources listed above. 14 | -------------------------------------------------------------------------------- /mutransformers/__init__.py: -------------------------------------------------------------------------------- 1 | from mutransformers.models.bert.modeling_bert import * 2 | from mutransformers.models.bert.configuration_bert import * 3 | from mutransformers.models.roberta.modeling_roberta import * 4 | from mutransformers.models.roberta.configuration_roberta import * 5 | from mutransformers.models.gpt2.modeling_gpt2 import * 6 | from mutransformers.models.gpt2.configuration_gpt2 import * -------------------------------------------------------------------------------- /mutransformers/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/mutransformers/480287ce7b18a07a3432e8f2fbc0f0e5b71e2599/mutransformers/models/__init__.py -------------------------------------------------------------------------------- /mutransformers/models/bert/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/mutransformers/480287ce7b18a07a3432e8f2fbc0f0e5b71e2599/mutransformers/models/bert/__init__.py -------------------------------------------------------------------------------- /mutransformers/models/bert/_original_configuration_bert.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. 3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """ BERT model configuration""" 17 | from collections import OrderedDict 18 | from typing import Mapping 19 | 20 | from ...configuration_utils import PretrainedConfig 21 | from ...onnx import OnnxConfig 22 | from ...utils import logging 23 | 24 | 25 | logger = logging.get_logger(__name__) 26 | 27 | BERT_PRETRAINED_CONFIG_ARCHIVE_MAP = { 28 | "bert-base-uncased": "https://huggingface.co/bert-base-uncased/resolve/main/config.json", 29 | "bert-large-uncased": "https://huggingface.co/bert-large-uncased/resolve/main/config.json", 30 | "bert-base-cased": "https://huggingface.co/bert-base-cased/resolve/main/config.json", 31 | "bert-large-cased": "https://huggingface.co/bert-large-cased/resolve/main/config.json", 32 | "bert-base-multilingual-uncased": "https://huggingface.co/bert-base-multilingual-uncased/resolve/main/config.json", 33 | "bert-base-multilingual-cased": "https://huggingface.co/bert-base-multilingual-cased/resolve/main/config.json", 34 | "bert-base-chinese": "https://huggingface.co/bert-base-chinese/resolve/main/config.json", 35 | "bert-base-german-cased": "https://huggingface.co/bert-base-german-cased/resolve/main/config.json", 36 | "bert-large-uncased-whole-word-masking": "https://huggingface.co/bert-large-uncased-whole-word-masking/resolve/main/config.json", 37 | "bert-large-cased-whole-word-masking": "https://huggingface.co/bert-large-cased-whole-word-masking/resolve/main/config.json", 38 | "bert-large-uncased-whole-word-masking-finetuned-squad": "https://huggingface.co/bert-large-uncased-whole-word-masking-finetuned-squad/resolve/main/config.json", 39 | "bert-large-cased-whole-word-masking-finetuned-squad": "https://huggingface.co/bert-large-cased-whole-word-masking-finetuned-squad/resolve/main/config.json", 40 | "bert-base-cased-finetuned-mrpc": "https://huggingface.co/bert-base-cased-finetuned-mrpc/resolve/main/config.json", 41 | "bert-base-german-dbmdz-cased": "https://huggingface.co/bert-base-german-dbmdz-cased/resolve/main/config.json", 42 | "bert-base-german-dbmdz-uncased": "https://huggingface.co/bert-base-german-dbmdz-uncased/resolve/main/config.json", 43 | "cl-tohoku/bert-base-japanese": "https://huggingface.co/cl-tohoku/bert-base-japanese/resolve/main/config.json", 44 | "cl-tohoku/bert-base-japanese-whole-word-masking": "https://huggingface.co/cl-tohoku/bert-base-japanese-whole-word-masking/resolve/main/config.json", 45 | "cl-tohoku/bert-base-japanese-char": "https://huggingface.co/cl-tohoku/bert-base-japanese-char/resolve/main/config.json", 46 | "cl-tohoku/bert-base-japanese-char-whole-word-masking": "https://huggingface.co/cl-tohoku/bert-base-japanese-char-whole-word-masking/resolve/main/config.json", 47 | "TurkuNLP/bert-base-finnish-cased-v1": "https://huggingface.co/TurkuNLP/bert-base-finnish-cased-v1/resolve/main/config.json", 48 | "TurkuNLP/bert-base-finnish-uncased-v1": "https://huggingface.co/TurkuNLP/bert-base-finnish-uncased-v1/resolve/main/config.json", 49 | "wietsedv/bert-base-dutch-cased": "https://huggingface.co/wietsedv/bert-base-dutch-cased/resolve/main/config.json", 50 | # See all BERT models at https://huggingface.co/models?filter=bert 51 | } 52 | 53 | 54 | class BertConfig(PretrainedConfig): 55 | r""" 56 | This is the configuration class to store the configuration of a [`BertModel`] or a [`TFBertModel`]. It is used to 57 | instantiate a BERT model according to the specified arguments, defining the model architecture. Instantiating a 58 | configuration with the defaults will yield a similar configuration to that of the BERT 59 | [bert-base-uncased](https://huggingface.co/bert-base-uncased) architecture. 60 | 61 | Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the 62 | documentation from [`PretrainedConfig`] for more information. 63 | 64 | 65 | Args: 66 | vocab_size (`int`, *optional*, defaults to 30522): 67 | Vocabulary size of the BERT model. Defines the number of different tokens that can be represented by the 68 | `inputs_ids` passed when calling [`BertModel`] or [`TFBertModel`]. 69 | hidden_size (`int`, *optional*, defaults to 768): 70 | Dimensionality of the encoder layers and the pooler layer. 71 | num_hidden_layers (`int`, *optional*, defaults to 12): 72 | Number of hidden layers in the Transformer encoder. 73 | num_attention_heads (`int`, *optional*, defaults to 12): 74 | Number of attention heads for each attention layer in the Transformer encoder. 75 | intermediate_size (`int`, *optional*, defaults to 3072): 76 | Dimensionality of the "intermediate" (often named feed-forward) layer in the Transformer encoder. 77 | hidden_act (`str` or `Callable`, *optional*, defaults to `"gelu"`): 78 | The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, 79 | `"relu"`, `"silu"` and `"gelu_new"` are supported. 80 | hidden_dropout_prob (`float`, *optional*, defaults to 0.1): 81 | The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. 82 | attention_probs_dropout_prob (`float`, *optional*, defaults to 0.1): 83 | The dropout ratio for the attention probabilities. 84 | max_position_embeddings (`int`, *optional*, defaults to 512): 85 | The maximum sequence length that this model might ever be used with. Typically set this to something large 86 | just in case (e.g., 512 or 1024 or 2048). 87 | type_vocab_size (`int`, *optional*, defaults to 2): 88 | The vocabulary size of the `token_type_ids` passed when calling [`BertModel`] or [`TFBertModel`]. 89 | initializer_range (`float`, *optional*, defaults to 0.02): 90 | The standard deviation of the truncated_normal_initializer for initializing all weight matrices. 91 | layer_norm_eps (`float`, *optional*, defaults to 1e-12): 92 | The epsilon used by the layer normalization layers. 93 | position_embedding_type (`str`, *optional*, defaults to `"absolute"`): 94 | Type of position embedding. Choose one of `"absolute"`, `"relative_key"`, `"relative_key_query"`. For 95 | positional embeddings use `"absolute"`. For more information on `"relative_key"`, please refer to 96 | [Self-Attention with Relative Position Representations (Shaw et al.)](https://arxiv.org/abs/1803.02155). 97 | For more information on `"relative_key_query"`, please refer to *Method 4* in [Improve Transformer Models 98 | with Better Relative Position Embeddings (Huang et al.)](https://arxiv.org/abs/2009.13658). 99 | use_cache (`bool`, *optional*, defaults to `True`): 100 | Whether or not the model should return the last key/values attentions (not used by all models). Only 101 | relevant if `config.is_decoder=True`. 102 | classifier_dropout (`float`, *optional*): 103 | The dropout ratio for the classification head. 104 | 105 | Examples: 106 | 107 | ```python 108 | >>> from transformers import BertModel, BertConfig 109 | 110 | >>> # Initializing a BERT bert-base-uncased style configuration 111 | >>> configuration = BertConfig() 112 | 113 | >>> # Initializing a model from the bert-base-uncased style configuration 114 | >>> model = BertModel(configuration) 115 | 116 | >>> # Accessing the model configuration 117 | >>> configuration = model.config 118 | ```""" 119 | model_type = "bert" 120 | 121 | def __init__( 122 | self, 123 | vocab_size=30522, 124 | hidden_size=768, 125 | num_hidden_layers=12, 126 | num_attention_heads=12, 127 | intermediate_size=3072, 128 | hidden_act="gelu", 129 | hidden_dropout_prob=0.1, 130 | attention_probs_dropout_prob=0.1, 131 | max_position_embeddings=512, 132 | type_vocab_size=2, 133 | initializer_range=0.02, 134 | layer_norm_eps=1e-12, 135 | pad_token_id=0, 136 | position_embedding_type="absolute", 137 | use_cache=True, 138 | classifier_dropout=None, 139 | **kwargs 140 | ): 141 | super().__init__(pad_token_id=pad_token_id, **kwargs) 142 | 143 | self.vocab_size = vocab_size 144 | self.hidden_size = hidden_size 145 | self.num_hidden_layers = num_hidden_layers 146 | self.num_attention_heads = num_attention_heads 147 | self.hidden_act = hidden_act 148 | self.intermediate_size = intermediate_size 149 | self.hidden_dropout_prob = hidden_dropout_prob 150 | self.attention_probs_dropout_prob = attention_probs_dropout_prob 151 | self.max_position_embeddings = max_position_embeddings 152 | self.type_vocab_size = type_vocab_size 153 | self.initializer_range = initializer_range 154 | self.layer_norm_eps = layer_norm_eps 155 | self.position_embedding_type = position_embedding_type 156 | self.use_cache = use_cache 157 | self.classifier_dropout = classifier_dropout 158 | 159 | 160 | class BertOnnxConfig(OnnxConfig): 161 | @property 162 | def inputs(self) -> Mapping[str, Mapping[int, str]]: 163 | return OrderedDict( 164 | [ 165 | ("input_ids", {0: "batch", 1: "sequence"}), 166 | ("attention_mask", {0: "batch", 1: "sequence"}), 167 | ("token_type_ids", {0: "batch", 1: "sequence"}), 168 | ] 169 | ) -------------------------------------------------------------------------------- /mutransformers/models/bert/configuration_bert.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. 3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 4 | # Copyright 2022 Microsoft Corporation. 5 | # 6 | # Licensed under the Apache License, Version 2.0 (the "License"); 7 | # you may not use this file except in compliance with the License. 8 | # You may obtain a copy of the License at 9 | # 10 | # http://www.apache.org/licenses/LICENSE-2.0 11 | # 12 | # Unless required by applicable law or agreed to in writing, software 13 | # distributed under the License is distributed on an "AS IS" BASIS, 14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | # See the License for the specific language governing permissions and 16 | # limitations under the License. 17 | """ BERT model configuration""" 18 | from collections import OrderedDict 19 | from typing import Mapping 20 | 21 | from transformers.configuration_utils import PretrainedConfig 22 | from transformers.onnx import OnnxConfig 23 | from transformers.utils import logging 24 | 25 | 26 | logger = logging.get_logger(__name__) 27 | 28 | BERT_PRETRAINED_CONFIG_ARCHIVE_MAP = { 29 | "bert-base-uncased": "https://huggingface.co/bert-base-uncased/resolve/main/config.json", 30 | "bert-large-uncased": "https://huggingface.co/bert-large-uncased/resolve/main/config.json", 31 | "bert-base-cased": "https://huggingface.co/bert-base-cased/resolve/main/config.json", 32 | "bert-large-cased": "https://huggingface.co/bert-large-cased/resolve/main/config.json", 33 | "bert-base-multilingual-uncased": "https://huggingface.co/bert-base-multilingual-uncased/resolve/main/config.json", 34 | "bert-base-multilingual-cased": "https://huggingface.co/bert-base-multilingual-cased/resolve/main/config.json", 35 | "bert-base-chinese": "https://huggingface.co/bert-base-chinese/resolve/main/config.json", 36 | "bert-base-german-cased": "https://huggingface.co/bert-base-german-cased/resolve/main/config.json", 37 | "bert-large-uncased-whole-word-masking": "https://huggingface.co/bert-large-uncased-whole-word-masking/resolve/main/config.json", 38 | "bert-large-cased-whole-word-masking": "https://huggingface.co/bert-large-cased-whole-word-masking/resolve/main/config.json", 39 | "bert-large-uncased-whole-word-masking-finetuned-squad": "https://huggingface.co/bert-large-uncased-whole-word-masking-finetuned-squad/resolve/main/config.json", 40 | "bert-large-cased-whole-word-masking-finetuned-squad": "https://huggingface.co/bert-large-cased-whole-word-masking-finetuned-squad/resolve/main/config.json", 41 | "bert-base-cased-finetuned-mrpc": "https://huggingface.co/bert-base-cased-finetuned-mrpc/resolve/main/config.json", 42 | "bert-base-german-dbmdz-cased": "https://huggingface.co/bert-base-german-dbmdz-cased/resolve/main/config.json", 43 | "bert-base-german-dbmdz-uncased": "https://huggingface.co/bert-base-german-dbmdz-uncased/resolve/main/config.json", 44 | "cl-tohoku/bert-base-japanese": "https://huggingface.co/cl-tohoku/bert-base-japanese/resolve/main/config.json", 45 | "cl-tohoku/bert-base-japanese-whole-word-masking": "https://huggingface.co/cl-tohoku/bert-base-japanese-whole-word-masking/resolve/main/config.json", 46 | "cl-tohoku/bert-base-japanese-char": "https://huggingface.co/cl-tohoku/bert-base-japanese-char/resolve/main/config.json", 47 | "cl-tohoku/bert-base-japanese-char-whole-word-masking": "https://huggingface.co/cl-tohoku/bert-base-japanese-char-whole-word-masking/resolve/main/config.json", 48 | "TurkuNLP/bert-base-finnish-cased-v1": "https://huggingface.co/TurkuNLP/bert-base-finnish-cased-v1/resolve/main/config.json", 49 | "TurkuNLP/bert-base-finnish-uncased-v1": "https://huggingface.co/TurkuNLP/bert-base-finnish-uncased-v1/resolve/main/config.json", 50 | "wietsedv/bert-base-dutch-cased": "https://huggingface.co/wietsedv/bert-base-dutch-cased/resolve/main/config.json", 51 | # See all BERT models at https://huggingface.co/models?filter=bert 52 | } 53 | 54 | 55 | class BertConfig(PretrainedConfig): 56 | r""" 57 | This is the configuration class to store the configuration of a [`BertModel`] or a [`TFBertModel`]. It is used to 58 | instantiate a BERT model according to the specified arguments, defining the model architecture. Instantiating a 59 | configuration with the defaults will yield a similar configuration to that of the BERT 60 | [bert-base-uncased](https://huggingface.co/bert-base-uncased) architecture. 61 | 62 | Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the 63 | documentation from [`PretrainedConfig`] for more information. 64 | 65 | 66 | Args: 67 | vocab_size (`int`, *optional*, defaults to 30522): 68 | Vocabulary size of the BERT model. Defines the number of different tokens that can be represented by the 69 | `inputs_ids` passed when calling [`BertModel`] or [`TFBertModel`]. 70 | hidden_size (`int`, *optional*, defaults to 768): 71 | Dimensionality of the encoder layers and the pooler layer. 72 | num_hidden_layers (`int`, *optional*, defaults to 12): 73 | Number of hidden layers in the Transformer encoder. 74 | num_attention_heads (`int`, *optional*, defaults to 12): 75 | Number of attention heads for each attention layer in the Transformer encoder. 76 | intermediate_size (`int`, *optional*, defaults to 3072): 77 | Dimensionality of the "intermediate" (often named feed-forward) layer in the Transformer encoder. 78 | hidden_act (`str` or `Callable`, *optional*, defaults to `"gelu"`): 79 | The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, 80 | `"relu"`, `"silu"` and `"gelu_new"` are supported. 81 | hidden_dropout_prob (`float`, *optional*, defaults to 0.1): 82 | The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. 83 | attention_probs_dropout_prob (`float`, *optional*, defaults to 0.1): 84 | The dropout ratio for the attention probabilities. 85 | max_position_embeddings (`int`, *optional*, defaults to 512): 86 | The maximum sequence length that this model might ever be used with. Typically set this to something large 87 | just in case (e.g., 512 or 1024 or 2048). 88 | type_vocab_size (`int`, *optional*, defaults to 2): 89 | The vocabulary size of the `token_type_ids` passed when calling [`BertModel`] or [`TFBertModel`]. 90 | initializer_range (`float`, *optional*, defaults to 0.02): 91 | The standard deviation of the truncated_normal_initializer for initializing all weight matrices. 92 | layer_norm_eps (`float`, *optional*, defaults to 1e-12): 93 | The epsilon used by the layer normalization layers. 94 | position_embedding_type (`str`, *optional*, defaults to `"absolute"`): 95 | Type of position embedding. Choose one of `"absolute"`, `"relative_key"`, `"relative_key_query"`. For 96 | positional embeddings use `"absolute"`. For more information on `"relative_key"`, please refer to 97 | [Self-Attention with Relative Position Representations (Shaw et al.)](https://arxiv.org/abs/1803.02155). 98 | For more information on `"relative_key_query"`, please refer to *Method 4* in [Improve Transformer Models 99 | with Better Relative Position Embeddings (Huang et al.)](https://arxiv.org/abs/2009.13658). 100 | use_cache (`bool`, *optional*, defaults to `True`): 101 | Whether or not the model should return the last key/values attentions (not used by all models). Only 102 | relevant if `config.is_decoder=True`. 103 | classifier_dropout (`float`, *optional*): 104 | The dropout ratio for the classification head. 105 | 106 | Examples: 107 | 108 | ```python 109 | >>> from transformers import BertModel, BertConfig 110 | 111 | >>> # Initializing a BERT bert-base-uncased style configuration 112 | >>> configuration = BertConfig() 113 | 114 | >>> # Initializing a model from the bert-base-uncased style configuration 115 | >>> model = BertModel(configuration) 116 | 117 | >>> # Accessing the model configuration 118 | >>> configuration = model.config 119 | ```""" 120 | model_type = "bert" 121 | 122 | def __init__( 123 | self, 124 | vocab_size=30522, 125 | hidden_size=768, 126 | num_hidden_layers=12, 127 | num_attention_heads=12, 128 | intermediate_size=3072, 129 | hidden_act="gelu", 130 | hidden_dropout_prob=0.1, 131 | attention_probs_dropout_prob=0.1, 132 | max_position_embeddings=512, 133 | type_vocab_size=2, 134 | initializer_range=0.02, 135 | layer_norm_eps=1e-12, 136 | pad_token_id=0, 137 | position_embedding_type="absolute", 138 | use_cache=True, 139 | classifier_dropout=None, 140 | ### muP 141 | attn_mult=None, 142 | **kwargs 143 | ): 144 | super().__init__(pad_token_id=pad_token_id, **kwargs) 145 | 146 | self.vocab_size = vocab_size 147 | self.hidden_size = hidden_size 148 | self.num_hidden_layers = num_hidden_layers 149 | self.num_attention_heads = num_attention_heads 150 | self.hidden_act = hidden_act 151 | self.intermediate_size = intermediate_size 152 | self.hidden_dropout_prob = hidden_dropout_prob 153 | self.attention_probs_dropout_prob = attention_probs_dropout_prob 154 | self.max_position_embeddings = max_position_embeddings 155 | self.type_vocab_size = type_vocab_size 156 | self.initializer_range = initializer_range 157 | self.layer_norm_eps = layer_norm_eps 158 | self.position_embedding_type = position_embedding_type 159 | self.use_cache = use_cache 160 | self.classifier_dropout = classifier_dropout 161 | ### muP 162 | if attn_mult is None: 163 | # defaults back to 1/sqrt(d) attn 164 | self.attn_mult = (self.hidden_size / self.num_attention_heads)**0.5 165 | else: 166 | self.attn_mult = attn_mult 167 | 168 | 169 | class BertOnnxConfig(OnnxConfig): 170 | @property 171 | def inputs(self) -> Mapping[str, Mapping[int, str]]: 172 | return OrderedDict( 173 | [ 174 | ("input_ids", {0: "batch", 1: "sequence"}), 175 | ("attention_mask", {0: "batch", 1: "sequence"}), 176 | ("token_type_ids", {0: "batch", 1: "sequence"}), 177 | ] 178 | ) -------------------------------------------------------------------------------- /mutransformers/models/gpt2/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/mutransformers/480287ce7b18a07a3432e8f2fbc0f0e5b71e2599/mutransformers/models/gpt2/__init__.py -------------------------------------------------------------------------------- /mutransformers/models/gpt2/_original_configuration_gpt2.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team. 3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """ OpenAI GPT-2 configuration""" 17 | from collections import OrderedDict 18 | from typing import Any, List, Mapping, Optional 19 | 20 | from transformers import PreTrainedTokenizer, TensorType, is_torch_available 21 | 22 | from ...configuration_utils import PretrainedConfig 23 | from ...onnx import OnnxConfigWithPast, PatchingSpec 24 | from ...utils import logging 25 | 26 | 27 | logger = logging.get_logger(__name__) 28 | 29 | GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP = { 30 | "gpt2": "https://huggingface.co/gpt2/resolve/main/config.json", 31 | "gpt2-medium": "https://huggingface.co/gpt2-medium/resolve/main/config.json", 32 | "gpt2-large": "https://huggingface.co/gpt2-large/resolve/main/config.json", 33 | "gpt2-xl": "https://huggingface.co/gpt2-xl/resolve/main/config.json", 34 | "distilgpt2": "https://huggingface.co/distilgpt2/resolve/main/config.json", 35 | } 36 | 37 | 38 | class GPT2Config(PretrainedConfig): 39 | """ 40 | This is the configuration class to store the configuration of a [`GPT2Model`] or a [`TFGPT2Model`]. It is used to 41 | instantiate a GPT-2 model according to the specified arguments, defining the model architecture. Instantiating a 42 | configuration with the defaults will yield a similar configuration to that of the GPT-2 43 | [small](https://huggingface.co/gpt2) architecture. 44 | 45 | Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the 46 | documentation from [`PretrainedConfig`] for more information. 47 | 48 | 49 | Args: 50 | vocab_size (`int`, *optional*, defaults to 50257): 51 | Vocabulary size of the GPT-2 model. Defines the number of different tokens that can be represented by the 52 | `inputs_ids` passed when calling [`GPT2Model`] or [`TFGPT2Model`]. 53 | n_positions (`int`, *optional*, defaults to 1024): 54 | The maximum sequence length that this model might ever be used with. Typically set this to something large 55 | just in case (e.g., 512 or 1024 or 2048). 56 | n_embd (`int`, *optional*, defaults to 768): 57 | Dimensionality of the embeddings and hidden states. 58 | n_layer (`int`, *optional*, defaults to 12): 59 | Number of hidden layers in the Transformer encoder. 60 | n_head (`int`, *optional*, defaults to 12): 61 | Number of attention heads for each attention layer in the Transformer encoder. 62 | n_inner (`int`, *optional*, defaults to None): 63 | Dimensionality of the inner feed-forward layers. `None` will set it to 4 times n_embd 64 | activation_function (`str`, *optional*, defaults to `"gelu"`): 65 | Activation function, to be selected in the list `["relu", "silu", "gelu", "tanh", "gelu_new"]`. 66 | resid_pdrop (`float`, *optional*, defaults to 0.1): 67 | The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. 68 | embd_pdrop (`int`, *optional*, defaults to 0.1): 69 | The dropout ratio for the embeddings. 70 | attn_pdrop (`float`, *optional*, defaults to 0.1): 71 | The dropout ratio for the attention. 72 | layer_norm_epsilon (`float`, *optional*, defaults to 1e-5): 73 | The epsilon to use in the layer normalization layers. 74 | initializer_range (`float`, *optional*, defaults to 0.02): 75 | The standard deviation of the truncated_normal_initializer for initializing all weight matrices. 76 | summary_type (`string`, *optional*, defaults to `"cls_index"`): 77 | Argument used when doing sequence summary, used in the models [`GPT2DoubleHeadsModel`] and 78 | [`TFGPT2DoubleHeadsModel`]. 79 | 80 | Has to be one of the following options: 81 | 82 | - `"last"`: Take the last token hidden state (like XLNet). 83 | - `"first"`: Take the first token hidden state (like BERT). 84 | - `"mean"`: Take the mean of all tokens hidden states. 85 | - `"cls_index"`: Supply a Tensor of classification token position (like GPT/GPT-2). 86 | - `"attn"`: Not implemented now, use multi-head attention. 87 | summary_use_proj (`bool`, *optional*, defaults to `True`): 88 | Argument used when doing sequence summary, used in the models [`GPT2DoubleHeadsModel`] and 89 | [`TFGPT2DoubleHeadsModel`]. 90 | 91 | Whether or not to add a projection after the vector extraction. 92 | summary_activation (`str`, *optional*): 93 | Argument used when doing sequence summary. Used in for the multiple choice head in 94 | [`GPT2DoubleHeadsModel`]. 95 | 96 | Pass `"tanh"` for a tanh activation to the output, any other value will result in no activation. 97 | summary_proj_to_labels (`bool`, *optional*, defaults to `True`): 98 | Argument used when doing sequence summary, used in the models [`GPT2DoubleHeadsModel`] and 99 | [`TFGPT2DoubleHeadsModel`]. 100 | 101 | Whether the projection outputs should have `config.num_labels` or `config.hidden_size` classes. 102 | summary_first_dropout (`float`, *optional*, defaults to 0.1): 103 | Argument used when doing sequence summary, used in the models [`GPT2DoubleHeadsModel`] and 104 | [`TFGPT2DoubleHeadsModel`]. 105 | 106 | The dropout ratio to be used after the projection and activation. 107 | scale_attn_weights (`bool`, *optional*, defaults to `True`): 108 | Scale attention weights by dividing by sqrt(hidden_size).. 109 | use_cache (`bool`, *optional*, defaults to `True`): 110 | Whether or not the model should return the last key/values attentions (not used by all models). 111 | scale_attn_by_inverse_layer_idx (`bool`, *optional*, defaults to `False`): 112 | Whether to additionally scale attention weights by `1 / layer_idx + 1`. 113 | reorder_and_upcast_attn (`bool`, *optional*, defaults to `False`): 114 | Whether to scale keys (K) prior to computing attention (dot-product) and upcast attention 115 | dot-product/softmax to float() when training with mixed precision. 116 | 117 | Example: 118 | 119 | ```python 120 | >>> from transformers import GPT2Model, GPT2Config 121 | 122 | >>> # Initializing a GPT2 configuration 123 | >>> configuration = GPT2Config() 124 | 125 | >>> # Initializing a model from the configuration 126 | >>> model = GPT2Model(configuration) 127 | 128 | >>> # Accessing the model configuration 129 | >>> configuration = model.config 130 | ```""" 131 | 132 | model_type = "gpt2" 133 | keys_to_ignore_at_inference = ["past_key_values"] 134 | attribute_map = { 135 | "hidden_size": "n_embd", 136 | "max_position_embeddings": "n_positions", 137 | "num_attention_heads": "n_head", 138 | "num_hidden_layers": "n_layer", 139 | } 140 | 141 | def __init__( 142 | self, 143 | vocab_size=50257, 144 | n_positions=1024, 145 | n_embd=768, 146 | n_layer=12, 147 | n_head=12, 148 | n_inner=None, 149 | activation_function="gelu_new", 150 | resid_pdrop=0.1, 151 | embd_pdrop=0.1, 152 | attn_pdrop=0.1, 153 | layer_norm_epsilon=1e-5, 154 | initializer_range=0.02, 155 | summary_type="cls_index", 156 | summary_use_proj=True, 157 | summary_activation=None, 158 | summary_proj_to_labels=True, 159 | summary_first_dropout=0.1, 160 | scale_attn_weights=True, 161 | use_cache=True, 162 | bos_token_id=50256, 163 | eos_token_id=50256, 164 | scale_attn_by_inverse_layer_idx=False, 165 | reorder_and_upcast_attn=False, 166 | **kwargs, 167 | ): 168 | self.vocab_size = vocab_size 169 | self.n_positions = n_positions 170 | self.n_embd = n_embd 171 | self.n_layer = n_layer 172 | self.n_head = n_head 173 | self.n_inner = n_inner 174 | self.activation_function = activation_function 175 | self.resid_pdrop = resid_pdrop 176 | self.embd_pdrop = embd_pdrop 177 | self.attn_pdrop = attn_pdrop 178 | self.layer_norm_epsilon = layer_norm_epsilon 179 | self.initializer_range = initializer_range 180 | self.summary_type = summary_type 181 | self.summary_use_proj = summary_use_proj 182 | self.summary_activation = summary_activation 183 | self.summary_first_dropout = summary_first_dropout 184 | self.summary_proj_to_labels = summary_proj_to_labels 185 | self.scale_attn_weights = scale_attn_weights 186 | self.use_cache = use_cache 187 | self.scale_attn_by_inverse_layer_idx = scale_attn_by_inverse_layer_idx 188 | self.reorder_and_upcast_attn = reorder_and_upcast_attn 189 | 190 | self.bos_token_id = bos_token_id 191 | self.eos_token_id = eos_token_id 192 | 193 | super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs) 194 | 195 | 196 | class GPT2OnnxConfig(OnnxConfigWithPast): 197 | def __init__( 198 | self, 199 | config: PretrainedConfig, 200 | task: str = "default", 201 | patching_specs: List[PatchingSpec] = None, 202 | use_past: bool = False, 203 | ): 204 | super().__init__(config, task=task, patching_specs=patching_specs, use_past=use_past) 205 | if not getattr(self._config, "pad_token_id", None): 206 | # TODO: how to do that better? 207 | self._config.pad_token_id = 0 208 | 209 | @property 210 | def inputs(self) -> Mapping[str, Mapping[int, str]]: 211 | common_inputs = OrderedDict({"input_ids": {0: "batch", 1: "sequence"}}) 212 | if self.use_past: 213 | self.fill_with_past_key_values_(common_inputs, direction="inputs") 214 | common_inputs["attention_mask"] = {0: "batch", 1: "past_sequence + sequence"} 215 | else: 216 | common_inputs["attention_mask"] = {0: "batch", 1: "sequence"} 217 | 218 | return common_inputs 219 | 220 | @property 221 | def num_layers(self) -> int: 222 | return self._config.n_layer 223 | 224 | @property 225 | def num_attention_heads(self) -> int: 226 | return self._config.n_head 227 | 228 | def generate_dummy_inputs( 229 | self, 230 | tokenizer: PreTrainedTokenizer, 231 | batch_size: int = -1, 232 | seq_length: int = -1, 233 | is_pair: bool = False, 234 | framework: Optional[TensorType] = None, 235 | ) -> Mapping[str, Any]: 236 | common_inputs = super(OnnxConfigWithPast, self).generate_dummy_inputs( 237 | tokenizer, batch_size, seq_length, is_pair, framework 238 | ) 239 | 240 | # We need to order the input in the way they appears in the forward() 241 | ordered_inputs = OrderedDict({"input_ids": common_inputs["input_ids"]}) 242 | 243 | # Need to add the past_keys 244 | if self.use_past: 245 | if not is_torch_available(): 246 | raise ValueError("Cannot generate dummy past_keys inputs without PyTorch installed.") 247 | else: 248 | import torch 249 | 250 | batch, seqlen = common_inputs["input_ids"].shape 251 | # Not using the same length for past_key_values 252 | past_key_values_length = seqlen + 2 253 | past_shape = ( 254 | batch, 255 | self.num_attention_heads, 256 | past_key_values_length, 257 | self._config.hidden_size // self.num_attention_heads, 258 | ) 259 | ordered_inputs["past_key_values"] = [ 260 | (torch.zeros(past_shape), torch.zeros(past_shape)) for _ in range(self.num_layers) 261 | ] 262 | 263 | ordered_inputs["attention_mask"] = common_inputs["attention_mask"] 264 | if self.use_past: 265 | ordered_inputs["attention_mask"] = torch.cat( 266 | [ordered_inputs["attention_mask"], torch.ones(batch, past_key_values_length)], dim=1 267 | ) 268 | 269 | return ordered_inputs 270 | 271 | @property 272 | def default_onnx_opset(self) -> int: 273 | return 13 -------------------------------------------------------------------------------- /mutransformers/models/gpt2/_original_modeling_gpt2.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team. 3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """PyTorch OpenAI GPT-2 model.""" 17 | 18 | import math 19 | import os 20 | from dataclasses import dataclass 21 | from typing import Optional, Tuple 22 | 23 | import torch 24 | import torch.utils.checkpoint 25 | from packaging import version 26 | from torch import nn 27 | from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss 28 | 29 | 30 | if version.parse(torch.__version__) >= version.parse("1.6"): 31 | is_amp_available = True 32 | from torch.cuda.amp import autocast 33 | else: 34 | is_amp_available = False 35 | 36 | from ...activations import ACT2FN 37 | from ...file_utils import ( 38 | ModelOutput, 39 | add_code_sample_docstrings, 40 | add_start_docstrings, 41 | add_start_docstrings_to_model_forward, 42 | replace_return_docstrings, 43 | ) 44 | from ...modeling_outputs import ( 45 | BaseModelOutputWithPastAndCrossAttentions, 46 | CausalLMOutputWithCrossAttentions, 47 | SequenceClassifierOutputWithPast, 48 | TokenClassifierOutput, 49 | ) 50 | from ...modeling_utils import ( 51 | Conv1D, 52 | PreTrainedModel, 53 | SequenceSummary, 54 | find_pruneable_heads_and_indices, 55 | prune_conv1d_layer, 56 | ) 57 | from ...utils import logging 58 | from ...utils.model_parallel_utils import assert_device_map, get_device_map 59 | from .configuration_gpt2 import GPT2Config 60 | 61 | 62 | logger = logging.get_logger(__name__) 63 | 64 | _CHECKPOINT_FOR_DOC = "gpt2" 65 | _CONFIG_FOR_DOC = "GPT2Config" 66 | _TOKENIZER_FOR_DOC = "GPT2Tokenizer" 67 | 68 | GPT2_PRETRAINED_MODEL_ARCHIVE_LIST = [ 69 | "gpt2", 70 | "gpt2-medium", 71 | "gpt2-large", 72 | "gpt2-xl", 73 | "distilgpt2", 74 | # See all GPT-2 models at https://huggingface.co/models?filter=gpt2 75 | ] 76 | 77 | 78 | def load_tf_weights_in_gpt2(model, config, gpt2_checkpoint_path): 79 | """Load tf checkpoints in a pytorch model""" 80 | try: 81 | import re 82 | 83 | import tensorflow as tf 84 | except ImportError: 85 | logger.error( 86 | "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see " 87 | "https://www.tensorflow.org/install/ for installation instructions." 88 | ) 89 | raise 90 | tf_path = os.path.abspath(gpt2_checkpoint_path) 91 | logger.info(f"Converting TensorFlow checkpoint from {tf_path}") 92 | # Load weights from TF model 93 | init_vars = tf.train.list_variables(tf_path) 94 | names = [] 95 | arrays = [] 96 | for name, shape in init_vars: 97 | logger.info(f"Loading TF weight {name} with shape {shape}") 98 | array = tf.train.load_variable(tf_path, name) 99 | names.append(name) 100 | arrays.append(array.squeeze()) 101 | 102 | for name, array in zip(names, arrays): 103 | name = name[6:] # skip "model/" 104 | name = name.split("/") 105 | pointer = model 106 | for m_name in name: 107 | if re.fullmatch(r"[A-Za-z]+\d+", m_name): 108 | scope_names = re.split(r"(\d+)", m_name) 109 | else: 110 | scope_names = [m_name] 111 | if scope_names[0] == "w" or scope_names[0] == "g": 112 | pointer = getattr(pointer, "weight") 113 | elif scope_names[0] == "b": 114 | pointer = getattr(pointer, "bias") 115 | elif scope_names[0] == "wpe" or scope_names[0] == "wte": 116 | pointer = getattr(pointer, scope_names[0]) 117 | pointer = getattr(pointer, "weight") 118 | else: 119 | pointer = getattr(pointer, scope_names[0]) 120 | if len(scope_names) >= 2: 121 | num = int(scope_names[1]) 122 | pointer = pointer[num] 123 | try: 124 | assert ( 125 | pointer.shape == array.shape 126 | ), f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched" 127 | except AssertionError as e: 128 | e.args += (pointer.shape, array.shape) 129 | raise 130 | logger.info(f"Initialize PyTorch weight {name}") 131 | pointer.data = torch.from_numpy(array) 132 | return model 133 | 134 | 135 | class GPT2Attention(nn.Module): 136 | def __init__(self, config, is_cross_attention=False, layer_idx=None): 137 | super().__init__() 138 | 139 | max_positions = config.max_position_embeddings 140 | self.register_buffer( 141 | "bias", 142 | torch.tril(torch.ones((max_positions, max_positions), dtype=torch.uint8)).view( 143 | 1, 1, max_positions, max_positions 144 | ), 145 | ) 146 | self.register_buffer("masked_bias", torch.tensor(-1e4)) 147 | 148 | self.embed_dim = config.hidden_size 149 | self.num_heads = config.num_attention_heads 150 | self.head_dim = self.embed_dim // self.num_heads 151 | self.split_size = self.embed_dim 152 | if self.head_dim * self.num_heads != self.embed_dim: 153 | raise ValueError( 154 | f"`embed_dim` must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`: {self.num_heads})." 155 | ) 156 | 157 | self.scale_attn_weights = config.scale_attn_weights 158 | self.is_cross_attention = is_cross_attention 159 | 160 | # Layer-wise attention scaling, reordering, and upcasting 161 | self.scale_attn_by_inverse_layer_idx = config.scale_attn_by_inverse_layer_idx 162 | self.layer_idx = layer_idx 163 | self.reorder_and_upcast_attn = config.reorder_and_upcast_attn 164 | 165 | if self.is_cross_attention: 166 | self.c_attn = Conv1D(2 * self.embed_dim, self.embed_dim) 167 | self.q_attn = Conv1D(self.embed_dim, self.embed_dim) 168 | else: 169 | self.c_attn = Conv1D(3 * self.embed_dim, self.embed_dim) 170 | self.c_proj = Conv1D(self.embed_dim, self.embed_dim) 171 | 172 | self.attn_dropout = nn.Dropout(config.attn_pdrop) 173 | self.resid_dropout = nn.Dropout(config.resid_pdrop) 174 | 175 | self.pruned_heads = set() 176 | 177 | def prune_heads(self, heads): 178 | if len(heads) == 0: 179 | return 180 | heads, index = find_pruneable_heads_and_indices(heads, self.num_heads, self.head_dim, self.pruned_heads) 181 | index_attn = torch.cat([index, index + self.split_size, index + (2 * self.split_size)]) 182 | 183 | # Prune conv1d layers 184 | self.c_attn = prune_conv1d_layer(self.c_attn, index_attn, dim=1) 185 | self.c_proj = prune_conv1d_layer(self.c_proj, index, dim=0) 186 | 187 | # Update hyper params 188 | self.split_size = (self.split_size // self.num_heads) * (self.num_heads - len(heads)) 189 | self.num_heads = self.num_heads - len(heads) 190 | self.pruned_heads = self.pruned_heads.union(heads) 191 | 192 | def _attn(self, query, key, value, attention_mask=None, head_mask=None): 193 | attn_weights = torch.matmul(query, key.transpose(-1, -2)) 194 | 195 | if self.scale_attn_weights: 196 | attn_weights = attn_weights / (float(value.size(-1)) ** 0.5) 197 | 198 | # Layer-wise attention scaling 199 | if self.scale_attn_by_inverse_layer_idx: 200 | attn_weights = attn_weights / float(self.layer_idx + 1) 201 | 202 | if not self.is_cross_attention: 203 | # if only "normal" attention layer implements causal mask 204 | query_length, key_length = query.size(-2), key.size(-2) 205 | causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].bool() 206 | attn_weights = torch.where(causal_mask, attn_weights, self.masked_bias.to(attn_weights.dtype)) 207 | 208 | if attention_mask is not None: 209 | # Apply the attention mask 210 | attn_weights = attn_weights + attention_mask 211 | 212 | attn_weights = nn.functional.softmax(attn_weights, dim=-1) 213 | 214 | # Downcast (if necessary) back to V's dtype (if in mixed-precision) -- No-Op otherwise 215 | attn_weights = attn_weights.type(value.dtype) 216 | attn_weights = self.attn_dropout(attn_weights) 217 | 218 | # Mask heads if we want to 219 | if head_mask is not None: 220 | attn_weights = attn_weights * head_mask 221 | 222 | attn_output = torch.matmul(attn_weights, value) 223 | 224 | return attn_output, attn_weights 225 | 226 | def _upcast_and_reordered_attn(self, query, key, value, attention_mask=None, head_mask=None): 227 | # Use `torch.baddbmm` (a bit more efficient w/ alpha param for scaling -- from Megatron-LM) 228 | bsz, num_heads, q_seq_len, dk = query.size() 229 | _, _, k_seq_len, _ = key.size() 230 | 231 | # Preallocate attn_weights for `baddbmm` 232 | attn_weights = torch.empty(bsz * num_heads, q_seq_len, k_seq_len, dtype=torch.float32, device=query.device) 233 | 234 | # Compute Scale Factor 235 | scale_factor = 1.0 236 | if self.scale_attn_weights: 237 | scale_factor /= float(value.size(-1)) ** 0.5 238 | 239 | if self.scale_attn_by_inverse_layer_idx: 240 | scale_factor /= float(self.layer_idx + 1) 241 | 242 | # Upcast (turn off autocast) and reorder (Scale K by 1 / root(dk)) 243 | if is_amp_available: 244 | with autocast(enabled=False): 245 | q, k = query.reshape(-1, q_seq_len, dk), key.transpose(-1, -2).reshape(-1, dk, k_seq_len) 246 | attn_weights = torch.baddbmm(attn_weights, q.float(), k.float(), beta=0, alpha=scale_factor) 247 | attn_weights = attn_weights.reshape(bsz, num_heads, q_seq_len, k_seq_len) 248 | else: 249 | q, k = query.reshape(-1, q_seq_len, dk), key.transpose(-1, -2).reshape(-1, dk, k_seq_len) 250 | attn_weights = torch.baddbmm(attn_weights, q.float(), k.float(), beta=0, alpha=scale_factor) 251 | attn_weights = attn_weights.reshape(bsz, num_heads, q_seq_len, k_seq_len) 252 | 253 | if not self.is_cross_attention: 254 | # if only "normal" attention layer implements causal mask 255 | query_length, key_length = query.size(-2), key.size(-2) 256 | causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].bool() 257 | attn_weights = torch.where(causal_mask, attn_weights, self.masked_bias.to(attn_weights.dtype)) 258 | 259 | if attention_mask is not None: 260 | # Apply the attention mask 261 | attn_weights = attn_weights + attention_mask 262 | 263 | attn_weights = nn.functional.softmax(attn_weights, dim=-1) 264 | 265 | # Downcast (if necessary) back to V's dtype (if in mixed-precision) -- No-Op if otherwise 266 | if attn_weights.dtype != torch.float32: 267 | raise RuntimeError("Error with upcasting, attn_weights does not have dtype torch.float32") 268 | attn_weights = attn_weights.type(value.dtype) 269 | attn_weights = self.attn_dropout(attn_weights) 270 | 271 | # Mask heads if we want to 272 | if head_mask is not None: 273 | attn_weights = attn_weights * head_mask 274 | 275 | attn_output = torch.matmul(attn_weights, value) 276 | 277 | return attn_output, attn_weights 278 | 279 | def _split_heads(self, tensor, num_heads, attn_head_size): 280 | """ 281 | Splits hidden_size dim into attn_head_size and num_heads 282 | """ 283 | new_shape = tensor.size()[:-1] + (num_heads, attn_head_size) 284 | tensor = tensor.view(*new_shape) 285 | return tensor.permute(0, 2, 1, 3) # (batch, head, seq_length, head_features) 286 | 287 | def _merge_heads(self, tensor, num_heads, attn_head_size): 288 | """ 289 | Merges attn_head_size dim and num_attn_heads dim into hidden_size 290 | """ 291 | tensor = tensor.permute(0, 2, 1, 3).contiguous() 292 | new_shape = tensor.size()[:-2] + (num_heads * attn_head_size,) 293 | return tensor.view(new_shape) 294 | 295 | def forward( 296 | self, 297 | hidden_states, 298 | layer_past=None, 299 | attention_mask=None, 300 | head_mask=None, 301 | encoder_hidden_states=None, 302 | encoder_attention_mask=None, 303 | use_cache=False, 304 | output_attentions=False, 305 | ): 306 | if encoder_hidden_states is not None: 307 | if not hasattr(self, "q_attn"): 308 | raise ValueError( 309 | "If class is used as cross attention, the weights `q_attn` have to be defined. " 310 | "Please make sure to instantiate class with `GPT2Attention(..., is_cross_attention=True)`." 311 | ) 312 | 313 | query = self.q_attn(hidden_states) 314 | key, value = self.c_attn(encoder_hidden_states).split(self.split_size, dim=2) 315 | attention_mask = encoder_attention_mask 316 | else: 317 | query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2) 318 | 319 | query = self._split_heads(query, self.num_heads, self.head_dim) 320 | key = self._split_heads(key, self.num_heads, self.head_dim) 321 | value = self._split_heads(value, self.num_heads, self.head_dim) 322 | 323 | if layer_past is not None: 324 | past_key, past_value = layer_past 325 | key = torch.cat((past_key, key), dim=-2) 326 | value = torch.cat((past_value, value), dim=-2) 327 | 328 | if use_cache is True: 329 | present = (key, value) 330 | else: 331 | present = None 332 | 333 | if self.reorder_and_upcast_attn: 334 | attn_output, attn_weights = self._upcast_and_reordered_attn(query, key, value, attention_mask, head_mask) 335 | else: 336 | attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask) 337 | 338 | attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim) 339 | attn_output = self.c_proj(attn_output) 340 | attn_output = self.resid_dropout(attn_output) 341 | 342 | outputs = (attn_output, present) 343 | if output_attentions: 344 | outputs += (attn_weights,) 345 | 346 | return outputs # a, present, (attentions) 347 | 348 | 349 | class GPT2MLP(nn.Module): 350 | def __init__(self, intermediate_size, config): 351 | super().__init__() 352 | embed_dim = config.hidden_size 353 | self.c_fc = Conv1D(intermediate_size, embed_dim) 354 | self.c_proj = Conv1D(embed_dim, intermediate_size) 355 | self.act = ACT2FN[config.activation_function] 356 | self.dropout = nn.Dropout(config.resid_pdrop) 357 | 358 | def forward(self, hidden_states): 359 | hidden_states = self.c_fc(hidden_states) 360 | hidden_states = self.act(hidden_states) 361 | hidden_states = self.c_proj(hidden_states) 362 | hidden_states = self.dropout(hidden_states) 363 | return hidden_states 364 | 365 | 366 | class GPT2Block(nn.Module): 367 | def __init__(self, config, layer_idx=None): 368 | super().__init__() 369 | hidden_size = config.hidden_size 370 | inner_dim = config.n_inner if config.n_inner is not None else 4 * hidden_size 371 | 372 | self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) 373 | self.attn = GPT2Attention(config, layer_idx=layer_idx) 374 | self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) 375 | 376 | if config.add_cross_attention: 377 | self.crossattention = GPT2Attention(config, is_cross_attention=True) 378 | self.ln_cross_attn = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) 379 | 380 | self.mlp = GPT2MLP(inner_dim, config) 381 | 382 | def forward( 383 | self, 384 | hidden_states, 385 | layer_past=None, 386 | attention_mask=None, 387 | head_mask=None, 388 | encoder_hidden_states=None, 389 | encoder_attention_mask=None, 390 | use_cache=False, 391 | output_attentions=False, 392 | ): 393 | residual = hidden_states 394 | hidden_states = self.ln_1(hidden_states) 395 | attn_outputs = self.attn( 396 | hidden_states, 397 | layer_past=layer_past, 398 | attention_mask=attention_mask, 399 | head_mask=head_mask, 400 | use_cache=use_cache, 401 | output_attentions=output_attentions, 402 | ) 403 | attn_output = attn_outputs[0] # output_attn: a, present, (attentions) 404 | outputs = attn_outputs[1:] 405 | # residual connection 406 | hidden_states = attn_output + residual 407 | 408 | if encoder_hidden_states is not None: 409 | # add one self-attention block for cross-attention 410 | if not hasattr(self, "crossattention"): 411 | raise ValueError( 412 | f"If `encoder_hidden_states` are passed, {self} has to be instantiated with " 413 | "cross-attention layers by setting `config.add_cross_attention=True`" 414 | ) 415 | residual = hidden_states 416 | hidden_states = self.ln_cross_attn(hidden_states) 417 | cross_attn_outputs = self.crossattention( 418 | hidden_states, 419 | attention_mask=attention_mask, 420 | head_mask=head_mask, 421 | encoder_hidden_states=encoder_hidden_states, 422 | encoder_attention_mask=encoder_attention_mask, 423 | output_attentions=output_attentions, 424 | ) 425 | attn_output = cross_attn_outputs[0] 426 | # residual connection 427 | hidden_states = residual + attn_output 428 | outputs = outputs + cross_attn_outputs[2:] # add cross attentions if we output attention weights 429 | 430 | residual = hidden_states 431 | hidden_states = self.ln_2(hidden_states) 432 | feed_forward_hidden_states = self.mlp(hidden_states) 433 | # residual connection 434 | hidden_states = residual + feed_forward_hidden_states 435 | 436 | if use_cache: 437 | outputs = (hidden_states,) + outputs 438 | else: 439 | outputs = (hidden_states,) + outputs[1:] 440 | 441 | return outputs # hidden_states, present, (attentions, cross_attentions) 442 | 443 | 444 | class GPT2PreTrainedModel(PreTrainedModel): 445 | """ 446 | An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained 447 | models. 448 | """ 449 | 450 | config_class = GPT2Config 451 | load_tf_weights = load_tf_weights_in_gpt2 452 | base_model_prefix = "transformer" 453 | is_parallelizable = True 454 | supports_gradient_checkpointing = True 455 | 456 | def __init__(self, *inputs, **kwargs): 457 | super().__init__(*inputs, **kwargs) 458 | 459 | def _init_weights(self, module): 460 | """Initialize the weights.""" 461 | if isinstance(module, (nn.Linear, Conv1D)): 462 | # Slightly different from the TF version which uses truncated_normal for initialization 463 | # cf https://github.com/pytorch/pytorch/pull/5617 464 | module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) 465 | if module.bias is not None: 466 | module.bias.data.zero_() 467 | elif isinstance(module, nn.Embedding): 468 | module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) 469 | if module.padding_idx is not None: 470 | module.weight.data[module.padding_idx].zero_() 471 | elif isinstance(module, nn.LayerNorm): 472 | module.bias.data.zero_() 473 | module.weight.data.fill_(1.0) 474 | 475 | # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme: 476 | # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale 477 | # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers. 478 | # > -- GPT-2 :: https://openai.com/blog/better-language-models/ 479 | # 480 | # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py 481 | for name, p in module.named_parameters(): 482 | if "c_proj" in name and "weight" in name: 483 | # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block 484 | p.data.normal_(mean=0.0, std=(self.config.initializer_range / math.sqrt(2 * self.config.n_layer))) 485 | 486 | def _set_gradient_checkpointing(self, module, value=False): 487 | if isinstance(module, GPT2Model): 488 | module.gradient_checkpointing = value 489 | 490 | 491 | @dataclass 492 | class GPT2DoubleHeadsModelOutput(ModelOutput): 493 | """ 494 | Base class for outputs of models predicting if two sentences are consecutive or not. 495 | 496 | Args: 497 | loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): 498 | Language modeling loss. 499 | mc_loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `mc_labels` is provided): 500 | Multiple choice classification loss. 501 | logits (`torch.FloatTensor` of shape `(batch_size, num_choices, sequence_length, config.vocab_size)`): 502 | Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). 503 | mc_logits (`torch.FloatTensor` of shape `(batch_size, num_choices)`): 504 | Prediction scores of the multiple choice classification head (scores for each choice before SoftMax). 505 | past_key_values (`Tuple[Tuple[torch.Tensor]]`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): 506 | Tuple of length `config.n_layers`, containing tuples of tensors of shape `(batch_size, num_heads, 507 | sequence_length, embed_size_per_head)`). 508 | 509 | Contains pre-computed hidden-states (key and values in the attention blocks) that can be used (see 510 | `past_key_values` input) to speed up sequential decoding. 511 | hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): 512 | Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of 513 | shape `(batch_size, sequence_length, hidden_size)`. 514 | 515 | Hidden-states of the model at the output of each layer plus the initial embedding outputs. 516 | attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): 517 | Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, 518 | sequence_length)`. 519 | 520 | GPT2Attentions weights after the attention softmax, used to compute the weighted average in the 521 | self-attention heads. 522 | """ 523 | 524 | loss: Optional[torch.FloatTensor] = None 525 | mc_loss: Optional[torch.FloatTensor] = None 526 | logits: torch.FloatTensor = None 527 | mc_logits: torch.FloatTensor = None 528 | past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None 529 | hidden_states: Optional[Tuple[torch.FloatTensor]] = None 530 | attentions: Optional[Tuple[torch.FloatTensor]] = None 531 | 532 | 533 | GPT2_START_DOCSTRING = r""" 534 | 535 | This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the 536 | library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads 537 | etc.) 538 | 539 | This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. 540 | Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage 541 | and behavior. 542 | 543 | Parameters: 544 | config ([`GPT2Config`]): Model configuration class with all the parameters of the model. 545 | Initializing with a config file does not load the weights associated with the model, only the 546 | configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. 547 | """ 548 | 549 | GPT2_INPUTS_DOCSTRING = r""" 550 | Args: 551 | input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`): 552 | `input_ids_length` = `sequence_length` if `past_key_values` is `None` else 553 | `past_key_values[0][0].shape[-2]` (`sequence_length` of input past key value states). Indices of input 554 | sequence tokens in the vocabulary. 555 | 556 | If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as 557 | `input_ids`. 558 | 559 | Indices can be obtained using [`GPT2Tokenizer`]. See [`PreTrainedTokenizer.encode`] and 560 | [`PreTrainedTokenizer.__call__`] for details. 561 | 562 | [What are input IDs?](../glossary#input-ids) 563 | past_key_values (`Tuple[Tuple[torch.Tensor]]` of length `config.n_layers`): 564 | Contains precomputed hidden-states (key and values in the attention blocks) as computed by the model (see 565 | `past_key_values` output below). Can be used to speed up sequential decoding. The `input_ids` which have 566 | their past given to this model should not be passed as `input_ids` as they have already been computed. 567 | attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*): 568 | Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: 569 | 570 | - 1 for tokens that are **not masked**, 571 | - 0 for tokens that are **masked**. 572 | 573 | [What are attention masks?](../glossary#attention-mask) 574 | token_type_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`, *optional*): 575 | Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0, 576 | 1]`: 577 | 578 | - 0 corresponds to a *sentence A* token, 579 | - 1 corresponds to a *sentence B* token. 580 | 581 | [What are token type IDs?](../glossary#token-type-ids) 582 | position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): 583 | Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, 584 | config.max_position_embeddings - 1]`. 585 | 586 | [What are position IDs?](../glossary#position-ids) 587 | head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): 588 | Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`: 589 | 590 | - 1 indicates the head is **not masked**, 591 | - 0 indicates the head is **masked**. 592 | 593 | inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): 594 | Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This 595 | is useful if you want more control over how to convert `input_ids` indices into associated vectors than the 596 | model's internal embedding lookup matrix. 597 | 598 | If `past_key_values` is used, optionally only the last `inputs_embeds` have to be input (see 599 | `past_key_values`). 600 | use_cache (`bool`, *optional*): 601 | If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see 602 | `past_key_values`). 603 | output_attentions (`bool`, *optional*): 604 | Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned 605 | tensors for more detail. 606 | output_hidden_states (`bool`, *optional*): 607 | Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for 608 | more detail. 609 | return_dict (`bool`, *optional*): 610 | Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple. 611 | """ 612 | PARALLELIZE_DOCSTRING = r""" 613 | This is an experimental feature and is a subject to change at a moment's notice. 614 | 615 | Uses a device map to distribute attention modules of the model across several devices. If no device map is given, 616 | it will evenly distribute blocks across all devices. 617 | 618 | Args: 619 | device_map (`Dict[int, list]`, optional, defaults to None): 620 | A dictionary that maps attention modules to devices. Note that the embedding module and LMHead are always 621 | automatically mapped to the first device (for esoteric reasons). That means that the first device should 622 | have fewer attention modules mapped to it than other devices. For reference, the gpt2 models have the 623 | following number of attention modules: 624 | 625 | - gpt2: 12 626 | - gpt2-medium: 24 627 | - gpt2-large: 36 628 | - gpt2-xl: 48 629 | 630 | Example: 631 | 632 | ```python 633 | # Here is an example of a device map on a machine with 4 GPUs using gpt2-xl, which has a total of 48 attention modules: 634 | model = GPT2LMHeadModel.from_pretrained("gpt2-xl") 635 | device_map = { 636 | 0: [0, 1, 2, 3, 4, 5, 6, 7, 8], 637 | 1: [9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21], 638 | 2: [22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34], 639 | 3: [35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47], 640 | } 641 | model.parallelize(device_map) 642 | ``` 643 | """ 644 | DEPARALLELIZE_DOCSTRING = r""" 645 | Moves the model to cpu from a model parallel state. 646 | 647 | Example: 648 | 649 | ```python 650 | # On a 4 GPU machine with gpt2-large: 651 | model = GPT2LMHeadModel.from_pretrained("gpt2-large") 652 | device_map = { 653 | 0: [0, 1, 2, 3, 4, 5, 6, 7], 654 | 1: [8, 9, 10, 11, 12, 13, 14, 15], 655 | 2: [16, 17, 18, 19, 20, 21, 22, 23], 656 | 3: [24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35], 657 | } 658 | model.parallelize(device_map) # Splits the model across several devices 659 | model.deparallelize() # Put the model back on cpu and cleans memory by calling torch.cuda.empty_cache() 660 | ``` 661 | """ 662 | 663 | 664 | @add_start_docstrings( 665 | "The bare GPT2 Model transformer outputting raw hidden-states without any specific head on top.", 666 | GPT2_START_DOCSTRING, 667 | ) 668 | class GPT2Model(GPT2PreTrainedModel): 669 | _keys_to_ignore_on_load_missing = ["attn.masked_bias"] 670 | 671 | def __init__(self, config): 672 | super().__init__(config) 673 | 674 | self.embed_dim = config.hidden_size 675 | 676 | self.wte = nn.Embedding(config.vocab_size, self.embed_dim) 677 | self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim) 678 | 679 | self.drop = nn.Dropout(config.embd_pdrop) 680 | self.h = nn.ModuleList([GPT2Block(config, layer_idx=i) for i in range(config.num_hidden_layers)]) 681 | self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) 682 | 683 | # Model parallel 684 | self.model_parallel = False 685 | self.device_map = None 686 | self.gradient_checkpointing = False 687 | 688 | # Initialize weights and apply final processing 689 | self.post_init() 690 | 691 | @add_start_docstrings(PARALLELIZE_DOCSTRING) 692 | def parallelize(self, device_map=None): 693 | # Check validity of device_map 694 | self.device_map = ( 695 | get_device_map(len(self.h), range(torch.cuda.device_count())) if device_map is None else device_map 696 | ) 697 | assert_device_map(self.device_map, len(self.h)) 698 | self.model_parallel = True 699 | self.first_device = "cpu" if "cpu" in self.device_map.keys() else "cuda:" + str(min(self.device_map.keys())) 700 | self.last_device = "cuda:" + str(max(self.device_map.keys())) 701 | self.wte = self.wte.to(self.first_device) 702 | self.wpe = self.wpe.to(self.first_device) 703 | # Load onto devices 704 | for k, v in self.device_map.items(): 705 | for block in v: 706 | cuda_device = "cuda:" + str(k) 707 | self.h[block] = self.h[block].to(cuda_device) 708 | # ln_f to last 709 | self.ln_f = self.ln_f.to(self.last_device) 710 | 711 | @add_start_docstrings(DEPARALLELIZE_DOCSTRING) 712 | def deparallelize(self): 713 | self.model_parallel = False 714 | self.device_map = None 715 | self.first_device = "cpu" 716 | self.last_device = "cpu" 717 | self.wte = self.wte.to("cpu") 718 | self.wpe = self.wpe.to("cpu") 719 | for index in range(len(self.h)): 720 | self.h[index] = self.h[index].to("cpu") 721 | self.ln_f = self.ln_f.to("cpu") 722 | torch.cuda.empty_cache() 723 | 724 | def get_input_embeddings(self): 725 | return self.wte 726 | 727 | def set_input_embeddings(self, new_embeddings): 728 | self.wte = new_embeddings 729 | 730 | def _prune_heads(self, heads_to_prune): 731 | """ 732 | Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} 733 | """ 734 | for layer, heads in heads_to_prune.items(): 735 | self.h[layer].attn.prune_heads(heads) 736 | 737 | @add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING) 738 | @add_code_sample_docstrings( 739 | processor_class=_TOKENIZER_FOR_DOC, 740 | checkpoint=_CHECKPOINT_FOR_DOC, 741 | output_type=BaseModelOutputWithPastAndCrossAttentions, 742 | config_class=_CONFIG_FOR_DOC, 743 | ) 744 | def forward( 745 | self, 746 | input_ids=None, 747 | past_key_values=None, 748 | attention_mask=None, 749 | token_type_ids=None, 750 | position_ids=None, 751 | head_mask=None, 752 | inputs_embeds=None, 753 | encoder_hidden_states=None, 754 | encoder_attention_mask=None, 755 | use_cache=None, 756 | output_attentions=None, 757 | output_hidden_states=None, 758 | return_dict=None, 759 | ): 760 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions 761 | output_hidden_states = ( 762 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states 763 | ) 764 | use_cache = use_cache if use_cache is not None else self.config.use_cache 765 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 766 | 767 | if input_ids is not None and inputs_embeds is not None: 768 | raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") 769 | elif input_ids is not None: 770 | input_shape = input_ids.size() 771 | input_ids = input_ids.view(-1, input_shape[-1]) 772 | batch_size = input_ids.shape[0] 773 | elif inputs_embeds is not None: 774 | input_shape = inputs_embeds.size()[:-1] 775 | batch_size = inputs_embeds.shape[0] 776 | else: 777 | raise ValueError("You have to specify either input_ids or inputs_embeds") 778 | 779 | device = input_ids.device if input_ids is not None else inputs_embeds.device 780 | 781 | if token_type_ids is not None: 782 | token_type_ids = token_type_ids.view(-1, input_shape[-1]) 783 | if position_ids is not None: 784 | position_ids = position_ids.view(-1, input_shape[-1]) 785 | 786 | if past_key_values is None: 787 | past_length = 0 788 | past_key_values = tuple([None] * len(self.h)) 789 | else: 790 | past_length = past_key_values[0][0].size(-2) 791 | if position_ids is None: 792 | position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device) 793 | position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1]) 794 | 795 | # GPT2Attention mask. 796 | if attention_mask is not None: 797 | if batch_size <= 0: 798 | raise ValueError("batch_size has to be defined and > 0") 799 | attention_mask = attention_mask.view(batch_size, -1) 800 | # We create a 3D attention mask from a 2D tensor mask. 801 | # Sizes are [batch_size, 1, 1, to_seq_length] 802 | # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] 803 | # this attention mask is more simple than the triangular masking of causal attention 804 | # used in OpenAI GPT, we just need to prepare the broadcast dimension here. 805 | attention_mask = attention_mask[:, None, None, :] 806 | 807 | # Since attention_mask is 1.0 for positions we want to attend and 0.0 for 808 | # masked positions, this operation will create a tensor which is 0.0 for 809 | # positions we want to attend and -10000.0 for masked positions. 810 | # Since we are adding it to the raw scores before the softmax, this is 811 | # effectively the same as removing these entirely. 812 | attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility 813 | attention_mask = (1.0 - attention_mask) * -10000.0 814 | 815 | # If a 2D or 3D attention mask is provided for the cross-attention 816 | # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] 817 | if self.config.add_cross_attention and encoder_hidden_states is not None: 818 | encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() 819 | encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) 820 | if encoder_attention_mask is None: 821 | encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) 822 | encoder_attention_mask = self.invert_attention_mask(encoder_attention_mask) 823 | else: 824 | encoder_attention_mask = None 825 | 826 | # Prepare head mask if needed 827 | # 1.0 in head_mask indicate we keep the head 828 | # attention_probs has shape bsz x n_heads x N x N 829 | # head_mask has shape n_layer x batch x n_heads x N x N 830 | head_mask = self.get_head_mask(head_mask, self.config.n_layer) 831 | 832 | if inputs_embeds is None: 833 | inputs_embeds = self.wte(input_ids) 834 | position_embeds = self.wpe(position_ids) 835 | hidden_states = inputs_embeds + position_embeds 836 | 837 | if token_type_ids is not None: 838 | token_type_embeds = self.wte(token_type_ids) 839 | hidden_states = hidden_states + token_type_embeds 840 | 841 | hidden_states = self.drop(hidden_states) 842 | 843 | output_shape = input_shape + (hidden_states.size(-1),) 844 | 845 | presents = () if use_cache else None 846 | all_self_attentions = () if output_attentions else None 847 | all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None 848 | all_hidden_states = () if output_hidden_states else None 849 | for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)): 850 | 851 | # Model parallel 852 | if self.model_parallel: 853 | torch.cuda.set_device(hidden_states.device) 854 | # Ensure layer_past is on same device as hidden_states (might not be correct) 855 | if layer_past is not None: 856 | layer_past = tuple(past_state.to(hidden_states.device) for past_state in layer_past) 857 | # Ensure that attention_mask is always on the same device as hidden_states 858 | if attention_mask is not None: 859 | attention_mask = attention_mask.to(hidden_states.device) 860 | if isinstance(head_mask, torch.Tensor): 861 | head_mask = head_mask.to(hidden_states.device) 862 | if output_hidden_states: 863 | all_hidden_states = all_hidden_states + (hidden_states,) 864 | 865 | if self.gradient_checkpointing and self.training: 866 | 867 | if use_cache: 868 | logger.warning( 869 | "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." 870 | ) 871 | use_cache = False 872 | 873 | def create_custom_forward(module): 874 | def custom_forward(*inputs): 875 | # None for past_key_value 876 | return module(*inputs, use_cache, output_attentions) 877 | 878 | return custom_forward 879 | 880 | outputs = torch.utils.checkpoint.checkpoint( 881 | create_custom_forward(block), 882 | hidden_states, 883 | None, 884 | attention_mask, 885 | head_mask[i], 886 | encoder_hidden_states, 887 | encoder_attention_mask, 888 | ) 889 | else: 890 | outputs = block( 891 | hidden_states, 892 | layer_past=layer_past, 893 | attention_mask=attention_mask, 894 | head_mask=head_mask[i], 895 | encoder_hidden_states=encoder_hidden_states, 896 | encoder_attention_mask=encoder_attention_mask, 897 | use_cache=use_cache, 898 | output_attentions=output_attentions, 899 | ) 900 | 901 | hidden_states = outputs[0] 902 | if use_cache is True: 903 | presents = presents + (outputs[1],) 904 | 905 | if output_attentions: 906 | all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],) 907 | if self.config.add_cross_attention: 908 | all_cross_attentions = all_cross_attentions + (outputs[3 if use_cache else 2],) 909 | 910 | # Model Parallel: If it's the last layer for that device, put things on the next device 911 | if self.model_parallel: 912 | for k, v in self.device_map.items(): 913 | if i == v[-1] and "cuda:" + str(k) != self.last_device: 914 | hidden_states = hidden_states.to("cuda:" + str(k + 1)) 915 | 916 | hidden_states = self.ln_f(hidden_states) 917 | 918 | hidden_states = hidden_states.view(*output_shape) 919 | # Add last hidden state 920 | if output_hidden_states: 921 | all_hidden_states = all_hidden_states + (hidden_states,) 922 | 923 | if not return_dict: 924 | return tuple( 925 | v 926 | for v in [hidden_states, presents, all_hidden_states, all_self_attentions, all_cross_attentions] 927 | if v is not None 928 | ) 929 | 930 | return BaseModelOutputWithPastAndCrossAttentions( 931 | last_hidden_state=hidden_states, 932 | past_key_values=presents, 933 | hidden_states=all_hidden_states, 934 | attentions=all_self_attentions, 935 | cross_attentions=all_cross_attentions, 936 | ) 937 | 938 | 939 | @add_start_docstrings( 940 | """ 941 | The GPT2 Model transformer with a language modeling head on top (linear layer with weights tied to the input 942 | embeddings). 943 | """, 944 | GPT2_START_DOCSTRING, 945 | ) 946 | class GPT2LMHeadModel(GPT2PreTrainedModel): 947 | _keys_to_ignore_on_load_missing = [r"attn.masked_bias", r"attn.bias", r"lm_head.weight"] 948 | 949 | def __init__(self, config): 950 | super().__init__(config) 951 | self.transformer = GPT2Model(config) 952 | self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) 953 | 954 | # Model parallel 955 | self.model_parallel = False 956 | self.device_map = None 957 | 958 | # Initialize weights and apply final processing 959 | self.post_init() 960 | 961 | @add_start_docstrings(PARALLELIZE_DOCSTRING) 962 | def parallelize(self, device_map=None): 963 | self.device_map = ( 964 | get_device_map(len(self.transformer.h), range(torch.cuda.device_count())) 965 | if device_map is None 966 | else device_map 967 | ) 968 | assert_device_map(self.device_map, len(self.transformer.h)) 969 | self.transformer.parallelize(self.device_map) 970 | self.lm_head = self.lm_head.to(self.transformer.first_device) 971 | self.model_parallel = True 972 | 973 | @add_start_docstrings(DEPARALLELIZE_DOCSTRING) 974 | def deparallelize(self): 975 | self.transformer.deparallelize() 976 | self.transformer = self.transformer.to("cpu") 977 | self.lm_head = self.lm_head.to("cpu") 978 | self.model_parallel = False 979 | torch.cuda.empty_cache() 980 | 981 | def get_output_embeddings(self): 982 | return self.lm_head 983 | 984 | def set_output_embeddings(self, new_embeddings): 985 | self.lm_head = new_embeddings 986 | 987 | def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs): 988 | token_type_ids = kwargs.get("token_type_ids", None) 989 | # only last token for inputs_ids if past is defined in kwargs 990 | if past: 991 | input_ids = input_ids[:, -1].unsqueeze(-1) 992 | if token_type_ids is not None: 993 | token_type_ids = token_type_ids[:, -1].unsqueeze(-1) 994 | 995 | attention_mask = kwargs.get("attention_mask", None) 996 | position_ids = kwargs.get("position_ids", None) 997 | 998 | if attention_mask is not None and position_ids is None: 999 | # create position_ids on the fly for batch generation 1000 | position_ids = attention_mask.long().cumsum(-1) - 1 1001 | position_ids.masked_fill_(attention_mask == 0, 1) 1002 | if past: 1003 | position_ids = position_ids[:, -1].unsqueeze(-1) 1004 | else: 1005 | position_ids = None 1006 | return { 1007 | "input_ids": input_ids, 1008 | "past_key_values": past, 1009 | "use_cache": kwargs.get("use_cache"), 1010 | "position_ids": position_ids, 1011 | "attention_mask": attention_mask, 1012 | "token_type_ids": token_type_ids, 1013 | } 1014 | 1015 | @add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING) 1016 | @add_code_sample_docstrings( 1017 | processor_class=_TOKENIZER_FOR_DOC, 1018 | checkpoint=_CHECKPOINT_FOR_DOC, 1019 | output_type=CausalLMOutputWithCrossAttentions, 1020 | config_class=_CONFIG_FOR_DOC, 1021 | ) 1022 | def forward( 1023 | self, 1024 | input_ids=None, 1025 | past_key_values=None, 1026 | attention_mask=None, 1027 | token_type_ids=None, 1028 | position_ids=None, 1029 | head_mask=None, 1030 | inputs_embeds=None, 1031 | encoder_hidden_states=None, 1032 | encoder_attention_mask=None, 1033 | labels=None, 1034 | use_cache=None, 1035 | output_attentions=None, 1036 | output_hidden_states=None, 1037 | return_dict=None, 1038 | ): 1039 | r""" 1040 | labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): 1041 | Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set 1042 | `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100` 1043 | are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]` 1044 | """ 1045 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 1046 | 1047 | transformer_outputs = self.transformer( 1048 | input_ids, 1049 | past_key_values=past_key_values, 1050 | attention_mask=attention_mask, 1051 | token_type_ids=token_type_ids, 1052 | position_ids=position_ids, 1053 | head_mask=head_mask, 1054 | inputs_embeds=inputs_embeds, 1055 | encoder_hidden_states=encoder_hidden_states, 1056 | encoder_attention_mask=encoder_attention_mask, 1057 | use_cache=use_cache, 1058 | output_attentions=output_attentions, 1059 | output_hidden_states=output_hidden_states, 1060 | return_dict=return_dict, 1061 | ) 1062 | hidden_states = transformer_outputs[0] 1063 | 1064 | # Set device for model parallelism 1065 | if self.model_parallel: 1066 | torch.cuda.set_device(self.transformer.first_device) 1067 | hidden_states = hidden_states.to(self.lm_head.weight.device) 1068 | 1069 | lm_logits = self.lm_head(hidden_states) 1070 | 1071 | loss = None 1072 | if labels is not None: 1073 | # Shift so that tokens < n predict n 1074 | shift_logits = lm_logits[..., :-1, :].contiguous() 1075 | shift_labels = labels[..., 1:].contiguous() 1076 | # Flatten the tokens 1077 | loss_fct = CrossEntropyLoss() 1078 | loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) 1079 | 1080 | if not return_dict: 1081 | output = (lm_logits,) + transformer_outputs[1:] 1082 | return ((loss,) + output) if loss is not None else output 1083 | 1084 | return CausalLMOutputWithCrossAttentions( 1085 | loss=loss, 1086 | logits=lm_logits, 1087 | past_key_values=transformer_outputs.past_key_values, 1088 | hidden_states=transformer_outputs.hidden_states, 1089 | attentions=transformer_outputs.attentions, 1090 | cross_attentions=transformer_outputs.cross_attentions, 1091 | ) 1092 | 1093 | @staticmethod 1094 | def _reorder_cache(past: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor) -> Tuple[Tuple[torch.Tensor]]: 1095 | """ 1096 | This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or 1097 | [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct 1098 | beam_idx at every generation step. 1099 | """ 1100 | return tuple( 1101 | tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past) 1102 | for layer_past in past 1103 | ) 1104 | 1105 | 1106 | @add_start_docstrings( 1107 | """ 1108 | The GPT2 Model transformer with a language modeling and a multiple-choice classification head on top e.g. for 1109 | RocStories/SWAG tasks. The two heads are two linear layers. The language modeling head has its weights tied to the 1110 | input embeddings, the classification head takes as input the input of a specified classification token index in the 1111 | input sequence). 1112 | """, 1113 | GPT2_START_DOCSTRING, 1114 | ) 1115 | class GPT2DoubleHeadsModel(GPT2PreTrainedModel): 1116 | _keys_to_ignore_on_load_missing = [r"attn.masked_bias", r"attn.bias", r"lm_head.weight"] 1117 | 1118 | def __init__(self, config): 1119 | super().__init__(config) 1120 | config.num_labels = 1 1121 | self.transformer = GPT2Model(config) 1122 | self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) 1123 | self.multiple_choice_head = SequenceSummary(config) 1124 | 1125 | # Model parallel 1126 | self.model_parallel = False 1127 | self.device_map = None 1128 | 1129 | # Initialize weights and apply final processing 1130 | self.post_init() 1131 | 1132 | @add_start_docstrings(PARALLELIZE_DOCSTRING) 1133 | def parallelize(self, device_map=None): 1134 | self.device_map = ( 1135 | get_device_map(len(self.transformer.h), range(torch.cuda.device_count())) 1136 | if device_map is None 1137 | else device_map 1138 | ) 1139 | assert_device_map(self.device_map, len(self.transformer.h)) 1140 | self.transformer.parallelize(self.device_map) 1141 | self.lm_head = self.lm_head.to(self.transformer.first_device) 1142 | self.multiple_choice_head = self.multiple_choice_head.to(self.transformer.first_device) 1143 | self.model_parallel = True 1144 | 1145 | @add_start_docstrings(DEPARALLELIZE_DOCSTRING) 1146 | def deparallelize(self): 1147 | self.transformer.deparallelize() 1148 | self.transformer = self.transformer.to("cpu") 1149 | self.lm_head = self.lm_head.to("cpu") 1150 | self.multiple_choice_head = self.multiple_choice_head.to("cpu") 1151 | self.model_parallel = False 1152 | torch.cuda.empty_cache() 1153 | 1154 | def get_output_embeddings(self): 1155 | return self.lm_head 1156 | 1157 | def set_output_embeddings(self, new_embeddings): 1158 | self.lm_head = new_embeddings 1159 | 1160 | def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs): 1161 | token_type_ids = kwargs.get("token_type_ids", None) 1162 | # only last token for inputs_ids if past is defined in kwargs 1163 | if past: 1164 | input_ids = input_ids[:, -1].unsqueeze(-1) 1165 | if token_type_ids is not None: 1166 | token_type_ids = token_type_ids[:, -1].unsqueeze(-1) 1167 | 1168 | attention_mask = kwargs.get("attention_mask", None) 1169 | position_ids = kwargs.get("position_ids", None) 1170 | 1171 | if attention_mask is not None and position_ids is None: 1172 | # create position_ids on the fly for batch generation 1173 | position_ids = attention_mask.long().cumsum(-1) - 1 1174 | position_ids.masked_fill_(attention_mask == 0, 1) 1175 | if past: 1176 | position_ids = position_ids[:, -1].unsqueeze(-1) 1177 | else: 1178 | position_ids = None 1179 | 1180 | return { 1181 | "input_ids": input_ids, 1182 | "past_key_values": past, 1183 | "use_cache": kwargs.get("use_cache"), 1184 | "position_ids": position_ids, 1185 | "attention_mask": attention_mask, 1186 | "token_type_ids": token_type_ids, 1187 | } 1188 | 1189 | @add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING) 1190 | @replace_return_docstrings(output_type=GPT2DoubleHeadsModelOutput, config_class=_CONFIG_FOR_DOC) 1191 | def forward( 1192 | self, 1193 | input_ids=None, 1194 | past_key_values=None, 1195 | attention_mask=None, 1196 | token_type_ids=None, 1197 | position_ids=None, 1198 | head_mask=None, 1199 | inputs_embeds=None, 1200 | mc_token_ids=None, 1201 | labels=None, 1202 | mc_labels=None, 1203 | use_cache=None, 1204 | output_attentions=None, 1205 | output_hidden_states=None, 1206 | return_dict=None, 1207 | **kwargs, 1208 | ): 1209 | r""" 1210 | mc_token_ids (`torch.LongTensor` of shape `(batch_size, num_choices)`, *optional*, default to index of the last token of the input): 1211 | Index of the classification token in each input sequence. Selected in the range `[0, input_ids.size(-1) - 1212 | 1[`. 1213 | labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): 1214 | Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set 1215 | `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size - 1]` All labels set to 1216 | `-100` are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size - 1]` 1217 | mc_labels (`torch.LongTensor` of shape `(batch_size)`, *optional*): 1218 | Labels for computing the multiple choice classification loss. Indices should be in `[0, ..., num_choices]` 1219 | where *num_choices* is the size of the second dimension of the input tensors. (see *input_ids* above) 1220 | 1221 | Return: 1222 | 1223 | Example: 1224 | 1225 | ```python 1226 | >>> import torch 1227 | >>> from transformers import GPT2Tokenizer, GPT2DoubleHeadsModel 1228 | 1229 | >>> tokenizer = GPT2Tokenizer.from_pretrained("gpt2") 1230 | >>> model = GPT2DoubleHeadsModel.from_pretrained("gpt2") 1231 | 1232 | >>> # Add a [CLS] to the vocabulary (we should train it also!) 1233 | >>> num_added_tokens = tokenizer.add_special_tokens({"cls_token": "[CLS]"}) 1234 | 1235 | >>> embedding_layer = model.resize_token_embeddings( 1236 | ... len(tokenizer) 1237 | >>> ) # Update the model embeddings with the new vocabulary size 1238 | 1239 | >>> choices = ["Hello, my dog is cute [CLS]", "Hello, my cat is cute [CLS]"] 1240 | >>> encoded_choices = [tokenizer.encode(s) for s in choices] 1241 | >>> cls_token_location = [tokens.index(tokenizer.cls_token_id) for tokens in encoded_choices] 1242 | 1243 | >>> input_ids = torch.tensor(encoded_choices).unsqueeze(0) # Batch size: 1, number of choices: 2 1244 | >>> mc_token_ids = torch.tensor([cls_token_location]) # Batch size: 1 1245 | 1246 | >>> outputs = model(input_ids, mc_token_ids=mc_token_ids) 1247 | >>> lm_logits = outputs.logits 1248 | >>> mc_logits = outputs.mc_logits 1249 | ```""" 1250 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 1251 | 1252 | transformer_outputs = self.transformer( 1253 | input_ids, 1254 | past_key_values=past_key_values, 1255 | attention_mask=attention_mask, 1256 | token_type_ids=token_type_ids, 1257 | position_ids=position_ids, 1258 | head_mask=head_mask, 1259 | inputs_embeds=inputs_embeds, 1260 | use_cache=use_cache, 1261 | output_attentions=output_attentions, 1262 | output_hidden_states=output_hidden_states, 1263 | return_dict=return_dict, 1264 | ) 1265 | 1266 | hidden_states = transformer_outputs[0] 1267 | 1268 | # Set device for model parallelism 1269 | if self.model_parallel: 1270 | torch.cuda.set_device(self.transformer.first_device) 1271 | hidden_states = hidden_states.to(self.lm_head.weight.device) 1272 | 1273 | lm_logits = self.lm_head(hidden_states) 1274 | mc_logits = self.multiple_choice_head(hidden_states, mc_token_ids).squeeze(-1) 1275 | 1276 | mc_loss = None 1277 | if mc_labels is not None: 1278 | loss_fct = CrossEntropyLoss() 1279 | mc_loss = loss_fct(mc_logits.view(-1, mc_logits.size(-1)), mc_labels.view(-1)) 1280 | lm_loss = None 1281 | if labels is not None: 1282 | shift_logits = lm_logits[..., :-1, :].contiguous() 1283 | shift_labels = labels[..., 1:].contiguous() 1284 | loss_fct = CrossEntropyLoss() 1285 | lm_loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) 1286 | 1287 | if not return_dict: 1288 | output = (lm_logits, mc_logits) + transformer_outputs[1:] 1289 | if mc_loss is not None: 1290 | output = (mc_loss,) + output 1291 | return ((lm_loss,) + output) if lm_loss is not None else output 1292 | 1293 | return GPT2DoubleHeadsModelOutput( 1294 | loss=lm_loss, 1295 | mc_loss=mc_loss, 1296 | logits=lm_logits, 1297 | mc_logits=mc_logits, 1298 | past_key_values=transformer_outputs.past_key_values, 1299 | hidden_states=transformer_outputs.hidden_states, 1300 | attentions=transformer_outputs.attentions, 1301 | ) 1302 | 1303 | @staticmethod 1304 | def _reorder_cache(past: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor) -> Tuple[Tuple[torch.Tensor]]: 1305 | """ 1306 | This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or 1307 | [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct 1308 | beam_idx at every generation step. 1309 | """ 1310 | return tuple( 1311 | tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past) 1312 | for layer_past in past 1313 | ) 1314 | 1315 | 1316 | @add_start_docstrings( 1317 | """ 1318 | The GPT2 Model transformer with a sequence classification head on top (linear layer). 1319 | 1320 | [`GPT2ForSequenceClassification`] uses the last token in order to do the classification, as other causal models 1321 | (e.g. GPT-1) do. 1322 | 1323 | Since it does classification on the last token, it requires to know the position of the last token. If a 1324 | `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If 1325 | no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the 1326 | padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in 1327 | each row of the batch). 1328 | """, 1329 | GPT2_START_DOCSTRING, 1330 | ) 1331 | class GPT2ForSequenceClassification(GPT2PreTrainedModel): 1332 | _keys_to_ignore_on_load_missing = [r"h\.\d+\.attn\.masked_bias", r"lm_head\.weight"] 1333 | 1334 | def __init__(self, config): 1335 | super().__init__(config) 1336 | self.num_labels = config.num_labels 1337 | self.transformer = GPT2Model(config) 1338 | self.score = nn.Linear(config.n_embd, self.num_labels, bias=False) 1339 | 1340 | # Model parallel 1341 | self.model_parallel = False 1342 | self.device_map = None 1343 | 1344 | # Initialize weights and apply final processing 1345 | self.post_init() 1346 | 1347 | @add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING) 1348 | @add_code_sample_docstrings( 1349 | processor_class=_TOKENIZER_FOR_DOC, 1350 | checkpoint="microsoft/DialogRPT-updown", 1351 | output_type=SequenceClassifierOutputWithPast, 1352 | config_class=_CONFIG_FOR_DOC, 1353 | ) 1354 | def forward( 1355 | self, 1356 | input_ids=None, 1357 | past_key_values=None, 1358 | attention_mask=None, 1359 | token_type_ids=None, 1360 | position_ids=None, 1361 | head_mask=None, 1362 | inputs_embeds=None, 1363 | labels=None, 1364 | use_cache=None, 1365 | output_attentions=None, 1366 | output_hidden_states=None, 1367 | return_dict=None, 1368 | ): 1369 | r""" 1370 | labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): 1371 | Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., 1372 | config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If 1373 | `config.num_labels > 1` a classification loss is computed (Cross-Entropy). 1374 | """ 1375 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 1376 | 1377 | transformer_outputs = self.transformer( 1378 | input_ids, 1379 | past_key_values=past_key_values, 1380 | attention_mask=attention_mask, 1381 | token_type_ids=token_type_ids, 1382 | position_ids=position_ids, 1383 | head_mask=head_mask, 1384 | inputs_embeds=inputs_embeds, 1385 | use_cache=use_cache, 1386 | output_attentions=output_attentions, 1387 | output_hidden_states=output_hidden_states, 1388 | return_dict=return_dict, 1389 | ) 1390 | hidden_states = transformer_outputs[0] 1391 | logits = self.score(hidden_states) 1392 | 1393 | if input_ids is not None: 1394 | batch_size, sequence_length = input_ids.shape[:2] 1395 | else: 1396 | batch_size, sequence_length = inputs_embeds.shape[:2] 1397 | 1398 | assert ( 1399 | self.config.pad_token_id is not None or batch_size == 1 1400 | ), "Cannot handle batch sizes > 1 if no padding token is defined." 1401 | if self.config.pad_token_id is None: 1402 | sequence_lengths = -1 1403 | else: 1404 | if input_ids is not None: 1405 | sequence_lengths = torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1 1406 | else: 1407 | sequence_lengths = -1 1408 | logger.warning( 1409 | f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be " 1410 | f"unexpected if using padding tokens in conjunction with `inputs_embeds.`" 1411 | ) 1412 | 1413 | pooled_logits = logits[range(batch_size), sequence_lengths] 1414 | 1415 | loss = None 1416 | if labels is not None: 1417 | if self.config.problem_type is None: 1418 | if self.num_labels == 1: 1419 | self.config.problem_type = "regression" 1420 | elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): 1421 | self.config.problem_type = "single_label_classification" 1422 | else: 1423 | self.config.problem_type = "multi_label_classification" 1424 | 1425 | if self.config.problem_type == "regression": 1426 | loss_fct = MSELoss() 1427 | if self.num_labels == 1: 1428 | loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) 1429 | else: 1430 | loss = loss_fct(pooled_logits, labels) 1431 | elif self.config.problem_type == "single_label_classification": 1432 | loss_fct = CrossEntropyLoss() 1433 | loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1)) 1434 | elif self.config.problem_type == "multi_label_classification": 1435 | loss_fct = BCEWithLogitsLoss() 1436 | loss = loss_fct(pooled_logits, labels) 1437 | if not return_dict: 1438 | output = (pooled_logits,) + transformer_outputs[1:] 1439 | return ((loss,) + output) if loss is not None else output 1440 | 1441 | return SequenceClassifierOutputWithPast( 1442 | loss=loss, 1443 | logits=pooled_logits, 1444 | past_key_values=transformer_outputs.past_key_values, 1445 | hidden_states=transformer_outputs.hidden_states, 1446 | attentions=transformer_outputs.attentions, 1447 | ) 1448 | 1449 | 1450 | @add_start_docstrings( 1451 | """ 1452 | GPT2 Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for 1453 | Named-Entity-Recognition (NER) tasks. 1454 | """, 1455 | GPT2_START_DOCSTRING, 1456 | ) 1457 | class GPT2ForTokenClassification(GPT2PreTrainedModel): 1458 | def __init__(self, config): 1459 | super().__init__(config) 1460 | self.num_labels = config.num_labels 1461 | 1462 | self.transformer = GPT2Model(config) 1463 | if hasattr(config, "classifier_dropout") and config.classifier_dropout is not None: 1464 | classifier_dropout = config.classifier_dropout 1465 | elif hasattr(config, "hidden_dropout") and config.hidden_dropout is not None: 1466 | classifier_dropout = config.hidden_dropout 1467 | else: 1468 | classifier_dropout = 0.1 1469 | self.dropout = nn.Dropout(classifier_dropout) 1470 | self.classifier = nn.Linear(config.hidden_size, config.num_labels) 1471 | 1472 | # Model parallel 1473 | self.model_parallel = False 1474 | self.device_map = None 1475 | 1476 | # Initialize weights and apply final processing 1477 | self.post_init() 1478 | 1479 | @add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING) 1480 | @add_code_sample_docstrings( 1481 | processor_class=_TOKENIZER_FOR_DOC, 1482 | checkpoint="microsoft/DialogRPT-updown", 1483 | output_type=TokenClassifierOutput, 1484 | config_class=_CONFIG_FOR_DOC, 1485 | ) 1486 | def forward( 1487 | self, 1488 | input_ids=None, 1489 | past_key_values=None, 1490 | attention_mask=None, 1491 | token_type_ids=None, 1492 | position_ids=None, 1493 | head_mask=None, 1494 | inputs_embeds=None, 1495 | labels=None, 1496 | use_cache=None, 1497 | output_attentions=None, 1498 | output_hidden_states=None, 1499 | return_dict=None, 1500 | ): 1501 | r""" 1502 | labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): 1503 | Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., 1504 | config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If 1505 | `config.num_labels > 1` a classification loss is computed (Cross-Entropy). 1506 | """ 1507 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 1508 | 1509 | transformer_outputs = self.transformer( 1510 | input_ids, 1511 | past_key_values=past_key_values, 1512 | attention_mask=attention_mask, 1513 | token_type_ids=token_type_ids, 1514 | position_ids=position_ids, 1515 | head_mask=head_mask, 1516 | inputs_embeds=inputs_embeds, 1517 | use_cache=use_cache, 1518 | output_attentions=output_attentions, 1519 | output_hidden_states=output_hidden_states, 1520 | return_dict=return_dict, 1521 | ) 1522 | 1523 | hidden_states = transformer_outputs[0] 1524 | hidden_states = self.dropout(hidden_states) 1525 | logits = self.classifier(hidden_states) 1526 | 1527 | loss = None 1528 | if labels is not None: 1529 | loss_fct = CrossEntropyLoss() 1530 | # Only keep active parts of the loss 1531 | if attention_mask is not None: 1532 | active_loss = attention_mask.view(-1) == 1 1533 | active_logits = logits.view(-1, self.num_labels) 1534 | active_labels = torch.where( 1535 | active_loss, labels.view(-1), torch.tensor(loss_fct.ignore_index).type_as(labels) 1536 | ) 1537 | loss = loss_fct(active_logits, active_labels) 1538 | else: 1539 | loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) 1540 | 1541 | if not return_dict: 1542 | output = (logits,) + transformer_outputs[2:] 1543 | return ((loss,) + output) if loss is not None else output 1544 | 1545 | return TokenClassifierOutput( 1546 | loss=loss, 1547 | logits=logits, 1548 | hidden_states=transformer_outputs.hidden_states, 1549 | attentions=transformer_outputs.attentions, 1550 | ) -------------------------------------------------------------------------------- /mutransformers/models/gpt2/configuration_gpt2.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team. 3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 4 | # Copyright 2022 Microsoft Corporation. 5 | # 6 | # Licensed under the Apache License, Version 2.0 (the "License"); 7 | # you may not use this file except in compliance with the License. 8 | # You may obtain a copy of the License at 9 | # 10 | # http://www.apache.org/licenses/LICENSE-2.0 11 | # 12 | # Unless required by applicable law or agreed to in writing, software 13 | # distributed under the License is distributed on an "AS IS" BASIS, 14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | # See the License for the specific language governing permissions and 16 | # limitations under the License. 17 | """ OpenAI GPT-2 configuration""" 18 | from collections import OrderedDict 19 | from typing import Any, List, Mapping, Optional 20 | 21 | from transformers import PreTrainedTokenizer, TensorType, is_torch_available 22 | 23 | from transformers.configuration_utils import PretrainedConfig 24 | from transformers.onnx import OnnxConfigWithPast, PatchingSpec 25 | from transformers.utils import logging 26 | 27 | 28 | logger = logging.get_logger(__name__) 29 | 30 | GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP = { 31 | "gpt2": "https://huggingface.co/gpt2/resolve/main/config.json", 32 | "gpt2-medium": "https://huggingface.co/gpt2-medium/resolve/main/config.json", 33 | "gpt2-large": "https://huggingface.co/gpt2-large/resolve/main/config.json", 34 | "gpt2-xl": "https://huggingface.co/gpt2-xl/resolve/main/config.json", 35 | "distilgpt2": "https://huggingface.co/distilgpt2/resolve/main/config.json", 36 | } 37 | 38 | 39 | class GPT2Config(PretrainedConfig): 40 | """ 41 | This is the configuration class to store the configuration of a [`GPT2Model`] or a [`TFGPT2Model`]. It is used to 42 | instantiate a GPT-2 model according to the specified arguments, defining the model architecture. Instantiating a 43 | configuration with the defaults will yield a similar configuration to that of the GPT-2 44 | [small](https://huggingface.co/gpt2) architecture. 45 | 46 | Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the 47 | documentation from [`PretrainedConfig`] for more information. 48 | 49 | 50 | Args: 51 | vocab_size (`int`, *optional*, defaults to 50257): 52 | Vocabulary size of the GPT-2 model. Defines the number of different tokens that can be represented by the 53 | `inputs_ids` passed when calling [`GPT2Model`] or [`TFGPT2Model`]. 54 | n_positions (`int`, *optional*, defaults to 1024): 55 | The maximum sequence length that this model might ever be used with. Typically set this to something large 56 | just in case (e.g., 512 or 1024 or 2048). 57 | n_embd (`int`, *optional*, defaults to 768): 58 | Dimensionality of the embeddings and hidden states. 59 | n_layer (`int`, *optional*, defaults to 12): 60 | Number of hidden layers in the Transformer encoder. 61 | n_head (`int`, *optional*, defaults to 12): 62 | Number of attention heads for each attention layer in the Transformer encoder. 63 | n_inner (`int`, *optional*, defaults to None): 64 | Dimensionality of the inner feed-forward layers. `None` will set it to 4 times n_embd 65 | activation_function (`str`, *optional*, defaults to `"gelu"`): 66 | Activation function, to be selected in the list `["relu", "silu", "gelu", "tanh", "gelu_new"]`. 67 | resid_pdrop (`float`, *optional*, defaults to 0.1): 68 | The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. 69 | embd_pdrop (`int`, *optional*, defaults to 0.1): 70 | The dropout ratio for the embeddings. 71 | attn_pdrop (`float`, *optional*, defaults to 0.1): 72 | The dropout ratio for the attention. 73 | layer_norm_epsilon (`float`, *optional*, defaults to 1e-5): 74 | The epsilon to use in the layer normalization layers. 75 | initializer_range (`float`, *optional*, defaults to 0.02): 76 | The standard deviation of the truncated_normal_initializer for initializing all weight matrices. 77 | summary_type (`string`, *optional*, defaults to `"cls_index"`): 78 | Argument used when doing sequence summary, used in the models [`GPT2DoubleHeadsModel`] and 79 | [`TFGPT2DoubleHeadsModel`]. 80 | 81 | Has to be one of the following options: 82 | 83 | - `"last"`: Take the last token hidden state (like XLNet). 84 | - `"first"`: Take the first token hidden state (like BERT). 85 | - `"mean"`: Take the mean of all tokens hidden states. 86 | - `"cls_index"`: Supply a Tensor of classification token position (like GPT/GPT-2). 87 | - `"attn"`: Not implemented now, use multi-head attention. 88 | summary_use_proj (`bool`, *optional*, defaults to `True`): 89 | Argument used when doing sequence summary, used in the models [`GPT2DoubleHeadsModel`] and 90 | [`TFGPT2DoubleHeadsModel`]. 91 | 92 | Whether or not to add a projection after the vector extraction. 93 | summary_activation (`str`, *optional*): 94 | Argument used when doing sequence summary. Used in for the multiple choice head in 95 | [`GPT2DoubleHeadsModel`]. 96 | 97 | Pass `"tanh"` for a tanh activation to the output, any other value will result in no activation. 98 | summary_proj_to_labels (`bool`, *optional*, defaults to `True`): 99 | Argument used when doing sequence summary, used in the models [`GPT2DoubleHeadsModel`] and 100 | [`TFGPT2DoubleHeadsModel`]. 101 | 102 | Whether the projection outputs should have `config.num_labels` or `config.hidden_size` classes. 103 | summary_first_dropout (`float`, *optional*, defaults to 0.1): 104 | Argument used when doing sequence summary, used in the models [`GPT2DoubleHeadsModel`] and 105 | [`TFGPT2DoubleHeadsModel`]. 106 | 107 | The dropout ratio to be used after the projection and activation. 108 | scale_attn_weights (`bool`, *optional*, defaults to `True`): 109 | Scale attention weights by dividing by sqrt(hidden_size).. 110 | use_cache (`bool`, *optional*, defaults to `True`): 111 | Whether or not the model should return the last key/values attentions (not used by all models). 112 | scale_attn_by_inverse_layer_idx (`bool`, *optional*, defaults to `False`): 113 | Whether to additionally scale attention weights by `1 / layer_idx + 1`. 114 | reorder_and_upcast_attn (`bool`, *optional*, defaults to `False`): 115 | Whether to scale keys (K) prior to computing attention (dot-product) and upcast attention 116 | dot-product/softmax to float() when training with mixed precision. 117 | 118 | Example: 119 | 120 | ```python 121 | >>> from transformers import GPT2Model, GPT2Config 122 | 123 | >>> # Initializing a GPT2 configuration 124 | >>> configuration = GPT2Config() 125 | 126 | >>> # Initializing a model from the configuration 127 | >>> model = GPT2Model(configuration) 128 | 129 | >>> # Accessing the model configuration 130 | >>> configuration = model.config 131 | ```""" 132 | 133 | model_type = "gpt2" 134 | keys_to_ignore_at_inference = ["past_key_values"] 135 | attribute_map = { 136 | "hidden_size": "n_embd", 137 | "max_position_embeddings": "n_positions", 138 | "num_attention_heads": "n_head", 139 | "num_hidden_layers": "n_layer", 140 | } 141 | 142 | def __init__( 143 | self, 144 | vocab_size=50257, 145 | n_positions=1024, 146 | n_embd=768, 147 | n_layer=12, 148 | n_head=12, 149 | n_inner=None, 150 | activation_function="gelu_new", 151 | resid_pdrop=0.1, 152 | embd_pdrop=0.1, 153 | attn_pdrop=0.1, 154 | layer_norm_epsilon=1e-5, 155 | initializer_range=0.02, 156 | summary_type="cls_index", 157 | summary_use_proj=True, 158 | summary_activation=None, 159 | summary_proj_to_labels=True, 160 | summary_first_dropout=0.1, 161 | scale_attn_weights=True, 162 | use_cache=True, 163 | bos_token_id=50256, 164 | eos_token_id=50256, 165 | scale_attn_by_inverse_layer_idx=False, 166 | reorder_and_upcast_attn=False, 167 | ### muP 168 | attn_mult=None, 169 | **kwargs, 170 | ): 171 | self.vocab_size = vocab_size 172 | self.n_positions = n_positions 173 | self.n_embd = n_embd 174 | self.n_layer = n_layer 175 | self.n_head = n_head 176 | self.n_inner = n_inner 177 | self.activation_function = activation_function 178 | self.resid_pdrop = resid_pdrop 179 | self.embd_pdrop = embd_pdrop 180 | self.attn_pdrop = attn_pdrop 181 | self.layer_norm_epsilon = layer_norm_epsilon 182 | self.initializer_range = initializer_range 183 | self.summary_type = summary_type 184 | self.summary_use_proj = summary_use_proj 185 | self.summary_activation = summary_activation 186 | self.summary_first_dropout = summary_first_dropout 187 | self.summary_proj_to_labels = summary_proj_to_labels 188 | self.scale_attn_weights = scale_attn_weights 189 | self.use_cache = use_cache 190 | self.scale_attn_by_inverse_layer_idx = scale_attn_by_inverse_layer_idx 191 | self.reorder_and_upcast_attn = reorder_and_upcast_attn 192 | 193 | self.bos_token_id = bos_token_id 194 | self.eos_token_id = eos_token_id 195 | 196 | ### muP 197 | if attn_mult is None: 198 | # defaults back to 1/sqrt(d) attn 199 | self.attn_mult = (self.hidden_size / self.num_attention_heads)**0.5 200 | else: 201 | self.attn_mult = attn_mult 202 | 203 | super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs) 204 | 205 | 206 | class GPT2OnnxConfig(OnnxConfigWithPast): 207 | def __init__( 208 | self, 209 | config: PretrainedConfig, 210 | task: str = "default", 211 | patching_specs: List[PatchingSpec] = None, 212 | use_past: bool = False, 213 | ): 214 | super().__init__(config, task=task, patching_specs=patching_specs, use_past=use_past) 215 | if not getattr(self._config, "pad_token_id", None): 216 | # TODO: how to do that better? 217 | self._config.pad_token_id = 0 218 | 219 | @property 220 | def inputs(self) -> Mapping[str, Mapping[int, str]]: 221 | common_inputs = OrderedDict({"input_ids": {0: "batch", 1: "sequence"}}) 222 | if self.use_past: 223 | self.fill_with_past_key_values_(common_inputs, direction="inputs") 224 | common_inputs["attention_mask"] = {0: "batch", 1: "past_sequence + sequence"} 225 | else: 226 | common_inputs["attention_mask"] = {0: "batch", 1: "sequence"} 227 | 228 | return common_inputs 229 | 230 | @property 231 | def num_layers(self) -> int: 232 | return self._config.n_layer 233 | 234 | @property 235 | def num_attention_heads(self) -> int: 236 | return self._config.n_head 237 | 238 | def generate_dummy_inputs( 239 | self, 240 | tokenizer: PreTrainedTokenizer, 241 | batch_size: int = -1, 242 | seq_length: int = -1, 243 | is_pair: bool = False, 244 | framework: Optional[TensorType] = None, 245 | ) -> Mapping[str, Any]: 246 | common_inputs = super(OnnxConfigWithPast, self).generate_dummy_inputs( 247 | tokenizer, batch_size, seq_length, is_pair, framework 248 | ) 249 | 250 | # We need to order the input in the way they appears in the forward() 251 | ordered_inputs = OrderedDict({"input_ids": common_inputs["input_ids"]}) 252 | 253 | # Need to add the past_keys 254 | if self.use_past: 255 | if not is_torch_available(): 256 | raise ValueError("Cannot generate dummy past_keys inputs without PyTorch installed.") 257 | else: 258 | import torch 259 | 260 | batch, seqlen = common_inputs["input_ids"].shape 261 | # Not using the same length for past_key_values 262 | past_key_values_length = seqlen + 2 263 | past_shape = ( 264 | batch, 265 | self.num_attention_heads, 266 | past_key_values_length, 267 | self._config.hidden_size // self.num_attention_heads, 268 | ) 269 | ordered_inputs["past_key_values"] = [ 270 | (torch.zeros(past_shape), torch.zeros(past_shape)) for _ in range(self.num_layers) 271 | ] 272 | 273 | ordered_inputs["attention_mask"] = common_inputs["attention_mask"] 274 | if self.use_past: 275 | ordered_inputs["attention_mask"] = torch.cat( 276 | [ordered_inputs["attention_mask"], torch.ones(batch, past_key_values_length)], dim=1 277 | ) 278 | 279 | return ordered_inputs 280 | 281 | @property 282 | def default_onnx_opset(self) -> int: 283 | return 13 -------------------------------------------------------------------------------- /mutransformers/models/roberta/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/mutransformers/480287ce7b18a07a3432e8f2fbc0f0e5b71e2599/mutransformers/models/roberta/__init__.py -------------------------------------------------------------------------------- /mutransformers/models/roberta/_original_configuration_roberta.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. 3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """ RoBERTa configuration""" 17 | from collections import OrderedDict 18 | from typing import Mapping 19 | 20 | from ...onnx import OnnxConfig 21 | from ...utils import logging 22 | from ..bert.configuration_bert import BertConfig 23 | 24 | 25 | logger = logging.get_logger(__name__) 26 | 27 | ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP = { 28 | "roberta-base": "https://huggingface.co/roberta-base/resolve/main/config.json", 29 | "roberta-large": "https://huggingface.co/roberta-large/resolve/main/config.json", 30 | "roberta-large-mnli": "https://huggingface.co/roberta-large-mnli/resolve/main/config.json", 31 | "distilroberta-base": "https://huggingface.co/distilroberta-base/resolve/main/config.json", 32 | "roberta-base-openai-detector": "https://huggingface.co/roberta-base-openai-detector/resolve/main/config.json", 33 | "roberta-large-openai-detector": "https://huggingface.co/roberta-large-openai-detector/resolve/main/config.json", 34 | } 35 | 36 | 37 | class RobertaConfig(BertConfig): 38 | r""" 39 | This is the configuration class to store the configuration of a [`RobertaModel`] or a [`TFRobertaModel`]. It is 40 | used to instantiate a RoBERTa model according to the specified arguments, defining the model architecture. 41 | 42 | 43 | Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the 44 | documentation from [`PretrainedConfig`] for more information. 45 | 46 | The [`RobertaConfig`] class directly inherits [`BertConfig`]. It reuses the same defaults. Please check the parent 47 | class for more information. 48 | 49 | Examples: 50 | 51 | ```python 52 | >>> from transformers import RobertaConfig, RobertaModel 53 | 54 | >>> # Initializing a RoBERTa configuration 55 | >>> configuration = RobertaConfig() 56 | 57 | >>> # Initializing a model from the configuration 58 | >>> model = RobertaModel(configuration) 59 | 60 | >>> # Accessing the model configuration 61 | >>> configuration = model.config 62 | ```""" 63 | model_type = "roberta" 64 | 65 | def __init__(self, pad_token_id=1, bos_token_id=0, eos_token_id=2, **kwargs): 66 | """Constructs RobertaConfig.""" 67 | super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs) 68 | 69 | 70 | class RobertaOnnxConfig(OnnxConfig): 71 | @property 72 | def inputs(self) -> Mapping[str, Mapping[int, str]]: 73 | return OrderedDict( 74 | [ 75 | ("input_ids", {0: "batch", 1: "sequence"}), 76 | ("attention_mask", {0: "batch", 1: "sequence"}), 77 | ] 78 | ) -------------------------------------------------------------------------------- /mutransformers/models/roberta/_original_modeling_roberta.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. 3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """PyTorch RoBERTa model.""" 17 | 18 | import math 19 | 20 | import torch 21 | import torch.utils.checkpoint 22 | from packaging import version 23 | from torch import nn 24 | from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss 25 | 26 | from ...activations import ACT2FN, gelu 27 | from ...file_utils import ( 28 | add_code_sample_docstrings, 29 | add_start_docstrings, 30 | add_start_docstrings_to_model_forward, 31 | replace_return_docstrings, 32 | ) 33 | from ...modeling_outputs import ( 34 | BaseModelOutputWithPastAndCrossAttentions, 35 | BaseModelOutputWithPoolingAndCrossAttentions, 36 | CausalLMOutputWithCrossAttentions, 37 | MaskedLMOutput, 38 | MultipleChoiceModelOutput, 39 | QuestionAnsweringModelOutput, 40 | SequenceClassifierOutput, 41 | TokenClassifierOutput, 42 | ) 43 | from ...modeling_utils import ( 44 | PreTrainedModel, 45 | apply_chunking_to_forward, 46 | find_pruneable_heads_and_indices, 47 | prune_linear_layer, 48 | ) 49 | from ...utils import logging 50 | from .configuration_roberta import RobertaConfig 51 | 52 | 53 | logger = logging.get_logger(__name__) 54 | 55 | _CHECKPOINT_FOR_DOC = "roberta-base" 56 | _CONFIG_FOR_DOC = "RobertaConfig" 57 | _TOKENIZER_FOR_DOC = "RobertaTokenizer" 58 | 59 | ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST = [ 60 | "roberta-base", 61 | "roberta-large", 62 | "roberta-large-mnli", 63 | "distilroberta-base", 64 | "roberta-base-openai-detector", 65 | "roberta-large-openai-detector", 66 | # See all RoBERTa models at https://huggingface.co/models?filter=roberta 67 | ] 68 | 69 | 70 | class RobertaEmbeddings(nn.Module): 71 | """ 72 | Same as BertEmbeddings with a tiny tweak for positional embeddings indexing. 73 | """ 74 | 75 | # Copied from transformers.models.bert.modeling_bert.BertEmbeddings.__init__ 76 | def __init__(self, config): 77 | super().__init__() 78 | self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id) 79 | self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) 80 | self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size) 81 | 82 | # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load 83 | # any TensorFlow checkpoint file 84 | self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) 85 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 86 | # position_ids (1, len position emb) is contiguous in memory and exported when serialized 87 | self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") 88 | self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1))) 89 | if version.parse(torch.__version__) > version.parse("1.6.0"): 90 | self.register_buffer( 91 | "token_type_ids", 92 | torch.zeros(self.position_ids.size(), dtype=torch.long), 93 | persistent=False, 94 | ) 95 | 96 | # End copy 97 | self.padding_idx = config.pad_token_id 98 | self.position_embeddings = nn.Embedding( 99 | config.max_position_embeddings, config.hidden_size, padding_idx=self.padding_idx 100 | ) 101 | 102 | def forward( 103 | self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None, past_key_values_length=0 104 | ): 105 | if position_ids is None: 106 | if input_ids is not None: 107 | # Create the position ids from the input token ids. Any padded tokens remain padded. 108 | position_ids = create_position_ids_from_input_ids(input_ids, self.padding_idx, past_key_values_length) 109 | else: 110 | position_ids = self.create_position_ids_from_inputs_embeds(inputs_embeds) 111 | 112 | if input_ids is not None: 113 | input_shape = input_ids.size() 114 | else: 115 | input_shape = inputs_embeds.size()[:-1] 116 | 117 | seq_length = input_shape[1] 118 | 119 | # Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs 120 | # when its auto-generated, registered buffer helps users when tracing the model without passing token_type_ids, solves 121 | # issue #5664 122 | if token_type_ids is None: 123 | if hasattr(self, "token_type_ids"): 124 | buffered_token_type_ids = self.token_type_ids[:, :seq_length] 125 | buffered_token_type_ids_expanded = buffered_token_type_ids.expand(input_shape[0], seq_length) 126 | token_type_ids = buffered_token_type_ids_expanded 127 | else: 128 | token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device) 129 | 130 | if inputs_embeds is None: 131 | inputs_embeds = self.word_embeddings(input_ids) 132 | token_type_embeddings = self.token_type_embeddings(token_type_ids) 133 | 134 | embeddings = inputs_embeds + token_type_embeddings 135 | if self.position_embedding_type == "absolute": 136 | position_embeddings = self.position_embeddings(position_ids) 137 | embeddings += position_embeddings 138 | embeddings = self.LayerNorm(embeddings) 139 | embeddings = self.dropout(embeddings) 140 | return embeddings 141 | 142 | def create_position_ids_from_inputs_embeds(self, inputs_embeds): 143 | """ 144 | We are provided embeddings directly. We cannot infer which are padded so just generate sequential position ids. 145 | 146 | Args: 147 | inputs_embeds: torch.Tensor 148 | 149 | Returns: torch.Tensor 150 | """ 151 | input_shape = inputs_embeds.size()[:-1] 152 | sequence_length = input_shape[1] 153 | 154 | position_ids = torch.arange( 155 | self.padding_idx + 1, sequence_length + self.padding_idx + 1, dtype=torch.long, device=inputs_embeds.device 156 | ) 157 | return position_ids.unsqueeze(0).expand(input_shape) 158 | 159 | 160 | # Copied from transformers.models.bert.modeling_bert.BertSelfAttention with Bert->Roberta 161 | class RobertaSelfAttention(nn.Module): 162 | def __init__(self, config, position_embedding_type=None): 163 | super().__init__() 164 | if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): 165 | raise ValueError( 166 | f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention " 167 | f"heads ({config.num_attention_heads})" 168 | ) 169 | 170 | self.num_attention_heads = config.num_attention_heads 171 | self.attention_head_size = int(config.hidden_size / config.num_attention_heads) 172 | self.all_head_size = self.num_attention_heads * self.attention_head_size 173 | 174 | self.query = nn.Linear(config.hidden_size, self.all_head_size) 175 | self.key = nn.Linear(config.hidden_size, self.all_head_size) 176 | self.value = nn.Linear(config.hidden_size, self.all_head_size) 177 | 178 | self.dropout = nn.Dropout(config.attention_probs_dropout_prob) 179 | self.position_embedding_type = position_embedding_type or getattr( 180 | config, "position_embedding_type", "absolute" 181 | ) 182 | if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": 183 | self.max_position_embeddings = config.max_position_embeddings 184 | self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size) 185 | 186 | self.is_decoder = config.is_decoder 187 | 188 | def transpose_for_scores(self, x): 189 | new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) 190 | x = x.view(*new_x_shape) 191 | return x.permute(0, 2, 1, 3) 192 | 193 | def forward( 194 | self, 195 | hidden_states, 196 | attention_mask=None, 197 | head_mask=None, 198 | encoder_hidden_states=None, 199 | encoder_attention_mask=None, 200 | past_key_value=None, 201 | output_attentions=False, 202 | ): 203 | mixed_query_layer = self.query(hidden_states) 204 | 205 | # If this is instantiated as a cross-attention module, the keys 206 | # and values come from an encoder; the attention mask needs to be 207 | # such that the encoder's padding tokens are not attended to. 208 | is_cross_attention = encoder_hidden_states is not None 209 | 210 | if is_cross_attention and past_key_value is not None: 211 | # reuse k,v, cross_attentions 212 | key_layer = past_key_value[0] 213 | value_layer = past_key_value[1] 214 | attention_mask = encoder_attention_mask 215 | elif is_cross_attention: 216 | key_layer = self.transpose_for_scores(self.key(encoder_hidden_states)) 217 | value_layer = self.transpose_for_scores(self.value(encoder_hidden_states)) 218 | attention_mask = encoder_attention_mask 219 | elif past_key_value is not None: 220 | key_layer = self.transpose_for_scores(self.key(hidden_states)) 221 | value_layer = self.transpose_for_scores(self.value(hidden_states)) 222 | key_layer = torch.cat([past_key_value[0], key_layer], dim=2) 223 | value_layer = torch.cat([past_key_value[1], value_layer], dim=2) 224 | else: 225 | key_layer = self.transpose_for_scores(self.key(hidden_states)) 226 | value_layer = self.transpose_for_scores(self.value(hidden_states)) 227 | 228 | query_layer = self.transpose_for_scores(mixed_query_layer) 229 | 230 | if self.is_decoder: 231 | # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. 232 | # Further calls to cross_attention layer can then reuse all cross-attention 233 | # key/value_states (first "if" case) 234 | # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of 235 | # all previous decoder key/value_states. Further calls to uni-directional self-attention 236 | # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) 237 | # if encoder bi-directional self-attention `past_key_value` is always `None` 238 | past_key_value = (key_layer, value_layer) 239 | 240 | # Take the dot product between "query" and "key" to get the raw attention scores. 241 | attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) 242 | 243 | if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": 244 | seq_length = hidden_states.size()[1] 245 | position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1) 246 | position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1) 247 | distance = position_ids_l - position_ids_r 248 | positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1) 249 | positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility 250 | 251 | if self.position_embedding_type == "relative_key": 252 | relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding) 253 | attention_scores = attention_scores + relative_position_scores 254 | elif self.position_embedding_type == "relative_key_query": 255 | relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding) 256 | relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding) 257 | attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key 258 | 259 | attention_scores = attention_scores / math.sqrt(self.attention_head_size) 260 | if attention_mask is not None: 261 | # Apply the attention mask is (precomputed for all layers in RobertaModel forward() function) 262 | attention_scores = attention_scores + attention_mask 263 | 264 | # Normalize the attention scores to probabilities. 265 | attention_probs = nn.functional.softmax(attention_scores, dim=-1) 266 | 267 | # This is actually dropping out entire tokens to attend to, which might 268 | # seem a bit unusual, but is taken from the original Transformer paper. 269 | attention_probs = self.dropout(attention_probs) 270 | 271 | # Mask heads if we want to 272 | if head_mask is not None: 273 | attention_probs = attention_probs * head_mask 274 | 275 | context_layer = torch.matmul(attention_probs, value_layer) 276 | 277 | context_layer = context_layer.permute(0, 2, 1, 3).contiguous() 278 | new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) 279 | context_layer = context_layer.view(*new_context_layer_shape) 280 | 281 | outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) 282 | 283 | if self.is_decoder: 284 | outputs = outputs + (past_key_value,) 285 | return outputs 286 | 287 | 288 | # Copied from transformers.models.bert.modeling_bert.BertSelfOutput 289 | class RobertaSelfOutput(nn.Module): 290 | def __init__(self, config): 291 | super().__init__() 292 | self.dense = nn.Linear(config.hidden_size, config.hidden_size) 293 | self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) 294 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 295 | 296 | def forward(self, hidden_states, input_tensor): 297 | hidden_states = self.dense(hidden_states) 298 | hidden_states = self.dropout(hidden_states) 299 | hidden_states = self.LayerNorm(hidden_states + input_tensor) 300 | return hidden_states 301 | 302 | 303 | # Copied from transformers.models.bert.modeling_bert.BertAttention with Bert->Roberta 304 | class RobertaAttention(nn.Module): 305 | def __init__(self, config, position_embedding_type=None): 306 | super().__init__() 307 | self.self = RobertaSelfAttention(config, position_embedding_type=position_embedding_type) 308 | self.output = RobertaSelfOutput(config) 309 | self.pruned_heads = set() 310 | 311 | def prune_heads(self, heads): 312 | if len(heads) == 0: 313 | return 314 | heads, index = find_pruneable_heads_and_indices( 315 | heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads 316 | ) 317 | 318 | # Prune linear layers 319 | self.self.query = prune_linear_layer(self.self.query, index) 320 | self.self.key = prune_linear_layer(self.self.key, index) 321 | self.self.value = prune_linear_layer(self.self.value, index) 322 | self.output.dense = prune_linear_layer(self.output.dense, index, dim=1) 323 | 324 | # Update hyper params and store pruned heads 325 | self.self.num_attention_heads = self.self.num_attention_heads - len(heads) 326 | self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads 327 | self.pruned_heads = self.pruned_heads.union(heads) 328 | 329 | def forward( 330 | self, 331 | hidden_states, 332 | attention_mask=None, 333 | head_mask=None, 334 | encoder_hidden_states=None, 335 | encoder_attention_mask=None, 336 | past_key_value=None, 337 | output_attentions=False, 338 | ): 339 | self_outputs = self.self( 340 | hidden_states, 341 | attention_mask, 342 | head_mask, 343 | encoder_hidden_states, 344 | encoder_attention_mask, 345 | past_key_value, 346 | output_attentions, 347 | ) 348 | attention_output = self.output(self_outputs[0], hidden_states) 349 | outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them 350 | return outputs 351 | 352 | 353 | # Copied from transformers.models.bert.modeling_bert.BertIntermediate 354 | class RobertaIntermediate(nn.Module): 355 | def __init__(self, config): 356 | super().__init__() 357 | self.dense = nn.Linear(config.hidden_size, config.intermediate_size) 358 | if isinstance(config.hidden_act, str): 359 | self.intermediate_act_fn = ACT2FN[config.hidden_act] 360 | else: 361 | self.intermediate_act_fn = config.hidden_act 362 | 363 | def forward(self, hidden_states): 364 | hidden_states = self.dense(hidden_states) 365 | hidden_states = self.intermediate_act_fn(hidden_states) 366 | return hidden_states 367 | 368 | 369 | # Copied from transformers.models.bert.modeling_bert.BertOutput 370 | class RobertaOutput(nn.Module): 371 | def __init__(self, config): 372 | super().__init__() 373 | self.dense = nn.Linear(config.intermediate_size, config.hidden_size) 374 | self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) 375 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 376 | 377 | def forward(self, hidden_states, input_tensor): 378 | hidden_states = self.dense(hidden_states) 379 | hidden_states = self.dropout(hidden_states) 380 | hidden_states = self.LayerNorm(hidden_states + input_tensor) 381 | return hidden_states 382 | 383 | 384 | # Copied from transformers.models.bert.modeling_bert.BertLayer with Bert->Roberta 385 | class RobertaLayer(nn.Module): 386 | def __init__(self, config): 387 | super().__init__() 388 | self.chunk_size_feed_forward = config.chunk_size_feed_forward 389 | self.seq_len_dim = 1 390 | self.attention = RobertaAttention(config) 391 | self.is_decoder = config.is_decoder 392 | self.add_cross_attention = config.add_cross_attention 393 | if self.add_cross_attention: 394 | if not self.is_decoder: 395 | raise ValueError(f"{self} should be used as a decoder model if cross attention is added") 396 | self.crossattention = RobertaAttention(config, position_embedding_type="absolute") 397 | self.intermediate = RobertaIntermediate(config) 398 | self.output = RobertaOutput(config) 399 | 400 | def forward( 401 | self, 402 | hidden_states, 403 | attention_mask=None, 404 | head_mask=None, 405 | encoder_hidden_states=None, 406 | encoder_attention_mask=None, 407 | past_key_value=None, 408 | output_attentions=False, 409 | ): 410 | # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 411 | self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None 412 | self_attention_outputs = self.attention( 413 | hidden_states, 414 | attention_mask, 415 | head_mask, 416 | output_attentions=output_attentions, 417 | past_key_value=self_attn_past_key_value, 418 | ) 419 | attention_output = self_attention_outputs[0] 420 | 421 | # if decoder, the last output is tuple of self-attn cache 422 | if self.is_decoder: 423 | outputs = self_attention_outputs[1:-1] 424 | present_key_value = self_attention_outputs[-1] 425 | else: 426 | outputs = self_attention_outputs[1:] # add self attentions if we output attention weights 427 | 428 | cross_attn_present_key_value = None 429 | if self.is_decoder and encoder_hidden_states is not None: 430 | if not hasattr(self, "crossattention"): 431 | raise ValueError( 432 | f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers by setting `config.add_cross_attention=True`" 433 | ) 434 | 435 | # cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple 436 | cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None 437 | cross_attention_outputs = self.crossattention( 438 | attention_output, 439 | attention_mask, 440 | head_mask, 441 | encoder_hidden_states, 442 | encoder_attention_mask, 443 | cross_attn_past_key_value, 444 | output_attentions, 445 | ) 446 | attention_output = cross_attention_outputs[0] 447 | outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights 448 | 449 | # add cross-attn cache to positions 3,4 of present_key_value tuple 450 | cross_attn_present_key_value = cross_attention_outputs[-1] 451 | present_key_value = present_key_value + cross_attn_present_key_value 452 | 453 | layer_output = apply_chunking_to_forward( 454 | self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output 455 | ) 456 | outputs = (layer_output,) + outputs 457 | 458 | # if decoder, return the attn key/values as the last output 459 | if self.is_decoder: 460 | outputs = outputs + (present_key_value,) 461 | 462 | return outputs 463 | 464 | def feed_forward_chunk(self, attention_output): 465 | intermediate_output = self.intermediate(attention_output) 466 | layer_output = self.output(intermediate_output, attention_output) 467 | return layer_output 468 | 469 | 470 | # Copied from transformers.models.bert.modeling_bert.BertEncoder with Bert->Roberta 471 | class RobertaEncoder(nn.Module): 472 | def __init__(self, config): 473 | super().__init__() 474 | self.config = config 475 | self.layer = nn.ModuleList([RobertaLayer(config) for _ in range(config.num_hidden_layers)]) 476 | self.gradient_checkpointing = False 477 | 478 | def forward( 479 | self, 480 | hidden_states, 481 | attention_mask=None, 482 | head_mask=None, 483 | encoder_hidden_states=None, 484 | encoder_attention_mask=None, 485 | past_key_values=None, 486 | use_cache=None, 487 | output_attentions=False, 488 | output_hidden_states=False, 489 | return_dict=True, 490 | ): 491 | all_hidden_states = () if output_hidden_states else None 492 | all_self_attentions = () if output_attentions else None 493 | all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None 494 | 495 | next_decoder_cache = () if use_cache else None 496 | for i, layer_module in enumerate(self.layer): 497 | if output_hidden_states: 498 | all_hidden_states = all_hidden_states + (hidden_states,) 499 | 500 | layer_head_mask = head_mask[i] if head_mask is not None else None 501 | past_key_value = past_key_values[i] if past_key_values is not None else None 502 | 503 | if self.gradient_checkpointing and self.training: 504 | 505 | if use_cache: 506 | logger.warning( 507 | "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." 508 | ) 509 | use_cache = False 510 | 511 | def create_custom_forward(module): 512 | def custom_forward(*inputs): 513 | return module(*inputs, past_key_value, output_attentions) 514 | 515 | return custom_forward 516 | 517 | layer_outputs = torch.utils.checkpoint.checkpoint( 518 | create_custom_forward(layer_module), 519 | hidden_states, 520 | attention_mask, 521 | layer_head_mask, 522 | encoder_hidden_states, 523 | encoder_attention_mask, 524 | ) 525 | else: 526 | layer_outputs = layer_module( 527 | hidden_states, 528 | attention_mask, 529 | layer_head_mask, 530 | encoder_hidden_states, 531 | encoder_attention_mask, 532 | past_key_value, 533 | output_attentions, 534 | ) 535 | 536 | hidden_states = layer_outputs[0] 537 | if use_cache: 538 | next_decoder_cache += (layer_outputs[-1],) 539 | if output_attentions: 540 | all_self_attentions = all_self_attentions + (layer_outputs[1],) 541 | if self.config.add_cross_attention: 542 | all_cross_attentions = all_cross_attentions + (layer_outputs[2],) 543 | 544 | if output_hidden_states: 545 | all_hidden_states = all_hidden_states + (hidden_states,) 546 | 547 | if not return_dict: 548 | return tuple( 549 | v 550 | for v in [ 551 | hidden_states, 552 | next_decoder_cache, 553 | all_hidden_states, 554 | all_self_attentions, 555 | all_cross_attentions, 556 | ] 557 | if v is not None 558 | ) 559 | return BaseModelOutputWithPastAndCrossAttentions( 560 | last_hidden_state=hidden_states, 561 | past_key_values=next_decoder_cache, 562 | hidden_states=all_hidden_states, 563 | attentions=all_self_attentions, 564 | cross_attentions=all_cross_attentions, 565 | ) 566 | 567 | 568 | # Copied from transformers.models.bert.modeling_bert.BertPooler 569 | class RobertaPooler(nn.Module): 570 | def __init__(self, config): 571 | super().__init__() 572 | self.dense = nn.Linear(config.hidden_size, config.hidden_size) 573 | self.activation = nn.Tanh() 574 | 575 | def forward(self, hidden_states): 576 | # We "pool" the model by simply taking the hidden state corresponding 577 | # to the first token. 578 | first_token_tensor = hidden_states[:, 0] 579 | pooled_output = self.dense(first_token_tensor) 580 | pooled_output = self.activation(pooled_output) 581 | return pooled_output 582 | 583 | 584 | class RobertaPreTrainedModel(PreTrainedModel): 585 | """ 586 | An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained 587 | models. 588 | """ 589 | 590 | config_class = RobertaConfig 591 | base_model_prefix = "roberta" 592 | supports_gradient_checkpointing = True 593 | 594 | # Copied from transformers.models.bert.modeling_bert.BertPreTrainedModel._init_weights 595 | def _init_weights(self, module): 596 | """Initialize the weights""" 597 | if isinstance(module, nn.Linear): 598 | # Slightly different from the TF version which uses truncated_normal for initialization 599 | # cf https://github.com/pytorch/pytorch/pull/5617 600 | module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) 601 | if module.bias is not None: 602 | module.bias.data.zero_() 603 | elif isinstance(module, nn.Embedding): 604 | module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) 605 | if module.padding_idx is not None: 606 | module.weight.data[module.padding_idx].zero_() 607 | elif isinstance(module, nn.LayerNorm): 608 | module.bias.data.zero_() 609 | module.weight.data.fill_(1.0) 610 | 611 | def _set_gradient_checkpointing(self, module, value=False): 612 | if isinstance(module, RobertaEncoder): 613 | module.gradient_checkpointing = value 614 | 615 | def update_keys_to_ignore(self, config, del_keys_to_ignore): 616 | """Remove some keys from ignore list""" 617 | if not config.tie_word_embeddings: 618 | # must make a new list, or the class variable gets modified! 619 | self._keys_to_ignore_on_save = [k for k in self._keys_to_ignore_on_save if k not in del_keys_to_ignore] 620 | self._keys_to_ignore_on_load_missing = [ 621 | k for k in self._keys_to_ignore_on_load_missing if k not in del_keys_to_ignore 622 | ] 623 | 624 | 625 | ROBERTA_START_DOCSTRING = r""" 626 | 627 | This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the 628 | library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads 629 | etc.) 630 | 631 | This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. 632 | Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage 633 | and behavior. 634 | 635 | Parameters: 636 | config ([`RobertaConfig`]): Model configuration class with all the parameters of the 637 | model. Initializing with a config file does not load the weights associated with the model, only the 638 | configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. 639 | """ 640 | 641 | ROBERTA_INPUTS_DOCSTRING = r""" 642 | Args: 643 | input_ids (`torch.LongTensor` of shape `({0})`): 644 | Indices of input sequence tokens in the vocabulary. 645 | 646 | Indices can be obtained using [`RobertaTokenizer`]. See [`PreTrainedTokenizer.encode`] and 647 | [`PreTrainedTokenizer.__call__`] for details. 648 | 649 | [What are input IDs?](../glossary#input-ids) 650 | attention_mask (`torch.FloatTensor` of shape `({0})`, *optional*): 651 | Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: 652 | 653 | - 1 for tokens that are **not masked**, 654 | - 0 for tokens that are **masked**. 655 | 656 | [What are attention masks?](../glossary#attention-mask) 657 | token_type_ids (`torch.LongTensor` of shape `({0})`, *optional*): 658 | Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0, 659 | 1]`: 660 | 661 | - 0 corresponds to a *sentence A* token, 662 | - 1 corresponds to a *sentence B* token. 663 | 664 | [What are token type IDs?](../glossary#token-type-ids) 665 | position_ids (`torch.LongTensor` of shape `({0})`, *optional*): 666 | Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, 667 | config.max_position_embeddings - 1]`. 668 | 669 | [What are position IDs?](../glossary#position-ids) 670 | head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): 671 | Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`: 672 | 673 | - 1 indicates the head is **not masked**, 674 | - 0 indicates the head is **masked**. 675 | 676 | inputs_embeds (`torch.FloatTensor` of shape `({0}, hidden_size)`, *optional*): 677 | Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This 678 | is useful if you want more control over how to convert `input_ids` indices into associated vectors than the 679 | model's internal embedding lookup matrix. 680 | output_attentions (`bool`, *optional*): 681 | Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned 682 | tensors for more detail. 683 | output_hidden_states (`bool`, *optional*): 684 | Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for 685 | more detail. 686 | return_dict (`bool`, *optional*): 687 | Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple. 688 | """ 689 | 690 | 691 | @add_start_docstrings( 692 | "The bare RoBERTa Model transformer outputting raw hidden-states without any specific head on top.", 693 | ROBERTA_START_DOCSTRING, 694 | ) 695 | class RobertaModel(RobertaPreTrainedModel): 696 | """ 697 | 698 | The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of 699 | cross-attention is added between the self-attention layers, following the architecture described in *Attention is 700 | all you need*_ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz 701 | Kaiser and Illia Polosukhin. 702 | 703 | To behave as an decoder the model needs to be initialized with the `is_decoder` argument of the configuration set 704 | to `True`. To be used in a Seq2Seq model, the model needs to initialized with both `is_decoder` argument and 705 | `add_cross_attention` set to `True`; an `encoder_hidden_states` is then expected as an input to the forward pass. 706 | 707 | .. _*Attention is all you need*: https://arxiv.org/abs/1706.03762 708 | 709 | """ 710 | 711 | _keys_to_ignore_on_load_missing = [r"position_ids"] 712 | 713 | # Copied from transformers.models.bert.modeling_bert.BertModel.__init__ with Bert->Roberta 714 | def __init__(self, config, add_pooling_layer=True): 715 | super().__init__(config) 716 | self.config = config 717 | 718 | self.embeddings = RobertaEmbeddings(config) 719 | self.encoder = RobertaEncoder(config) 720 | 721 | self.pooler = RobertaPooler(config) if add_pooling_layer else None 722 | 723 | # Initialize weights and apply final processing 724 | self.post_init() 725 | 726 | def get_input_embeddings(self): 727 | return self.embeddings.word_embeddings 728 | 729 | def set_input_embeddings(self, value): 730 | self.embeddings.word_embeddings = value 731 | 732 | def _prune_heads(self, heads_to_prune): 733 | """ 734 | Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base 735 | class PreTrainedModel 736 | """ 737 | for layer, heads in heads_to_prune.items(): 738 | self.encoder.layer[layer].attention.prune_heads(heads) 739 | 740 | @add_start_docstrings_to_model_forward(ROBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length")) 741 | @add_code_sample_docstrings( 742 | processor_class=_TOKENIZER_FOR_DOC, 743 | checkpoint=_CHECKPOINT_FOR_DOC, 744 | output_type=BaseModelOutputWithPoolingAndCrossAttentions, 745 | config_class=_CONFIG_FOR_DOC, 746 | ) 747 | # Copied from transformers.models.bert.modeling_bert.BertModel.forward 748 | def forward( 749 | self, 750 | input_ids=None, 751 | attention_mask=None, 752 | token_type_ids=None, 753 | position_ids=None, 754 | head_mask=None, 755 | inputs_embeds=None, 756 | encoder_hidden_states=None, 757 | encoder_attention_mask=None, 758 | past_key_values=None, 759 | use_cache=None, 760 | output_attentions=None, 761 | output_hidden_states=None, 762 | return_dict=None, 763 | ): 764 | r""" 765 | encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): 766 | Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if 767 | the model is configured as a decoder. 768 | encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*): 769 | Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in 770 | the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`: 771 | 772 | - 1 for tokens that are **not masked**, 773 | - 0 for tokens that are **masked**. 774 | past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): 775 | Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. 776 | 777 | If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that 778 | don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all 779 | `decoder_input_ids` of shape `(batch_size, sequence_length)`. 780 | use_cache (`bool`, *optional*): 781 | If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see 782 | `past_key_values`). 783 | """ 784 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions 785 | output_hidden_states = ( 786 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states 787 | ) 788 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 789 | 790 | if self.config.is_decoder: 791 | use_cache = use_cache if use_cache is not None else self.config.use_cache 792 | else: 793 | use_cache = False 794 | 795 | if input_ids is not None and inputs_embeds is not None: 796 | raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") 797 | elif input_ids is not None: 798 | input_shape = input_ids.size() 799 | elif inputs_embeds is not None: 800 | input_shape = inputs_embeds.size()[:-1] 801 | else: 802 | raise ValueError("You have to specify either input_ids or inputs_embeds") 803 | 804 | batch_size, seq_length = input_shape 805 | device = input_ids.device if input_ids is not None else inputs_embeds.device 806 | 807 | # past_key_values_length 808 | past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 809 | 810 | if attention_mask is None: 811 | attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device) 812 | 813 | if token_type_ids is None: 814 | if hasattr(self.embeddings, "token_type_ids"): 815 | buffered_token_type_ids = self.embeddings.token_type_ids[:, :seq_length] 816 | buffered_token_type_ids_expanded = buffered_token_type_ids.expand(batch_size, seq_length) 817 | token_type_ids = buffered_token_type_ids_expanded 818 | else: 819 | token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device) 820 | 821 | # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] 822 | # ourselves in which case we just need to make it broadcastable to all heads. 823 | extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, device) 824 | 825 | # If a 2D or 3D attention mask is provided for the cross-attention 826 | # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] 827 | if self.config.is_decoder and encoder_hidden_states is not None: 828 | encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() 829 | encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) 830 | if encoder_attention_mask is None: 831 | encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) 832 | encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) 833 | else: 834 | encoder_extended_attention_mask = None 835 | 836 | # Prepare head mask if needed 837 | # 1.0 in head_mask indicate we keep the head 838 | # attention_probs has shape bsz x n_heads x N x N 839 | # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] 840 | # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] 841 | head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) 842 | 843 | embedding_output = self.embeddings( 844 | input_ids=input_ids, 845 | position_ids=position_ids, 846 | token_type_ids=token_type_ids, 847 | inputs_embeds=inputs_embeds, 848 | past_key_values_length=past_key_values_length, 849 | ) 850 | encoder_outputs = self.encoder( 851 | embedding_output, 852 | attention_mask=extended_attention_mask, 853 | head_mask=head_mask, 854 | encoder_hidden_states=encoder_hidden_states, 855 | encoder_attention_mask=encoder_extended_attention_mask, 856 | past_key_values=past_key_values, 857 | use_cache=use_cache, 858 | output_attentions=output_attentions, 859 | output_hidden_states=output_hidden_states, 860 | return_dict=return_dict, 861 | ) 862 | sequence_output = encoder_outputs[0] 863 | pooled_output = self.pooler(sequence_output) if self.pooler is not None else None 864 | 865 | if not return_dict: 866 | return (sequence_output, pooled_output) + encoder_outputs[1:] 867 | 868 | return BaseModelOutputWithPoolingAndCrossAttentions( 869 | last_hidden_state=sequence_output, 870 | pooler_output=pooled_output, 871 | past_key_values=encoder_outputs.past_key_values, 872 | hidden_states=encoder_outputs.hidden_states, 873 | attentions=encoder_outputs.attentions, 874 | cross_attentions=encoder_outputs.cross_attentions, 875 | ) 876 | 877 | 878 | @add_start_docstrings( 879 | """RoBERTa Model with a `language modeling` head on top for CLM fine-tuning.""", ROBERTA_START_DOCSTRING 880 | ) 881 | class RobertaForCausalLM(RobertaPreTrainedModel): 882 | _keys_to_ignore_on_save = [r"lm_head.decoder.weight", r"lm_head.decoder.bias"] 883 | _keys_to_ignore_on_load_missing = [r"position_ids", r"lm_head.decoder.weight", r"lm_head.decoder.bias"] 884 | _keys_to_ignore_on_load_unexpected = [r"pooler"] 885 | 886 | def __init__(self, config): 887 | super().__init__(config) 888 | 889 | if not config.is_decoder: 890 | logger.warning("If you want to use `RobertaLMHeadModel` as a standalone, add `is_decoder=True.`") 891 | 892 | self.roberta = RobertaModel(config, add_pooling_layer=False) 893 | self.lm_head = RobertaLMHead(config) 894 | 895 | # The LM head weights require special treatment only when they are tied with the word embeddings 896 | self.update_keys_to_ignore(config, ["lm_head.decoder.weight"]) 897 | 898 | # Initialize weights and apply final processing 899 | self.post_init() 900 | 901 | def get_output_embeddings(self): 902 | return self.lm_head.decoder 903 | 904 | def set_output_embeddings(self, new_embeddings): 905 | self.lm_head.decoder = new_embeddings 906 | 907 | @add_start_docstrings_to_model_forward(ROBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length")) 908 | @replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC) 909 | def forward( 910 | self, 911 | input_ids=None, 912 | attention_mask=None, 913 | token_type_ids=None, 914 | position_ids=None, 915 | head_mask=None, 916 | inputs_embeds=None, 917 | encoder_hidden_states=None, 918 | encoder_attention_mask=None, 919 | labels=None, 920 | past_key_values=None, 921 | use_cache=None, 922 | output_attentions=None, 923 | output_hidden_states=None, 924 | return_dict=None, 925 | ): 926 | r""" 927 | encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): 928 | Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if 929 | the model is configured as a decoder. 930 | encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*): 931 | Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in 932 | the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`: 933 | 934 | - 1 for tokens that are **not masked**, 935 | - 0 for tokens that are **masked**. 936 | 937 | labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): 938 | Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in 939 | `[-100, 0, ..., config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are 940 | ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]` 941 | past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): 942 | Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. 943 | 944 | If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that 945 | don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all 946 | `decoder_input_ids` of shape `(batch_size, sequence_length)`. 947 | use_cache (`bool`, *optional*): 948 | If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see 949 | `past_key_values`). 950 | 951 | Returns: 952 | 953 | Example: 954 | 955 | ```python 956 | >>> from transformers import RobertaTokenizer, RobertaForCausalLM, RobertaConfig 957 | >>> import torch 958 | 959 | >>> tokenizer = RobertaTokenizer.from_pretrained("roberta-base") 960 | >>> config = RobertaConfig.from_pretrained("roberta-base") 961 | >>> config.is_decoder = True 962 | >>> model = RobertaForCausalLM.from_pretrained("roberta-base", config=config) 963 | 964 | >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt") 965 | >>> outputs = model(**inputs) 966 | 967 | >>> prediction_logits = outputs.logits 968 | ```""" 969 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 970 | if labels is not None: 971 | use_cache = False 972 | 973 | outputs = self.roberta( 974 | input_ids, 975 | attention_mask=attention_mask, 976 | token_type_ids=token_type_ids, 977 | position_ids=position_ids, 978 | head_mask=head_mask, 979 | inputs_embeds=inputs_embeds, 980 | encoder_hidden_states=encoder_hidden_states, 981 | encoder_attention_mask=encoder_attention_mask, 982 | past_key_values=past_key_values, 983 | use_cache=use_cache, 984 | output_attentions=output_attentions, 985 | output_hidden_states=output_hidden_states, 986 | return_dict=return_dict, 987 | ) 988 | 989 | sequence_output = outputs[0] 990 | prediction_scores = self.lm_head(sequence_output) 991 | 992 | lm_loss = None 993 | if labels is not None: 994 | # we are doing next-token prediction; shift prediction scores and input ids by one 995 | shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous() 996 | labels = labels[:, 1:].contiguous() 997 | loss_fct = CrossEntropyLoss() 998 | lm_loss = loss_fct(shifted_prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)) 999 | 1000 | if not return_dict: 1001 | output = (prediction_scores,) + outputs[2:] 1002 | return ((lm_loss,) + output) if lm_loss is not None else output 1003 | 1004 | return CausalLMOutputWithCrossAttentions( 1005 | loss=lm_loss, 1006 | logits=prediction_scores, 1007 | past_key_values=outputs.past_key_values, 1008 | hidden_states=outputs.hidden_states, 1009 | attentions=outputs.attentions, 1010 | cross_attentions=outputs.cross_attentions, 1011 | ) 1012 | 1013 | def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, **model_kwargs): 1014 | input_shape = input_ids.shape 1015 | # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly 1016 | if attention_mask is None: 1017 | attention_mask = input_ids.new_ones(input_shape) 1018 | 1019 | # cut decoder_input_ids if past is used 1020 | if past is not None: 1021 | input_ids = input_ids[:, -1:] 1022 | 1023 | return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past} 1024 | 1025 | def _reorder_cache(self, past, beam_idx): 1026 | reordered_past = () 1027 | for layer_past in past: 1028 | reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),) 1029 | return reordered_past 1030 | 1031 | 1032 | @add_start_docstrings("""RoBERTa Model with a `language modeling` head on top.""", ROBERTA_START_DOCSTRING) 1033 | class RobertaForMaskedLM(RobertaPreTrainedModel): 1034 | _keys_to_ignore_on_save = [r"lm_head.decoder.weight", r"lm_head.decoder.bias"] 1035 | _keys_to_ignore_on_load_missing = [r"position_ids", r"lm_head.decoder.weight", r"lm_head.decoder.bias"] 1036 | _keys_to_ignore_on_load_unexpected = [r"pooler"] 1037 | 1038 | def __init__(self, config): 1039 | super().__init__(config) 1040 | 1041 | if config.is_decoder: 1042 | logger.warning( 1043 | "If you want to use `RobertaForMaskedLM` make sure `config.is_decoder=False` for " 1044 | "bi-directional self-attention." 1045 | ) 1046 | 1047 | self.roberta = RobertaModel(config, add_pooling_layer=False) 1048 | self.lm_head = RobertaLMHead(config) 1049 | 1050 | # The LM head weights require special treatment only when they are tied with the word embeddings 1051 | self.update_keys_to_ignore(config, ["lm_head.decoder.weight"]) 1052 | 1053 | # Initialize weights and apply final processing 1054 | self.post_init() 1055 | 1056 | def get_output_embeddings(self): 1057 | return self.lm_head.decoder 1058 | 1059 | def set_output_embeddings(self, new_embeddings): 1060 | self.lm_head.decoder = new_embeddings 1061 | 1062 | @add_start_docstrings_to_model_forward(ROBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length")) 1063 | @add_code_sample_docstrings( 1064 | processor_class=_TOKENIZER_FOR_DOC, 1065 | checkpoint=_CHECKPOINT_FOR_DOC, 1066 | output_type=MaskedLMOutput, 1067 | config_class=_CONFIG_FOR_DOC, 1068 | mask="", 1069 | ) 1070 | def forward( 1071 | self, 1072 | input_ids=None, 1073 | attention_mask=None, 1074 | token_type_ids=None, 1075 | position_ids=None, 1076 | head_mask=None, 1077 | inputs_embeds=None, 1078 | encoder_hidden_states=None, 1079 | encoder_attention_mask=None, 1080 | labels=None, 1081 | output_attentions=None, 1082 | output_hidden_states=None, 1083 | return_dict=None, 1084 | ): 1085 | r""" 1086 | labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): 1087 | Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ..., 1088 | config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the 1089 | loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]` 1090 | kwargs (`Dict[str, any]`, optional, defaults to *{}*): 1091 | Used to hide legacy arguments that have been deprecated. 1092 | """ 1093 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 1094 | 1095 | outputs = self.roberta( 1096 | input_ids, 1097 | attention_mask=attention_mask, 1098 | token_type_ids=token_type_ids, 1099 | position_ids=position_ids, 1100 | head_mask=head_mask, 1101 | inputs_embeds=inputs_embeds, 1102 | encoder_hidden_states=encoder_hidden_states, 1103 | encoder_attention_mask=encoder_attention_mask, 1104 | output_attentions=output_attentions, 1105 | output_hidden_states=output_hidden_states, 1106 | return_dict=return_dict, 1107 | ) 1108 | sequence_output = outputs[0] 1109 | prediction_scores = self.lm_head(sequence_output) 1110 | 1111 | masked_lm_loss = None 1112 | if labels is not None: 1113 | loss_fct = CrossEntropyLoss() 1114 | masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)) 1115 | 1116 | if not return_dict: 1117 | output = (prediction_scores,) + outputs[2:] 1118 | return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output 1119 | 1120 | return MaskedLMOutput( 1121 | loss=masked_lm_loss, 1122 | logits=prediction_scores, 1123 | hidden_states=outputs.hidden_states, 1124 | attentions=outputs.attentions, 1125 | ) 1126 | 1127 | 1128 | class RobertaLMHead(nn.Module): 1129 | """Roberta Head for masked language modeling.""" 1130 | 1131 | def __init__(self, config): 1132 | super().__init__() 1133 | self.dense = nn.Linear(config.hidden_size, config.hidden_size) 1134 | self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) 1135 | 1136 | self.decoder = nn.Linear(config.hidden_size, config.vocab_size) 1137 | self.bias = nn.Parameter(torch.zeros(config.vocab_size)) 1138 | self.decoder.bias = self.bias 1139 | 1140 | def forward(self, features, **kwargs): 1141 | x = self.dense(features) 1142 | x = gelu(x) 1143 | x = self.layer_norm(x) 1144 | 1145 | # project back to size of vocabulary with bias 1146 | x = self.decoder(x) 1147 | 1148 | return x 1149 | 1150 | def _tie_weights(self): 1151 | # To tie those two weights if they get disconnected (on TPU or when the bias is resized) 1152 | self.bias = self.decoder.bias 1153 | 1154 | 1155 | @add_start_docstrings( 1156 | """ 1157 | RoBERTa Model transformer with a sequence classification/regression head on top (a linear layer on top of the 1158 | pooled output) e.g. for GLUE tasks. 1159 | """, 1160 | ROBERTA_START_DOCSTRING, 1161 | ) 1162 | class RobertaForSequenceClassification(RobertaPreTrainedModel): 1163 | _keys_to_ignore_on_load_missing = [r"position_ids"] 1164 | 1165 | def __init__(self, config): 1166 | super().__init__(config) 1167 | self.num_labels = config.num_labels 1168 | self.config = config 1169 | 1170 | self.roberta = RobertaModel(config, add_pooling_layer=False) 1171 | self.classifier = RobertaClassificationHead(config) 1172 | 1173 | # Initialize weights and apply final processing 1174 | self.post_init() 1175 | 1176 | @add_start_docstrings_to_model_forward(ROBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length")) 1177 | @add_code_sample_docstrings( 1178 | processor_class=_TOKENIZER_FOR_DOC, 1179 | checkpoint=_CHECKPOINT_FOR_DOC, 1180 | output_type=SequenceClassifierOutput, 1181 | config_class=_CONFIG_FOR_DOC, 1182 | ) 1183 | def forward( 1184 | self, 1185 | input_ids=None, 1186 | attention_mask=None, 1187 | token_type_ids=None, 1188 | position_ids=None, 1189 | head_mask=None, 1190 | inputs_embeds=None, 1191 | labels=None, 1192 | output_attentions=None, 1193 | output_hidden_states=None, 1194 | return_dict=None, 1195 | ): 1196 | r""" 1197 | labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): 1198 | Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., 1199 | config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If 1200 | `config.num_labels > 1` a classification loss is computed (Cross-Entropy). 1201 | """ 1202 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 1203 | 1204 | outputs = self.roberta( 1205 | input_ids, 1206 | attention_mask=attention_mask, 1207 | token_type_ids=token_type_ids, 1208 | position_ids=position_ids, 1209 | head_mask=head_mask, 1210 | inputs_embeds=inputs_embeds, 1211 | output_attentions=output_attentions, 1212 | output_hidden_states=output_hidden_states, 1213 | return_dict=return_dict, 1214 | ) 1215 | sequence_output = outputs[0] 1216 | logits = self.classifier(sequence_output) 1217 | 1218 | loss = None 1219 | if labels is not None: 1220 | if self.config.problem_type is None: 1221 | if self.num_labels == 1: 1222 | self.config.problem_type = "regression" 1223 | elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): 1224 | self.config.problem_type = "single_label_classification" 1225 | else: 1226 | self.config.problem_type = "multi_label_classification" 1227 | 1228 | if self.config.problem_type == "regression": 1229 | loss_fct = MSELoss() 1230 | if self.num_labels == 1: 1231 | loss = loss_fct(logits.squeeze(), labels.squeeze()) 1232 | else: 1233 | loss = loss_fct(logits, labels) 1234 | elif self.config.problem_type == "single_label_classification": 1235 | loss_fct = CrossEntropyLoss() 1236 | loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) 1237 | elif self.config.problem_type == "multi_label_classification": 1238 | loss_fct = BCEWithLogitsLoss() 1239 | loss = loss_fct(logits, labels) 1240 | 1241 | if not return_dict: 1242 | output = (logits,) + outputs[2:] 1243 | return ((loss,) + output) if loss is not None else output 1244 | 1245 | return SequenceClassifierOutput( 1246 | loss=loss, 1247 | logits=logits, 1248 | hidden_states=outputs.hidden_states, 1249 | attentions=outputs.attentions, 1250 | ) 1251 | 1252 | 1253 | @add_start_docstrings( 1254 | """ 1255 | Roberta Model with a multiple choice classification head on top (a linear layer on top of the pooled output and a 1256 | softmax) e.g. for RocStories/SWAG tasks. 1257 | """, 1258 | ROBERTA_START_DOCSTRING, 1259 | ) 1260 | class RobertaForMultipleChoice(RobertaPreTrainedModel): 1261 | _keys_to_ignore_on_load_missing = [r"position_ids"] 1262 | 1263 | def __init__(self, config): 1264 | super().__init__(config) 1265 | 1266 | self.roberta = RobertaModel(config) 1267 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 1268 | self.classifier = nn.Linear(config.hidden_size, 1) 1269 | 1270 | # Initialize weights and apply final processing 1271 | self.post_init() 1272 | 1273 | @add_start_docstrings_to_model_forward(ROBERTA_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length")) 1274 | @add_code_sample_docstrings( 1275 | processor_class=_TOKENIZER_FOR_DOC, 1276 | checkpoint=_CHECKPOINT_FOR_DOC, 1277 | output_type=MultipleChoiceModelOutput, 1278 | config_class=_CONFIG_FOR_DOC, 1279 | ) 1280 | def forward( 1281 | self, 1282 | input_ids=None, 1283 | token_type_ids=None, 1284 | attention_mask=None, 1285 | labels=None, 1286 | position_ids=None, 1287 | head_mask=None, 1288 | inputs_embeds=None, 1289 | output_attentions=None, 1290 | output_hidden_states=None, 1291 | return_dict=None, 1292 | ): 1293 | r""" 1294 | labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): 1295 | Labels for computing the multiple choice classification loss. Indices should be in `[0, ..., 1296 | num_choices-1]` where `num_choices` is the size of the second dimension of the input tensors. (See 1297 | `input_ids` above) 1298 | """ 1299 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 1300 | num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1] 1301 | 1302 | flat_input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None 1303 | flat_position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None 1304 | flat_token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None 1305 | flat_attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None 1306 | flat_inputs_embeds = ( 1307 | inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1)) 1308 | if inputs_embeds is not None 1309 | else None 1310 | ) 1311 | 1312 | outputs = self.roberta( 1313 | flat_input_ids, 1314 | position_ids=flat_position_ids, 1315 | token_type_ids=flat_token_type_ids, 1316 | attention_mask=flat_attention_mask, 1317 | head_mask=head_mask, 1318 | inputs_embeds=flat_inputs_embeds, 1319 | output_attentions=output_attentions, 1320 | output_hidden_states=output_hidden_states, 1321 | return_dict=return_dict, 1322 | ) 1323 | pooled_output = outputs[1] 1324 | 1325 | pooled_output = self.dropout(pooled_output) 1326 | logits = self.classifier(pooled_output) 1327 | reshaped_logits = logits.view(-1, num_choices) 1328 | 1329 | loss = None 1330 | if labels is not None: 1331 | loss_fct = CrossEntropyLoss() 1332 | loss = loss_fct(reshaped_logits, labels) 1333 | 1334 | if not return_dict: 1335 | output = (reshaped_logits,) + outputs[2:] 1336 | return ((loss,) + output) if loss is not None else output 1337 | 1338 | return MultipleChoiceModelOutput( 1339 | loss=loss, 1340 | logits=reshaped_logits, 1341 | hidden_states=outputs.hidden_states, 1342 | attentions=outputs.attentions, 1343 | ) 1344 | 1345 | 1346 | @add_start_docstrings( 1347 | """ 1348 | Roberta Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for 1349 | Named-Entity-Recognition (NER) tasks. 1350 | """, 1351 | ROBERTA_START_DOCSTRING, 1352 | ) 1353 | class RobertaForTokenClassification(RobertaPreTrainedModel): 1354 | _keys_to_ignore_on_load_unexpected = [r"pooler"] 1355 | _keys_to_ignore_on_load_missing = [r"position_ids"] 1356 | 1357 | def __init__(self, config): 1358 | super().__init__(config) 1359 | self.num_labels = config.num_labels 1360 | 1361 | self.roberta = RobertaModel(config, add_pooling_layer=False) 1362 | classifier_dropout = ( 1363 | config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob 1364 | ) 1365 | self.dropout = nn.Dropout(classifier_dropout) 1366 | self.classifier = nn.Linear(config.hidden_size, config.num_labels) 1367 | 1368 | # Initialize weights and apply final processing 1369 | self.post_init() 1370 | 1371 | @add_start_docstrings_to_model_forward(ROBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length")) 1372 | @add_code_sample_docstrings( 1373 | processor_class=_TOKENIZER_FOR_DOC, 1374 | checkpoint=_CHECKPOINT_FOR_DOC, 1375 | output_type=TokenClassifierOutput, 1376 | config_class=_CONFIG_FOR_DOC, 1377 | ) 1378 | def forward( 1379 | self, 1380 | input_ids=None, 1381 | attention_mask=None, 1382 | token_type_ids=None, 1383 | position_ids=None, 1384 | head_mask=None, 1385 | inputs_embeds=None, 1386 | labels=None, 1387 | output_attentions=None, 1388 | output_hidden_states=None, 1389 | return_dict=None, 1390 | ): 1391 | r""" 1392 | labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): 1393 | Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`. 1394 | """ 1395 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 1396 | 1397 | outputs = self.roberta( 1398 | input_ids, 1399 | attention_mask=attention_mask, 1400 | token_type_ids=token_type_ids, 1401 | position_ids=position_ids, 1402 | head_mask=head_mask, 1403 | inputs_embeds=inputs_embeds, 1404 | output_attentions=output_attentions, 1405 | output_hidden_states=output_hidden_states, 1406 | return_dict=return_dict, 1407 | ) 1408 | 1409 | sequence_output = outputs[0] 1410 | 1411 | sequence_output = self.dropout(sequence_output) 1412 | logits = self.classifier(sequence_output) 1413 | 1414 | loss = None 1415 | if labels is not None: 1416 | loss_fct = CrossEntropyLoss() 1417 | # Only keep active parts of the loss 1418 | if attention_mask is not None: 1419 | active_loss = attention_mask.view(-1) == 1 1420 | active_logits = logits.view(-1, self.num_labels) 1421 | active_labels = torch.where( 1422 | active_loss, labels.view(-1), torch.tensor(loss_fct.ignore_index).type_as(labels) 1423 | ) 1424 | loss = loss_fct(active_logits, active_labels) 1425 | else: 1426 | loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) 1427 | 1428 | if not return_dict: 1429 | output = (logits,) + outputs[2:] 1430 | return ((loss,) + output) if loss is not None else output 1431 | 1432 | return TokenClassifierOutput( 1433 | loss=loss, 1434 | logits=logits, 1435 | hidden_states=outputs.hidden_states, 1436 | attentions=outputs.attentions, 1437 | ) 1438 | 1439 | 1440 | class RobertaClassificationHead(nn.Module): 1441 | """Head for sentence-level classification tasks.""" 1442 | 1443 | def __init__(self, config): 1444 | super().__init__() 1445 | self.dense = nn.Linear(config.hidden_size, config.hidden_size) 1446 | classifier_dropout = ( 1447 | config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob 1448 | ) 1449 | self.dropout = nn.Dropout(classifier_dropout) 1450 | self.out_proj = nn.Linear(config.hidden_size, config.num_labels) 1451 | 1452 | def forward(self, features, **kwargs): 1453 | x = features[:, 0, :] # take token (equiv. to [CLS]) 1454 | x = self.dropout(x) 1455 | x = self.dense(x) 1456 | x = torch.tanh(x) 1457 | x = self.dropout(x) 1458 | x = self.out_proj(x) 1459 | return x 1460 | 1461 | 1462 | @add_start_docstrings( 1463 | """ 1464 | Roberta Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear 1465 | layers on top of the hidden-states output to compute `span start logits` and `span end logits`). 1466 | """, 1467 | ROBERTA_START_DOCSTRING, 1468 | ) 1469 | class RobertaForQuestionAnswering(RobertaPreTrainedModel): 1470 | _keys_to_ignore_on_load_unexpected = [r"pooler"] 1471 | _keys_to_ignore_on_load_missing = [r"position_ids"] 1472 | 1473 | def __init__(self, config): 1474 | super().__init__(config) 1475 | self.num_labels = config.num_labels 1476 | 1477 | self.roberta = RobertaModel(config, add_pooling_layer=False) 1478 | self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels) 1479 | 1480 | # Initialize weights and apply final processing 1481 | self.post_init() 1482 | 1483 | @add_start_docstrings_to_model_forward(ROBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length")) 1484 | @add_code_sample_docstrings( 1485 | processor_class=_TOKENIZER_FOR_DOC, 1486 | checkpoint=_CHECKPOINT_FOR_DOC, 1487 | output_type=QuestionAnsweringModelOutput, 1488 | config_class=_CONFIG_FOR_DOC, 1489 | ) 1490 | def forward( 1491 | self, 1492 | input_ids=None, 1493 | attention_mask=None, 1494 | token_type_ids=None, 1495 | position_ids=None, 1496 | head_mask=None, 1497 | inputs_embeds=None, 1498 | start_positions=None, 1499 | end_positions=None, 1500 | output_attentions=None, 1501 | output_hidden_states=None, 1502 | return_dict=None, 1503 | ): 1504 | r""" 1505 | start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): 1506 | Labels for position (index) of the start of the labelled span for computing the token classification loss. 1507 | Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence 1508 | are not taken into account for computing the loss. 1509 | end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): 1510 | Labels for position (index) of the end of the labelled span for computing the token classification loss. 1511 | Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence 1512 | are not taken into account for computing the loss. 1513 | """ 1514 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 1515 | 1516 | outputs = self.roberta( 1517 | input_ids, 1518 | attention_mask=attention_mask, 1519 | token_type_ids=token_type_ids, 1520 | position_ids=position_ids, 1521 | head_mask=head_mask, 1522 | inputs_embeds=inputs_embeds, 1523 | output_attentions=output_attentions, 1524 | output_hidden_states=output_hidden_states, 1525 | return_dict=return_dict, 1526 | ) 1527 | 1528 | sequence_output = outputs[0] 1529 | 1530 | logits = self.qa_outputs(sequence_output) 1531 | start_logits, end_logits = logits.split(1, dim=-1) 1532 | start_logits = start_logits.squeeze(-1).contiguous() 1533 | end_logits = end_logits.squeeze(-1).contiguous() 1534 | 1535 | total_loss = None 1536 | if start_positions is not None and end_positions is not None: 1537 | # If we are on multi-GPU, split add a dimension 1538 | if len(start_positions.size()) > 1: 1539 | start_positions = start_positions.squeeze(-1) 1540 | if len(end_positions.size()) > 1: 1541 | end_positions = end_positions.squeeze(-1) 1542 | # sometimes the start/end positions are outside our model inputs, we ignore these terms 1543 | ignored_index = start_logits.size(1) 1544 | start_positions = start_positions.clamp(0, ignored_index) 1545 | end_positions = end_positions.clamp(0, ignored_index) 1546 | 1547 | loss_fct = CrossEntropyLoss(ignore_index=ignored_index) 1548 | start_loss = loss_fct(start_logits, start_positions) 1549 | end_loss = loss_fct(end_logits, end_positions) 1550 | total_loss = (start_loss + end_loss) / 2 1551 | 1552 | if not return_dict: 1553 | output = (start_logits, end_logits) + outputs[2:] 1554 | return ((total_loss,) + output) if total_loss is not None else output 1555 | 1556 | return QuestionAnsweringModelOutput( 1557 | loss=total_loss, 1558 | start_logits=start_logits, 1559 | end_logits=end_logits, 1560 | hidden_states=outputs.hidden_states, 1561 | attentions=outputs.attentions, 1562 | ) 1563 | 1564 | 1565 | def create_position_ids_from_input_ids(input_ids, padding_idx, past_key_values_length=0): 1566 | """ 1567 | Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding symbols 1568 | are ignored. This is modified from fairseq's `utils.make_positions`. 1569 | 1570 | Args: 1571 | x: torch.Tensor x: 1572 | 1573 | Returns: torch.Tensor 1574 | """ 1575 | # The series of casts and type-conversions here are carefully balanced to both work with ONNX export and XLA. 1576 | mask = input_ids.ne(padding_idx).int() 1577 | incremental_indices = (torch.cumsum(mask, dim=1).type_as(mask) + past_key_values_length) * mask 1578 | return incremental_indices.long() + padding_idx -------------------------------------------------------------------------------- /mutransformers/models/roberta/configuration_roberta.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. 3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 4 | # Copyright 2022 Microsoft Corporation. 5 | # 6 | # Licensed under the Apache License, Version 2.0 (the "License"); 7 | # you may not use this file except in compliance with the License. 8 | # You may obtain a copy of the License at 9 | # 10 | # http://www.apache.org/licenses/LICENSE-2.0 11 | # 12 | # Unless required by applicable law or agreed to in writing, software 13 | # distributed under the License is distributed on an "AS IS" BASIS, 14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | # See the License for the specific language governing permissions and 16 | # limitations under the License. 17 | """ RoBERTa configuration""" 18 | from collections import OrderedDict 19 | from typing import Mapping 20 | 21 | from transformers.onnx import OnnxConfig 22 | from transformers.utils import logging 23 | from ..bert.configuration_bert import BertConfig 24 | 25 | 26 | logger = logging.get_logger(__name__) 27 | 28 | ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP = { 29 | "roberta-base": "https://huggingface.co/roberta-base/resolve/main/config.json", 30 | "roberta-large": "https://huggingface.co/roberta-large/resolve/main/config.json", 31 | "roberta-large-mnli": "https://huggingface.co/roberta-large-mnli/resolve/main/config.json", 32 | "distilroberta-base": "https://huggingface.co/distilroberta-base/resolve/main/config.json", 33 | "roberta-base-openai-detector": "https://huggingface.co/roberta-base-openai-detector/resolve/main/config.json", 34 | "roberta-large-openai-detector": "https://huggingface.co/roberta-large-openai-detector/resolve/main/config.json", 35 | } 36 | 37 | 38 | class RobertaConfig(BertConfig): 39 | r""" 40 | This is the configuration class to store the configuration of a [`RobertaModel`] or a [`TFRobertaModel`]. It is 41 | used to instantiate a RoBERTa model according to the specified arguments, defining the model architecture. 42 | 43 | 44 | Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the 45 | documentation from [`PretrainedConfig`] for more information. 46 | 47 | The [`RobertaConfig`] class directly inherits [`BertConfig`]. It reuses the same defaults. Please check the parent 48 | class for more information. 49 | 50 | Examples: 51 | 52 | ```python 53 | >>> from transformers import RobertaConfig, RobertaModel 54 | 55 | >>> # Initializing a RoBERTa configuration 56 | >>> configuration = RobertaConfig() 57 | 58 | >>> # Initializing a model from the configuration 59 | >>> model = RobertaModel(configuration) 60 | 61 | >>> # Accessing the model configuration 62 | >>> configuration = model.config 63 | ```""" 64 | model_type = "roberta" 65 | 66 | def __init__(self, pad_token_id=1, bos_token_id=0, eos_token_id=2, **kwargs): 67 | """Constructs RobertaConfig.""" 68 | super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs) 69 | 70 | 71 | class RobertaOnnxConfig(OnnxConfig): 72 | @property 73 | def inputs(self) -> Mapping[str, Mapping[int, str]]: 74 | return OrderedDict( 75 | [ 76 | ("input_ids", {0: "batch", 1: "sequence"}), 77 | ("attention_mask", {0: "batch", 1: "sequence"}), 78 | ] 79 | ) -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy>=1.18.5 2 | pandas>=1.1.2 3 | torch>=1.6.0 4 | torchvision>=0.7.0 5 | seaborn>=0.11.2 6 | transformers>=4.16.2 7 | pyyaml 8 | mup 9 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import setuptools 2 | 3 | with open("README.md", "r", encoding="utf-8") as fh: 4 | long_description = fh.read() 5 | 6 | setuptools.setup( 7 | name="mutransformers", 8 | version="0.1.0", 9 | author="Greg Yang, Edward J Hu", 10 | author_email="gregyang@microsoft.com, edward.hu@umontreal.ca", 11 | description="some Huggingface transformers reparametrized in muP", 12 | long_description=long_description, 13 | long_description_content_type="text/markdown", 14 | url="https://github.com/microsoft/mutransformers", 15 | packages=setuptools.find_packages(), 16 | classifiers=[ 17 | "Programming Language :: Python :: 3", 18 | "License :: OSI Approved :: MIT License", 19 | "Operating System :: OS Independent", 20 | ], 21 | python_requires='>=3.6', 22 | ) -------------------------------------------------------------------------------- /tests/coordcheck/bert_mup_dhead_coord_check.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/mutransformers/480287ce7b18a07a3432e8f2fbc0f0e5b71e2599/tests/coordcheck/bert_mup_dhead_coord_check.png -------------------------------------------------------------------------------- /tests/coordcheck/bert_mup_nhead_coord_check.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/mutransformers/480287ce7b18a07a3432e8f2fbc0f0e5b71e2599/tests/coordcheck/bert_mup_nhead_coord_check.png -------------------------------------------------------------------------------- /tests/coordcheck/bert_sp_dhead_coord_check.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/mutransformers/480287ce7b18a07a3432e8f2fbc0f0e5b71e2599/tests/coordcheck/bert_sp_dhead_coord_check.png -------------------------------------------------------------------------------- /tests/coordcheck/bert_sp_nhead_coord_check.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/mutransformers/480287ce7b18a07a3432e8f2fbc0f0e5b71e2599/tests/coordcheck/bert_sp_nhead_coord_check.png -------------------------------------------------------------------------------- /tests/coordcheck/coordcheck.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Microsoft Corporation. 2 | 3 | from functools import partial 4 | from itertools import cycle 5 | 6 | import numpy as np 7 | import torch 8 | 9 | import seaborn as sns 10 | from mup.coord_check import get_coord_data, plot_coord_data 11 | from mup import set_base_shapes, make_base_shapes 12 | from transformers import BertTokenizer, GPT2Tokenizer 13 | 14 | from mutransformers import BertConfig, BertForMaskedLM, RobertaConfig, RobertaForMaskedLM, GPT2Config, GPT2LMHeadModel 15 | 16 | sns.set() 17 | 18 | def get_dataloader(arch): 19 | if arch in ('bert', 'roberta'): 20 | tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') 21 | input_ids = tokenizer("The capital of France is [MASK].", return_tensors="pt")['input_ids'] 22 | labels = tokenizer("The capital of France is Paris.", return_tensors="pt")["input_ids"] 23 | dataloader = cycle([dict(input_ids=input_ids, labels=labels)]) 24 | elif arch == 'gpt2': 25 | tokenizer = GPT2Tokenizer.from_pretrained('gpt2') 26 | text = "The capital of France is Paris." 27 | encoded_input = tokenizer(text, return_tensors='pt') 28 | encoded_input['labels'] = encoded_input['input_ids'] 29 | dataloader = cycle([encoded_input]) 30 | return dataloader 31 | 32 | def make_bsh(arch, filename=None): 33 | if arch == 'roberta': 34 | base_config = RobertaConfig( 35 | vocab_size=52_000, 36 | max_position_embeddings=514, 37 | hidden_size=256, 38 | intermediate_size=256, 39 | num_attention_heads=4, 40 | num_hidden_layers=2, 41 | type_vocab_size=1, 42 | ) 43 | delta_config = RobertaConfig( 44 | vocab_size=52_000, 45 | max_position_embeddings=514, 46 | hidden_size=200, 47 | intermediate_size=200, 48 | num_attention_heads=5, 49 | num_hidden_layers=2, 50 | type_vocab_size=1, 51 | ) 52 | base_model = RobertaForMaskedLM(config=base_config) 53 | delta_model = RobertaForMaskedLM(config=delta_config) 54 | elif arch == 'bert': 55 | base_config = BertConfig( 56 | vocab_size=52_000, 57 | max_position_embeddings=514, 58 | hidden_size=256, 59 | intermediate_size=256, 60 | num_attention_heads=4, 61 | num_hidden_layers=2, 62 | type_vocab_size=1, 63 | ) 64 | delta_config = BertConfig( 65 | vocab_size=52_000, 66 | max_position_embeddings=514, 67 | hidden_size=200, 68 | intermediate_size=200, 69 | num_attention_heads=5, 70 | num_hidden_layers=2, 71 | type_vocab_size=1, 72 | ) 73 | base_model = BertForMaskedLM(config=base_config) 74 | delta_model = BertForMaskedLM(config=delta_config) 75 | elif arch == 'gpt2': 76 | base_config = GPT2Config( 77 | n_head=4, 78 | activation_function='relu', 79 | n_embd=256, 80 | n_layer=2, 81 | # num_labels=1, 82 | ) 83 | delta_config = GPT2Config( 84 | n_head=5, 85 | activation_function='relu', 86 | n_embd=200, 87 | n_layer=2, 88 | # num_labels=1, 89 | ) 90 | base_model = GPT2LMHeadModel(config=base_config) 91 | delta_model = GPT2LMHeadModel(config=delta_config) 92 | else: 93 | raise NotImplementedError() 94 | base_shapes = make_base_shapes(base_model, delta_model, savefile=filename) 95 | return base_shapes 96 | 97 | def get_lazy_model(arch, width, base_shape=None, mup=True, readout_zero_init=True, query_zero_init=True, vary_nhead=False): 98 | width = int(width) 99 | nhead = 4 100 | if vary_nhead: 101 | nhead = int(4 * width / 252) 102 | def f(): 103 | if arch == 'roberta': 104 | config = RobertaConfig( 105 | vocab_size=52_000, 106 | max_position_embeddings=514, 107 | hidden_size=width, 108 | intermediate_size=width, 109 | num_attention_heads=nhead, 110 | num_hidden_layers=2, 111 | type_vocab_size=1, 112 | attn_mult=8 if mup else None, 113 | classifier_dropout=0 114 | ) 115 | model = RobertaForMaskedLM(config=config) 116 | elif arch == 'bert': 117 | config = BertConfig( 118 | vocab_size=52_000, 119 | max_position_embeddings=514, 120 | hidden_size=width, 121 | intermediate_size=width, 122 | num_attention_heads=nhead, 123 | num_hidden_layers=2, 124 | type_vocab_size=1, 125 | attn_mult=8 if mup else None, 126 | classifier_dropout=0 127 | ) 128 | model = BertForMaskedLM(config=config) 129 | elif arch == 'gpt2': 130 | config = GPT2Config( 131 | n_head=nhead, 132 | activation_function='relu', 133 | n_embd=width, 134 | n_layer=2, 135 | attn_mult=8 if mup else None, 136 | # num_labels=1, 137 | # resid_pdrop=0, 138 | # embd_pdrop=0, 139 | # attn_pdrop=0, 140 | ) 141 | model = GPT2LMHeadModel(config=config) 142 | if mup: 143 | set_base_shapes(model, base_shape) 144 | else: 145 | set_base_shapes(model, None) 146 | 147 | model.apply( 148 | partial(model._init_weights, 149 | readout_zero_init=readout_zero_init, 150 | query_zero_init=query_zero_init, 151 | )) 152 | return model 153 | return f 154 | 155 | def plot_coord_check(arch, mup=True, vary_nhead=False, y='l1', widths=None, optimizer='adam', 156 | nseeds=1, nsteps=4, loglog=False, logbase=2, legend=None, 157 | **get_coord_data_kw): 158 | if widths is None: 159 | widths = 2**np.arange(6, 11) 160 | base_shape = make_bsh(arch) 161 | models = {width: get_lazy_model(arch, width, base_shape=base_shape, mup=mup, vary_nhead=vary_nhead) for width in widths} 162 | dataloader = get_dataloader(arch) 163 | df = get_coord_data(models, dataloader, mup=mup, optimizer=optimizer, 164 | nseeds=nseeds, dict_in_out=True, 165 | nsteps=nsteps, **get_coord_data_kw) 166 | 167 | prm = 'mup' if mup else 'sp' 168 | width = 'nhead' if vary_nhead else 'dhead' 169 | return plot_coord_data(df, legend=legend, loglog=loglog, logbase=logbase, y=y, 170 | save_to=f'{arch}_{prm}_{width}_coord_check.png', suptitle=f'{prm} {arch} {width}', 171 | face_color='xkcd:light grey' if not mup else None) -------------------------------------------------------------------------------- /tests/coordcheck/gpt2_mup_dhead_coord_check.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/mutransformers/480287ce7b18a07a3432e8f2fbc0f0e5b71e2599/tests/coordcheck/gpt2_mup_dhead_coord_check.png -------------------------------------------------------------------------------- /tests/coordcheck/gpt2_mup_nhead_coord_check.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/mutransformers/480287ce7b18a07a3432e8f2fbc0f0e5b71e2599/tests/coordcheck/gpt2_mup_nhead_coord_check.png -------------------------------------------------------------------------------- /tests/coordcheck/gpt2_sp_dhead_coord_check.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/mutransformers/480287ce7b18a07a3432e8f2fbc0f0e5b71e2599/tests/coordcheck/gpt2_sp_dhead_coord_check.png -------------------------------------------------------------------------------- /tests/coordcheck/gpt2_sp_nhead_coord_check.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/mutransformers/480287ce7b18a07a3432e8f2fbc0f0e5b71e2599/tests/coordcheck/gpt2_sp_nhead_coord_check.png -------------------------------------------------------------------------------- /tests/coordcheck/roberta_mup_dhead_coord_check.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/mutransformers/480287ce7b18a07a3432e8f2fbc0f0e5b71e2599/tests/coordcheck/roberta_mup_dhead_coord_check.png -------------------------------------------------------------------------------- /tests/coordcheck/roberta_mup_nhead_coord_check.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/mutransformers/480287ce7b18a07a3432e8f2fbc0f0e5b71e2599/tests/coordcheck/roberta_mup_nhead_coord_check.png -------------------------------------------------------------------------------- /tests/coordcheck/roberta_sp_dhead_coord_check.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/mutransformers/480287ce7b18a07a3432e8f2fbc0f0e5b71e2599/tests/coordcheck/roberta_sp_dhead_coord_check.png -------------------------------------------------------------------------------- /tests/coordcheck/roberta_sp_nhead_coord_check.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/mutransformers/480287ce7b18a07a3432e8f2fbc0f0e5b71e2599/tests/coordcheck/roberta_sp_nhead_coord_check.png --------------------------------------------------------------------------------