├── LICENSE ├── README.md ├── models ├── __init__.py ├── bert.py ├── bert_utils.py ├── gaussian.py ├── resnet.py └── transformer.py ├── requirements.txt ├── runs ├── learned │ └── run.sh ├── many-heads │ └── run.sh ├── quadratic-generalized-pruned │ └── run.sh ├── quadratic-generalized │ └── run.sh ├── quadratic │ └── run.sh └── resnet │ └── run.sh ├── train.py └── utils ├── __init__.py ├── accumulators.py ├── config.py ├── data.py ├── learning_rate.py ├── logging.py └── plotting.py /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Self-Attention and Convolution 2 | 3 | The code accompanies the paper [On the Relationship between Self-Attention and Convolutional Layers](https://openreview.net/pdf?id=HJlnC1rKPB) by [Jean-Baptiste Cordonnier](http://jbcordonnier.com/), [Andreas Loukas](https://andreasloukas.blog/) and [Martin Jaggi](https://m8j.net/) that appeared in ICLR 2020. 4 | 5 | ### Abstract 6 | 7 | Recent trends of incorporating attention mechanisms in vision have led researchers to reconsider the supremacy of convolutional layers as a primary building block. Beyond helping CNNs to handle long-range dependencies, Ramachandran et al. (2019) showed that attention can completely replace convolution and achieve state-of-the-art performance on vision tasks. This raises the question: do learned attention layers operate similarly to convolutional layers? This work provides evidence that attention layers can perform convolution and, indeed, they often learn to do so in practice. Specifically, we prove that a multi-head self-attention layer with sufficient number of heads is at least as powerful as any convolutional layer. Our numerical experiments then show that the phenomenon also occurs in practice, corroborating our analysis. Our code is publicly available. 8 | 9 | ### Interact with Attention 10 | 11 | Check out our [interactive website](https://epfml.github.io/attention-cnn/). 12 | 13 | ### Reproduce 14 | 15 | To run our code on a Ubuntu machine with a GPU, install the Python packages in a fresh Anaconda environment: 16 | 17 | ```bash 18 | conda install pytorch torchvision cudatoolkit=10.0 -c pytorch 19 | pip install -r requirements.txt 20 | ``` 21 | 22 | All experiments presented in the paper are reproducible by running the scripts in `runs/`, for example: 23 | 24 | ``` 25 | bash runs/quadratic/run.sh 26 | ``` 27 | 28 | ### Reference 29 | 30 | If you use this code, please cite the following [paper](https://openreview.net/pdf?id=HJlnC1rKPB): 31 | 32 | ``` 33 | @inproceedings{ 34 | Cordonnier2020On, 35 | title={On the Relationship between Self-Attention and Convolutional Layers}, 36 | author={Jean-Baptiste Cordonnier and Andreas Loukas and Martin Jaggi}, 37 | booktitle={International Conference on Learning Representations}, 38 | year={2020}, 39 | url={https://openreview.net/forum?id=HJlnC1rKPB} 40 | } 41 | ``` 42 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | # from .vgg import * 2 | # from .dpn import * 3 | # from .lenet import * 4 | # from .senet import * 5 | # from .pnasnet import * 6 | # from .densenet import * 7 | # from .googlenet import * 8 | # from .shufflenet import * 9 | from .resnet import * 10 | # from .resnext import * 11 | # from .preact_resnet import * 12 | # from .mobilenet import * 13 | # from .mobilenetv2 import * 14 | from .transformer import * 15 | -------------------------------------------------------------------------------- /models/bert.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. 3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """PyTorch BERT model.""" 17 | 18 | from __future__ import absolute_import, division, print_function, unicode_literals 19 | 20 | import copy 21 | import json 22 | import logging 23 | import math 24 | import os 25 | import shutil 26 | import tarfile 27 | import tempfile 28 | import sys 29 | from io import open 30 | import numbers 31 | 32 | import torch 33 | from torch import nn 34 | from torch.nn import CrossEntropyLoss 35 | from torch.nn import functional as F 36 | 37 | from .bert_utils import cached_path, WEIGHTS_NAME, CONFIG_NAME 38 | from .gaussian import gaussian_kernel_2d 39 | 40 | MAX_WIDTH_HEIGHT = 64 41 | 42 | logger = logging.getLogger(__name__) 43 | 44 | PRETRAINED_MODEL_ARCHIVE_MAP = { 45 | "bert-base-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased.tar.gz", 46 | "bert-large-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased.tar.gz", 47 | "bert-base-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased.tar.gz", 48 | "bert-large-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased.tar.gz", 49 | "bert-base-multilingual-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased.tar.gz", 50 | "bert-base-multilingual-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased.tar.gz", 51 | "bert-base-chinese": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese.tar.gz", 52 | "bert-base-german-cased": "https://int-deepset-models-bert.s3.eu-central-1.amazonaws.com/pytorch/bert-base-german-cased.tar.gz", 53 | "bert-large-uncased-whole-word-masking": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking.tar.gz", 54 | "bert-large-cased-whole-word-masking": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking.tar.gz", 55 | } 56 | BERT_CONFIG_NAME = "bert_config.json" 57 | TF_WEIGHTS_NAME = "model.ckpt" 58 | 59 | 60 | 61 | def prune_linear_layer(layer, index, dim=0): 62 | """ Prune a linear layer (a model parameters) to keep only entries in index. 63 | Return the pruned layer as a new layer with requires_grad=True. 64 | Used to remove heads. 65 | """ 66 | device = layer.weight.device 67 | index = index.to(device) 68 | W = layer.weight.index_select(dim, index).clone().detach() 69 | if layer.bias is not None: 70 | if dim == 1: 71 | b = layer.bias.clone().detach() 72 | else: 73 | b = layer.bias[index].clone().detach() 74 | new_size = list(layer.weight.size()) 75 | new_size[dim] = len(index) 76 | new_layer = nn.Linear(new_size[1], new_size[0], bias=layer.bias is not None) 77 | new_layer.weight.requires_grad = False 78 | new_layer.weight.copy_(W.contiguous()) 79 | new_layer.weight.requires_grad = True 80 | if layer.bias is not None: 81 | new_layer.bias.requires_grad = False 82 | new_layer.bias.copy_(b.contiguous()) 83 | new_layer.bias.requires_grad = True 84 | return new_layer.to(device) 85 | 86 | 87 | def load_tf_weights_in_bert(model, tf_checkpoint_path): 88 | """ Load tf checkpoints in a pytorch model 89 | """ 90 | try: 91 | import re 92 | import numpy as np 93 | import tensorflow as tf 94 | except ImportError: 95 | print( 96 | "Loading a TensorFlow models in PyTorch, requires TensorFlow to be installed. Please see " 97 | "https://www.tensorflow.org/install/ for installation instructions." 98 | ) 99 | raise 100 | tf_path = os.path.abspath(tf_checkpoint_path) 101 | print("Converting TensorFlow checkpoint from {}".format(tf_path)) 102 | # Load weights from TF model 103 | init_vars = tf.train.list_variables(tf_path) 104 | names = [] 105 | arrays = [] 106 | for name, shape in init_vars: 107 | print("Loading TF weight {} with shape {}".format(name, shape)) 108 | array = tf.train.load_variable(tf_path, name) 109 | names.append(name) 110 | arrays.append(array) 111 | 112 | for name, array in zip(names, arrays): 113 | name = name.split("/") 114 | # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v 115 | # which are not required for using pretrained model 116 | if any(n in ["adam_v", "adam_m", "global_step"] for n in name): 117 | print("Skipping {}".format("/".join(name))) 118 | continue 119 | pointer = model 120 | for m_name in name: 121 | if re.fullmatch(r"[A-Za-z]+_\d+", m_name): 122 | l = re.split(r"_(\d+)", m_name) 123 | else: 124 | l = [m_name] 125 | if l[0] == "kernel" or l[0] == "gamma": 126 | pointer = getattr(pointer, "weight") 127 | elif l[0] == "output_bias" or l[0] == "beta": 128 | pointer = getattr(pointer, "bias") 129 | elif l[0] == "output_weights": 130 | pointer = getattr(pointer, "weight") 131 | elif l[0] == "squad": 132 | pointer = getattr(pointer, "classifier") 133 | else: 134 | try: 135 | pointer = getattr(pointer, l[0]) 136 | except AttributeError: 137 | print("Skipping {}".format("/".join(name))) 138 | continue 139 | if len(l) >= 2: 140 | num = int(l[1]) 141 | pointer = pointer[num] 142 | if m_name[-11:] == "_embeddings": 143 | pointer = getattr(pointer, "weight") 144 | elif m_name == "kernel": 145 | array = np.transpose(array) 146 | try: 147 | assert pointer.shape == array.shape 148 | except AssertionError as e: 149 | e.args += (pointer.shape, array.shape) 150 | raise 151 | print("Initialize PyTorch weight {}".format(name)) 152 | pointer.data = torch.from_numpy(array) 153 | return model 154 | 155 | 156 | def gelu(x): 157 | """Implementation of the gelu activation function. 158 | For information: OpenAI GPT's gelu is slightly different (and gives slightly different results): 159 | 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) 160 | Also see https://arxiv.org/abs/1606.08415 161 | """ 162 | return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0))) 163 | 164 | 165 | def swish(x): 166 | return x * torch.sigmoid(x) 167 | 168 | 169 | ACT2FN = {"gelu": gelu, "relu": torch.nn.functional.relu, "swish": swish} 170 | 171 | 172 | class BertConfig(object): 173 | """Configuration class to store the configuration of a `BertModel`. 174 | """ 175 | 176 | def __init__( 177 | self, 178 | vocab_size_or_config_json_file, 179 | hidden_size=None, 180 | position_encoding_size=None, 181 | num_hidden_layers=None, 182 | num_attention_heads=None, 183 | intermediate_size=None, 184 | hidden_act=None, 185 | hidden_dropout_prob=None, 186 | attention_probs_dropout_prob=None, 187 | max_position_embeddings=None, 188 | type_vocab_size=None, 189 | initializer_range=None, 190 | layer_norm_eps=None, 191 | use_learned_2d_encoding=None, 192 | share_position_encoding=None, 193 | use_attention_data=None, 194 | query_positional_score=None, 195 | use_gaussian_attention=None, 196 | add_positional_encoding_to_input=None, 197 | positional_encoding=None, 198 | max_positional_encoding=None, 199 | attention_gaussian_blur_trick=None, 200 | attention_isotropic_gaussian=None, 201 | gaussian_init_sigma_std=None, 202 | gaussian_init_mu_std=None, 203 | ): 204 | """Constructs BertConfig. 205 | 206 | Args: 207 | vocab_size_or_config_json_file: Vocabulary size of `inputs_ids` in `BertModel`. 208 | hidden_size: Size of the encoder layers and the pooler layer. 209 | num_hidden_layers: Number of hidden layers in the Transformer encoder. 210 | num_attention_heads: Number of attention heads for each attention layer in 211 | the Transformer encoder. 212 | intermediate_size: The size of the "intermediate" (i.e., feed-forward) 213 | layer in the Transformer encoder. 214 | hidden_act: The non-linear activation function (function or string) in the 215 | encoder and pooler. If string, "gelu", "relu" and "swish" are supported. 216 | hidden_dropout_prob: The dropout probabilitiy for all fully connected 217 | layers in the embeddings, encoder, and pooler. 218 | attention_probs_dropout_prob: The dropout ratio for the attention 219 | probabilities. 220 | max_position_embeddings: The maximum sequence length that this model might 221 | ever be used with. Typically set this to something large just in case 222 | (e.g., 512 or 1024 or 2048). 223 | type_vocab_size: The vocabulary size of the `token_type_ids` passed into 224 | `BertModel`. 225 | initializer_range: The sttdev of the truncated_normal_initializer for 226 | initializing all weight matrices. 227 | layer_norm_eps: The epsilon used by LayerNorm. 228 | """ 229 | if isinstance(vocab_size_or_config_json_file, str) or ( 230 | sys.version_info[0] == 2 and isinstance(vocab_size_or_config_json_file, unicode) 231 | ): 232 | with open(vocab_size_or_config_json_file, "r", encoding="utf-8") as reader: 233 | json_config = json.loads(reader.read()) 234 | for key, value in json_config.items(): 235 | self.__dict__[key] = value 236 | elif isinstance(vocab_size_or_config_json_file, int): 237 | self.vocab_size = vocab_size_or_config_json_file 238 | self.hidden_size = hidden_size 239 | self.position_encoding_size = position_encoding_size 240 | self.num_hidden_layers = num_hidden_layers 241 | self.num_attention_heads = num_attention_heads 242 | self.hidden_act = hidden_act 243 | self.intermediate_size = intermediate_size 244 | self.hidden_dropout_prob = hidden_dropout_prob 245 | self.attention_probs_dropout_prob = attention_probs_dropout_prob 246 | self.max_position_embeddings = max_position_embeddings 247 | self.type_vocab_size = type_vocab_size 248 | self.initializer_range = initializer_range 249 | self.layer_norm_eps = layer_norm_eps 250 | self.use_learned_2d_encoding = use_learned_2d_encoding 251 | self.use_gaussian_attention = use_gaussian_attention 252 | self.positional_encoding = positional_encoding 253 | self.max_positional_encoding = max_positional_encoding 254 | self.attention_gaussian_blur_trick = attention_gaussian_blur_trick 255 | self.attention_isotropic_gaussian = attention_isotropic_gaussian 256 | self.gaussian_init_sigma_std = gaussian_init_sigma_std 257 | self.gaussian_init_mu_std = gaussian_init_mu_std 258 | else: 259 | raise ValueError( 260 | "First argument must be either a vocabulary size (int)" 261 | "or the path to a pretrained model config file (str)" 262 | ) 263 | 264 | @classmethod 265 | def from_dict(cls, json_object): 266 | """Constructs a `BertConfig` from a Python dictionary of parameters.""" 267 | config = BertConfig(vocab_size_or_config_json_file=-1) 268 | for key, value in json_object.items(): 269 | config.__dict__[key] = value 270 | return config 271 | 272 | @classmethod 273 | def from_json_file(cls, json_file): 274 | """Constructs a `BertConfig` from a json file of parameters.""" 275 | with open(json_file, "r", encoding="utf-8") as reader: 276 | text = reader.read() 277 | return cls.from_dict(json.loads(text)) 278 | 279 | def __repr__(self): 280 | return str(self.to_json_string()) 281 | 282 | def to_dict(self): 283 | """Serializes this instance to a Python dictionary.""" 284 | output = copy.deepcopy(self.__dict__) 285 | return output 286 | 287 | def to_json_string(self): 288 | """Serializes this instance to a JSON string.""" 289 | return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n" 290 | 291 | def to_json_file(self, json_file_path): 292 | """ Save this instance to a json file.""" 293 | with open(json_file_path, "w", encoding="utf-8") as writer: 294 | writer.write(self.to_json_string()) 295 | 296 | 297 | try: 298 | from apex.normalization.fused_layer_norm import FusedLayerNorm as BertLayerNorm 299 | except ImportError: 300 | logger.info( 301 | "Better speed can be achieved with apex installed from https://www.github.com/nvidia/apex ." 302 | ) 303 | 304 | class BertLayerNorm(nn.Module): 305 | def __init__(self, hidden_size, eps=1e-12): 306 | """Construct a layernorm module in the TF style (epsilon inside the square root). 307 | """ 308 | super(BertLayerNorm, self).__init__() 309 | self.weight = nn.Parameter(torch.ones(hidden_size)) 310 | self.bias = nn.Parameter(torch.zeros(hidden_size)) 311 | self.variance_epsilon = eps 312 | 313 | def forward(self, x): 314 | u = x.mean(-1, keepdim=True) 315 | s = (x - u).pow(2).mean(-1, keepdim=True) 316 | x = (x - u) / torch.sqrt(s + self.variance_epsilon) 317 | return self.weight * x + self.bias 318 | 319 | 320 | class BertEmbeddings(nn.Module): 321 | """Construct the embeddings from word, position and token_type embeddings. 322 | """ 323 | 324 | def __init__(self, config): 325 | super(BertEmbeddings, self).__init__() 326 | self.add_positional_encoding = config.add_positional_encoding_to_input 327 | self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=0) 328 | if self.add_positional_encoding: 329 | self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) 330 | self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size) 331 | 332 | # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load 333 | # any TensorFlow checkpoint file 334 | self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps) 335 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 336 | 337 | def forward(self, input_ids, token_type_ids=None): 338 | seq_length = input_ids.size(1) 339 | position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device) 340 | position_ids = position_ids.unsqueeze(0).expand_as(input_ids) 341 | if token_type_ids is None: 342 | token_type_ids = torch.zeros_like(input_ids) 343 | 344 | words_embeddings = self.word_embeddings(input_ids) 345 | token_type_embeddings = self.token_type_embeddings(token_type_ids) 346 | 347 | embeddings = words_embeddings + token_type_embeddings 348 | 349 | if self.add_positional_encoding: 350 | position_embeddings = self.position_embeddings(position_ids) 351 | embeddings += position_embeddings 352 | 353 | embeddings = self.LayerNorm(embeddings) 354 | embeddings = self.dropout(embeddings) 355 | return embeddings 356 | 357 | 358 | class Learned2DRelativeSelfAttention(nn.Module): 359 | def __init__(self, config, output_attentions=False, keep_multihead_output=False): 360 | super().__init__() 361 | self.output_attentions = output_attentions 362 | self.num_attention_heads = config.num_attention_heads 363 | self.use_attention_data = config.use_attention_data 364 | self.query_positional_score = config.query_positional_score 365 | self.hidden_size = config.hidden_size 366 | self.all_head_size = config.hidden_size * self.num_attention_heads 367 | 368 | max_position_embeddings = config.max_position_embeddings 369 | 370 | position_embedding_size = config.hidden_size 371 | if self.query_positional_score: 372 | position_embedding_size = config.hidden_size // 2 373 | if config.position_encoding_size != -1: 374 | position_embedding_size = config.position_encoding_size 375 | 376 | self.row_embeddings = nn.Embedding(2 * max_position_embeddings - 1, position_embedding_size) 377 | self.col_embeddings = nn.Embedding(2 * max_position_embeddings - 1, position_embedding_size) 378 | 379 | if not self.query_positional_score: 380 | self.head_keys_row = nn.Linear(position_embedding_size, self.num_attention_heads, bias=False) 381 | self.head_keys_col = nn.Linear(position_embedding_size, self.num_attention_heads, bias=False) 382 | 383 | # need query linear transformation 384 | if self.use_attention_data or self.query_positional_score: 385 | self.query = nn.Linear(config.hidden_size, self.all_head_size) 386 | 387 | # need key linear transformation 388 | if self.use_attention_data: 389 | self.key = nn.Linear(config.hidden_size, self.all_head_size) 390 | 391 | self.dropout = nn.Dropout(config.attention_probs_dropout_prob) 392 | self.value = nn.Linear(self.all_head_size, config.hidden_size) 393 | 394 | deltas = torch.arange(max_position_embeddings).view(1, -1) - torch.arange(max_position_embeddings).view(-1, 1) 395 | # shift the delta to [0, 2 * max_position_embeddings - 1] 396 | relative_indices = deltas + max_position_embeddings - 1 397 | 398 | self.register_buffer("relative_indices", relative_indices) 399 | 400 | def forward(self, hidden_states, attention_mask, head_mask=None): 401 | assert len(hidden_states.shape) == 4 402 | b, w, h, c = hidden_states.shape 403 | 404 | # -- B, W, H, num_heads, W, H 405 | attention_scores, attention_scores_per_type = self.compute_attention_scores(hidden_states) 406 | shape = attention_scores.shape 407 | attention_probs = nn.Softmax(dim=-1)(attention_scores.view(*shape[:-2], -1)).view(shape) 408 | # expand batch dim if 1 409 | if shape[0] != b: 410 | attention_probs = attention_probs.expand(b, *shape[1:]) 411 | 412 | attention_probs = self.dropout(attention_probs) 413 | 414 | input_values = torch.einsum('bijhkl,bkld->bijhd', attention_probs, hidden_states) 415 | input_values = input_values.contiguous().view(b, w, h, -1) 416 | output_value = self.value(input_values) 417 | 418 | if self.output_attentions: 419 | attention_scores_per_type["attention_scores"] = attention_scores 420 | attention_scores_per_type["attention_probs"] = attention_probs 421 | return attention_scores_per_type, output_value 422 | else: 423 | return output_value 424 | 425 | def compute_attention_scores(self, hidden_states): 426 | """Compute the positional attention for an image of size width x height 427 | Returns: tensor of attention scores (1 or batch, width, height, num_head, width, height) 428 | 429 | Attention scores: 430 | * Position only 431 | Options: use_attention_data=False, query_positional_score=False 432 | w_q^T * r 433 | where w_q is a learned vector per head 434 | * Query and positional encoding (without query key attention scores), 435 | same as q * r in (Ramachandran et al., 2019) 436 | Options: use_attention_data=False, query_positional_score=True 437 | X * W_Q * r 438 | * With data 439 | same as q*k + q*r in (Ramachandran et al., 2019) 440 | Options: use_attention_data=True, query_positional_score=True 441 | X * W_Q * W_K^T * X^T + X * W_Q * r 442 | * Last option use_attention_data=True, query_positional_score=False was not used 443 | """ 444 | batch_size, height, width, hidden_dim = hidden_states.shape 445 | 446 | # compute query data if needed 447 | if self.use_attention_data or self.query_positional_score: 448 | q = self.query(hidden_states) 449 | q = q.view(batch_size, width, height, self.num_attention_heads, self.hidden_size) 450 | 451 | # compute key data if needed 452 | if self.use_attention_data: 453 | k = self.key(hidden_states) 454 | k = k.view(batch_size, width, height, self.num_attention_heads, self.hidden_size) 455 | 456 | # Compute attention scores based on position 457 | # Probably not optimal way to order computation 458 | relative_indices = self.relative_indices[:width,:width].reshape(-1) 459 | row_embeddings = self.row_embeddings(relative_indices) 460 | 461 | relative_indices = self.relative_indices[:height,:height].reshape(-1) 462 | col_embeddings = self.col_embeddings(relative_indices) 463 | 464 | # keep attention scores/prob for plotting 465 | attention_scores_per_type = {} 466 | sqrt_normalizer = math.sqrt(self.hidden_size) 467 | 468 | if not self.query_positional_score: 469 | # Caveat: sqrt rescaling is not used in this case 470 | row_scores = self.head_keys_row(row_embeddings).view(1, width, 1, width, self.num_attention_heads) 471 | col_scores = self.head_keys_col(col_embeddings).view(height, 1, height, 1, self.num_attention_heads) 472 | # -- H, W, H, W, num_attention_heads 473 | attention_scores = row_scores + col_scores 474 | # -- H, W, num_attention_heads, H, W 475 | attention_scores = attention_scores.permute(0, 1, 4, 2, 3) 476 | # -- 1, H, W, num_attention_heads, H, W 477 | attention_scores = attention_scores.unsqueeze(0) 478 | 479 | attention_scores_per_type["w_q^Tr"] = attention_scores 480 | 481 | else: # query_positional_score 482 | # B, W, H, num_attention_heads, D // 2 483 | q_row = q[:, :, :, :, :self.hidden_size // 2] 484 | q_col = q[:, :, :, :, self.hidden_size // 2:] 485 | 486 | row_scores = torch.einsum("bijhd,ikd->bijhk", q_row, row_embeddings.view(width, width, -1)) 487 | col_scores = torch.einsum("bijhd,jld->bijhl", q_col, col_embeddings.view(height, height, -1)) 488 | 489 | # -- B, H, W, num_attention_heads, H, W 490 | attention_scores = row_scores.unsqueeze(-1) + col_scores.unsqueeze(-2) 491 | attention_scores = attention_scores / sqrt_normalizer 492 | 493 | # save 494 | attention_scores_per_type["q^Tr"] = attention_scores 495 | 496 | # Compute attention scores based on data 497 | if self.use_attention_data: 498 | attention_content_scores = torch.einsum("bijhd,bklhd->bijhkl", q, k) 499 | attention_content_scores = attention_content_scores / sqrt_normalizer 500 | attention_scores = attention_scores + attention_content_scores 501 | 502 | # save 503 | attention_scores_per_type["q^Tk"] = attention_content_scores 504 | 505 | return attention_scores, attention_scores_per_type 506 | 507 | def get_attention_probs(self, width, height): 508 | """LEGACY 509 | Compute the positional attention for an image of size width x height 510 | Returns: tensor of attention probabilities (width, height, num_head, width, height) 511 | """ 512 | relative_indices = self.relative_indices[:width,:width].reshape(-1) 513 | row_scores = self.head_keys_row(self.row_embeddings(relative_indices)).view(1, width, 1, width, self.num_attention_heads) 514 | 515 | relative_indices = self.relative_indices[:height,:height].reshape(-1) 516 | col_scores = self.head_keys_col(self.col_embeddings(relative_indices)).view(height, 1, height, 1, self.num_attention_heads) 517 | 518 | # -- H, W, H, W, num_attention_heads 519 | attention_scores = row_scores + col_scores 520 | # -- H, W, num_attention_heads, H, W 521 | attention_scores = attention_scores.permute(0, 1, 4, 2, 3) 522 | 523 | # -- H, W, num_attention_heads, H, W 524 | flatten_shape = [height, width, self.num_attention_heads, height * width] 525 | unflatten_shape = [height, width, self.num_attention_heads, height, width] 526 | attention_probs = nn.Softmax(dim=-1)(attention_scores.view(*flatten_shape)).view(*unflatten_shape) 527 | 528 | return attention_probs 529 | 530 | 531 | class GaussianSelfAttention(nn.Module): 532 | def __init__(self, config, output_attentions=False, keep_multihead_output=False): 533 | super().__init__() 534 | self.attention_gaussian_blur_trick = config.attention_gaussian_blur_trick 535 | self.attention_isotropic_gaussian = config.attention_isotropic_gaussian 536 | self.gaussian_init_mu_std = config.gaussian_init_mu_std 537 | self.gaussian_init_sigma_std = config.gaussian_init_sigma_std 538 | self.config = config 539 | 540 | self.num_attention_heads = config.num_attention_heads 541 | self.attention_head_size = config.hidden_size 542 | # assert config.hidden_size % config.num_attention_heads == 0, "num_attention_heads should divide hidden_size" 543 | # self.attention_head_size = int(config.hidden_size / config.num_attention_heads) 544 | self.all_head_size = self.num_attention_heads * config.hidden_size 545 | self.output_attentions = output_attentions 546 | 547 | # CAREFUL: if change something here, change also in reset_heads (TODO remove code duplication) 548 | # shift of the each gaussian per head 549 | self.attention_centers = nn.Parameter( 550 | torch.zeros(self.num_attention_heads, 2).normal_(0.0, config.gaussian_init_mu_std) 551 | ) 552 | 553 | if config.attention_isotropic_gaussian: 554 | # only one scalar (inverse standard deviation) 555 | # initialized to 1 + noise 556 | attention_spreads = 1 + torch.zeros(self.num_attention_heads).normal_(0, config.gaussian_init_sigma_std) 557 | else: 558 | # Inverse standart deviation $Sigma^{-1/2}$ 559 | # 2x2 matrix or a scalar per head 560 | # initialized to noisy identity matrix 561 | attention_spreads = torch.eye(2).unsqueeze(0).repeat(self.num_attention_heads, 1, 1) 562 | attention_spreads += torch.zeros_like(attention_spreads).normal_(0, config.gaussian_init_sigma_std) 563 | 564 | self.attention_spreads = nn.Parameter(attention_spreads) 565 | 566 | self.value = nn.Linear(self.all_head_size, config.hidden_size) 567 | 568 | if not config.attention_gaussian_blur_trick: 569 | # relative encoding grid (delta_x, delta_y, delta_x**2, delta_y**2, delta_x * delta_y) 570 | MAX_WIDTH_HEIGHT = 50 571 | range_ = torch.arange(MAX_WIDTH_HEIGHT) 572 | grid = torch.cat([t.unsqueeze(-1) for t in torch.meshgrid([range_, range_])], dim=-1) 573 | relative_indices = grid.unsqueeze(0).unsqueeze(0) - grid.unsqueeze(-2).unsqueeze(-2) 574 | R = torch.cat([relative_indices, relative_indices ** 2, (relative_indices[..., 0] * relative_indices[..., 1]).unsqueeze(-1)], dim=-1) 575 | R = R.float() 576 | self.register_buffer("R", R) 577 | self.dropout = nn.Dropout(config.attention_probs_dropout_prob) 578 | 579 | def get_heads_target_vectors(self): 580 | if self.attention_isotropic_gaussian: 581 | a = c = self.attention_spreads ** 2 582 | b = torch.zeros_like(self.attention_spreads) 583 | else: 584 | # $\Sigma^{-1}$ 585 | inv_covariance = torch.einsum('hij,hkj->hik', [self.attention_spreads, self.attention_spreads]) 586 | a, b, c = inv_covariance[:, 0, 0], inv_covariance[:, 0, 1], inv_covariance[:, 1, 1] 587 | 588 | mu_1, mu_2 = self.attention_centers[:, 0], self.attention_centers[:, 1] 589 | 590 | t_h = -1/2 * torch.stack([ 591 | -2*(a*mu_1 + b*mu_2), 592 | -2*(c*mu_2 + b*mu_1), 593 | a, 594 | c, 595 | 2 * b 596 | ], dim=-1) 597 | return t_h 598 | 599 | def get_attention_probs(self, width, height): 600 | """Compute the positional attention for an image of size width x height 601 | Returns: tensor of attention probabilities (width, height, num_head, width, height) 602 | """ 603 | u = self.get_heads_target_vectors() 604 | 605 | # Compute attention map for each head 606 | attention_scores = torch.einsum('ijkld,hd->ijhkl', [self.R[:width,:height,:width,:height,:], u]) 607 | # Softmax 608 | attention_probs = torch.nn.Softmax(dim=-1)(attention_scores.view(width, height, self.num_attention_heads, -1)) 609 | attention_probs = attention_probs.view(width, height, self.num_attention_heads, width, height) 610 | 611 | return attention_probs 612 | 613 | def reset_heads(self, heads): 614 | device = self.attention_spreads.data.device 615 | reset_heads_mask = torch.zeros(self.num_attention_heads, device=device, dtype=torch.bool) 616 | for head in heads: 617 | reset_heads_mask[head] = 1 618 | 619 | # Reinitialize mu and sigma of these heads 620 | self.attention_centers.data[reset_heads_mask].zero_().normal_(0.0, self.gaussian_init_mu_std) 621 | 622 | if self.attention_isotropic_gaussian: 623 | self.attention_spreads.ones_().normal_(0, self.gaussian_init_sigma_std) 624 | else: 625 | self.attention_spreads.zero_().normal_(0, self.gaussian_init_sigma_std) 626 | self.attention_spreads[:, 0, 0] += 1 627 | self.attention_spreads[:, 1, 1] += 1 628 | 629 | # Reinitialize value matrix for these heads 630 | mask = torch.zeros(self.num_attention_heads, self.attention_head_size, dtype=torch.bool) 631 | for head in heads: 632 | mask[head] = 1 633 | mask = mask.view(-1).contiguous() 634 | self.value.weight.data[:, mask].normal_(mean=0.0, std=self.config.initializer_range) 635 | # self.value.bias.data.zero_() 636 | 637 | 638 | def blured_attention(self, X): 639 | """Compute the weighted average according to gaussian attention without 640 | computing explicitly the attention coefficients. 641 | 642 | Args: 643 | X (tensor): shape (batch, width, height, dim) 644 | Output: 645 | shape (batch, width, height, dim x num_heads) 646 | """ 647 | num_heads = self.attention_centers.shape[0] 648 | batch, width, height, d_total = X.shape 649 | Y = X.permute(0, 3, 1, 2).contiguous() 650 | 651 | kernels = [] 652 | kernel_width = kernel_height = 7 653 | assert kernel_width % 2 == 1 and kernel_height % 2 == 1, 'kernel size should be odd' 654 | 655 | for mean, std_inv in zip(self.attention_centers, self.attention_spreads): 656 | conv_weights = gaussian_kernel_2d(mean, std_inv, size=(kernel_width, kernel_height)) 657 | conv_weights = conv_weights.view(1, 1, kernel_width, kernel_height).repeat(d_total, 1, 1, 1) 658 | kernels.append(conv_weights) 659 | 660 | weights = torch.cat(kernels) 661 | 662 | padding_width = (kernel_width - 1) // 2 663 | padding_height = (kernel_height - 1) // 2 664 | out = F.conv2d(Y, weights, groups=d_total, padding=(padding_width, padding_height)) 665 | 666 | # renormalize for padding 667 | all_one_input = torch.ones(1, d_total, width, height, device=X.device) 668 | normalizer = F.conv2d(all_one_input, weights, groups=d_total, padding=(padding_width, padding_height)) 669 | out /= normalizer 670 | 671 | return out.permute(0, 2, 3, 1).contiguous() 672 | 673 | def forward(self, hidden_states, attention_mask, head_mask=None): 674 | assert len(hidden_states.shape) == 4 675 | b, w, h, c = hidden_states.shape 676 | 677 | if not self.attention_gaussian_blur_trick: 678 | attention_probs = self.get_attention_probs(w, h) 679 | attention_probs = self.dropout(attention_probs) 680 | 681 | input_values = torch.einsum('ijhkl,bkld->bijhd', attention_probs, hidden_states) 682 | input_values = input_values.contiguous().view(b, w, h, -1) 683 | else: 684 | input_values = self.blured_attention(hidden_states) 685 | 686 | output_value = self.value(input_values) 687 | 688 | if self.output_attentions: 689 | return attention_probs, output_value 690 | else: 691 | return output_value 692 | 693 | class BertSelfAttention(nn.Module): 694 | def __init__(self, config, output_attentions=False, keep_multihead_output=False): 695 | super(BertSelfAttention, self).__init__() 696 | if config.hidden_size % config.num_attention_heads != 0: 697 | raise ValueError( 698 | "The hidden size (%d) is not a multiple of the number of attention " 699 | "heads (%d)" % (config.hidden_size, config.num_attention_heads) 700 | ) 701 | self.output_attentions = output_attentions 702 | self.keep_multihead_output = keep_multihead_output 703 | self.multihead_output = None 704 | 705 | self.num_attention_heads = config.num_attention_heads 706 | self.attention_head_size = int(config.hidden_size / config.num_attention_heads) 707 | self.all_head_size = self.num_attention_heads * self.attention_head_size 708 | 709 | self.query = nn.Linear(config.hidden_size, self.all_head_size) 710 | self.key = nn.Linear(config.hidden_size, self.all_head_size) 711 | self.value = nn.Linear(config.hidden_size, self.all_head_size) 712 | 713 | self.dropout = nn.Dropout(config.attention_probs_dropout_prob) 714 | 715 | def transpose_for_scores(self, x): 716 | new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) 717 | x = x.view(*new_x_shape) 718 | return x.permute(0, 2, 1, 3) 719 | 720 | def forward(self, hidden_states, attention_mask, head_mask=None): 721 | mixed_query_layer = self.query(hidden_states) 722 | mixed_key_layer = self.key(hidden_states) 723 | mixed_value_layer = self.value(hidden_states) 724 | 725 | query_layer = self.transpose_for_scores(mixed_query_layer) 726 | key_layer = self.transpose_for_scores(mixed_key_layer) 727 | value_layer = self.transpose_for_scores(mixed_value_layer) 728 | 729 | # Take the dot product between "query" and "key" to get the raw attention scores. 730 | attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) 731 | attention_scores = attention_scores / math.sqrt(self.attention_head_size) 732 | # Apply the attention mask is (precomputed for all layers in BertModel forward() function) 733 | attention_scores = attention_scores + attention_mask 734 | 735 | # Normalize the attention scores to probabilities. 736 | attention_probs = nn.Softmax(dim=-1)(attention_scores) 737 | 738 | # This is actually dropping out entire tokens to attend to, which might 739 | # seem a bit unusual, but is taken from the original Transformer paper. 740 | attention_probs = self.dropout(attention_probs) 741 | 742 | # Mask heads if we want to 743 | if head_mask is not None: 744 | attention_probs = attention_probs * head_mask 745 | 746 | context_layer = torch.matmul(attention_probs, value_layer) 747 | if self.keep_multihead_output: 748 | self.multihead_output = context_layer 749 | self.multihead_output.retain_grad() 750 | 751 | context_layer = context_layer.permute(0, 2, 1, 3).contiguous() 752 | new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) 753 | context_layer = context_layer.view(*new_context_layer_shape) 754 | if self.output_attentions: 755 | return attention_probs, context_layer 756 | return context_layer 757 | 758 | 759 | class BertSelfOutput(nn.Module): 760 | def __init__(self, config): 761 | super(BertSelfOutput, self).__init__() 762 | self.dense = nn.Linear(config.hidden_size, config.hidden_size) 763 | self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps) 764 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 765 | 766 | def forward(self, hidden_states, input_tensor): 767 | hidden_states = self.dense(hidden_states) 768 | hidden_states = self.dropout(hidden_states) 769 | hidden_states = self.LayerNorm(hidden_states + input_tensor) 770 | return hidden_states 771 | 772 | 773 | class BertAttention(nn.Module): 774 | def __init__(self, config, output_attentions=False, keep_multihead_output=False): 775 | super(BertAttention, self).__init__() 776 | self.output_attentions = output_attentions 777 | self.flatten_image = not config.use_gaussian_attention and not config.use_learned_2d_encoding 778 | self.use_gaussian_attention = config.use_gaussian_attention 779 | self.config = config 780 | 781 | assert not config.use_gaussian_attention or not config.use_learned_2d_encoding # TODO change to enum args 782 | 783 | if config.use_gaussian_attention: 784 | attention_cls = GaussianSelfAttention 785 | elif config.use_learned_2d_encoding: 786 | attention_cls = Learned2DRelativeSelfAttention 787 | else: 788 | attention_cls = BertSelfAttention 789 | 790 | self.self = attention_cls(config, output_attentions=output_attentions, keep_multihead_output=keep_multihead_output) 791 | 792 | self.output = BertSelfOutput(config) 793 | 794 | def prune_heads(self, heads): 795 | mask = torch.ones(self.self.num_attention_heads, self.self.attention_head_size) 796 | for head in heads: 797 | mask[head] = 0 798 | mask = mask.view(-1).contiguous().eq(1) 799 | index = torch.arange(len(mask))[mask].long() 800 | 801 | # Prune linear layers 802 | if not self.use_gaussian_attention: 803 | self.self.query = prune_linear_layer(self.self.query, index) 804 | self.self.key = prune_linear_layer(self.self.key, index) 805 | self.output.dense = prune_linear_layer(self.output.dense, index, dim=1) 806 | 807 | if self.use_gaussian_attention: 808 | device = self.self.attention_spreads.data.device 809 | keep_heads = torch.ones(self.self.num_attention_heads, device=device, dtype=torch.bool) 810 | for head in heads: 811 | keep_heads[head] = 0 812 | self.self.attention_spreads.data = self.self.attention_spreads.data[keep_heads].contiguous() 813 | self.self.attention_centers.data = self.self.attention_centers.data[keep_heads].contiguous() 814 | 815 | dim = 0 if not self.use_gaussian_attention else 1 816 | self.self.value = prune_linear_layer(self.self.value, index, dim=dim) 817 | # Update hyper params 818 | self.self.num_attention_heads = self.self.num_attention_heads - len(heads) 819 | self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads 820 | 821 | def reset_heads(self, heads): 822 | """Only for Gaussian Attention""" 823 | assert self.use_gaussian_attention 824 | self.self.reset_heads(heads) 825 | 826 | def forward(self, input_tensor, attention_mask, head_mask=None): 827 | is_image = len(input_tensor.shape) == 4 828 | if is_image and self.flatten_image: 829 | batch, width, height, d = input_tensor.shape 830 | input_tensor = input_tensor.view([batch, -1, d]) 831 | 832 | self_output = self.self(input_tensor, attention_mask, head_mask) 833 | if self.output_attentions: 834 | attentions, self_output = self_output 835 | attention_output = self.output(self_output, input_tensor) 836 | 837 | if is_image and self.flatten_image: 838 | attention_output = attention_output.view([batch, width, height, -1]) 839 | 840 | if self.output_attentions: 841 | return attentions, attention_output 842 | return attention_output 843 | 844 | 845 | class BertIntermediate(nn.Module): 846 | def __init__(self, config): 847 | super(BertIntermediate, self).__init__() 848 | self.dense = nn.Linear(config.hidden_size, config.intermediate_size) 849 | if isinstance(config.hidden_act, str) or ( 850 | sys.version_info[0] == 2 and isinstance(config.hidden_act, unicode) 851 | ): 852 | self.intermediate_act_fn = ACT2FN[config.hidden_act] 853 | else: 854 | self.intermediate_act_fn = config.hidden_act 855 | 856 | def forward(self, hidden_states): 857 | hidden_states = self.dense(hidden_states) 858 | hidden_states = self.intermediate_act_fn(hidden_states) 859 | return hidden_states 860 | 861 | 862 | class BertOutput(nn.Module): 863 | def __init__(self, config): 864 | super(BertOutput, self).__init__() 865 | self.dense = nn.Linear(config.intermediate_size, config.hidden_size) 866 | self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps) 867 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 868 | 869 | def forward(self, hidden_states, input_tensor): 870 | hidden_states = self.dense(hidden_states) 871 | hidden_states = self.dropout(hidden_states) 872 | hidden_states = self.LayerNorm(hidden_states + input_tensor) 873 | return hidden_states 874 | 875 | 876 | class BertLayer(nn.Module): 877 | def __init__(self, config, output_attentions=False, keep_multihead_output=False): 878 | super(BertLayer, self).__init__() 879 | self.output_attentions = output_attentions 880 | self.attention = BertAttention( 881 | config, output_attentions=output_attentions, keep_multihead_output=keep_multihead_output 882 | ) 883 | self.intermediate = BertIntermediate(config) 884 | self.output = BertOutput(config) 885 | 886 | def forward(self, hidden_states, attention_mask, head_mask=None): 887 | attention_output = self.attention(hidden_states, attention_mask, head_mask) 888 | if self.output_attentions: 889 | attentions, attention_output = attention_output 890 | intermediate_output = self.intermediate(attention_output) 891 | layer_output = self.output(intermediate_output, attention_output) 892 | if self.output_attentions: 893 | return attentions, layer_output 894 | return layer_output 895 | 896 | 897 | class BertEncoder(nn.Module): 898 | def __init__(self, config, output_attentions=False, keep_multihead_output=False): 899 | super(BertEncoder, self).__init__() 900 | self.output_attentions = output_attentions 901 | layer_constructor = lambda: BertLayer( 902 | config, output_attentions=output_attentions, keep_multihead_output=keep_multihead_output 903 | ) 904 | self.layer = nn.ModuleList([layer_constructor() for _ in range(config.num_hidden_layers)]) 905 | 906 | if config.use_learned_2d_encoding and config.share_position_encoding: 907 | for layer in self.layer[1:]: 908 | self.layer[0].attention.self.row_embeddings = layer.attention.self.row_embeddings 909 | self.layer[0].attention.self.col_embeddings = layer.attention.self.col_embeddings 910 | 911 | 912 | def forward( 913 | self, hidden_states, attention_mask, output_all_encoded_layers=True, head_mask=None 914 | ): 915 | all_encoder_layers = [] 916 | all_attentions = [] 917 | for i, layer_module in enumerate(self.layer): 918 | hidden_states = layer_module( 919 | hidden_states, attention_mask, head_mask[i] if head_mask is not None else None 920 | ) 921 | if self.output_attentions: 922 | attentions, hidden_states = hidden_states 923 | all_attentions.append(attentions) 924 | if output_all_encoded_layers: 925 | all_encoder_layers.append(hidden_states) 926 | if not output_all_encoded_layers: 927 | all_encoder_layers.append(hidden_states) 928 | if self.output_attentions: 929 | return all_attentions, all_encoder_layers 930 | return all_encoder_layers 931 | 932 | 933 | class BertPooler(nn.Module): 934 | def __init__(self, config): 935 | super(BertPooler, self).__init__() 936 | self.dense = nn.Linear(config.hidden_size, config.hidden_size) 937 | self.activation = nn.Tanh() 938 | 939 | def forward(self, hidden_states): 940 | # We "pool" the model by simply taking the hidden state corresponding 941 | # to the first token. 942 | first_token_tensor = hidden_states[:, 0] 943 | pooled_output = self.dense(first_token_tensor) 944 | pooled_output = self.activation(pooled_output) 945 | return pooled_output 946 | 947 | 948 | class BertPredictionHeadTransform(nn.Module): 949 | def __init__(self, config): 950 | super(BertPredictionHeadTransform, self).__init__() 951 | self.dense = nn.Linear(config.hidden_size, config.hidden_size) 952 | if isinstance(config.hidden_act, str) or ( 953 | sys.version_info[0] == 2 and isinstance(config.hidden_act, unicode) 954 | ): 955 | self.transform_act_fn = ACT2FN[config.hidden_act] 956 | else: 957 | self.transform_act_fn = config.hidden_act 958 | self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps) 959 | 960 | def forward(self, hidden_states): 961 | hidden_states = self.dense(hidden_states) 962 | hidden_states = self.transform_act_fn(hidden_states) 963 | hidden_states = self.LayerNorm(hidden_states) 964 | return hidden_states 965 | 966 | 967 | class BertLMPredictionHead(nn.Module): 968 | def __init__(self, config, bert_model_embedding_weights): 969 | super(BertLMPredictionHead, self).__init__() 970 | self.transform = BertPredictionHeadTransform(config) 971 | 972 | # The output weights are the same as the input embeddings, but there is 973 | # an output-only bias for each token. 974 | self.decoder = nn.Linear( 975 | bert_model_embedding_weights.size(1), bert_model_embedding_weights.size(0), bias=False 976 | ) 977 | self.decoder.weight = bert_model_embedding_weights 978 | self.bias = nn.Parameter(torch.zeros(bert_model_embedding_weights.size(0))) 979 | 980 | def forward(self, hidden_states): 981 | hidden_states = self.transform(hidden_states) 982 | hidden_states = self.decoder(hidden_states) + self.bias 983 | return hidden_states 984 | 985 | 986 | class BertOnlyMLMHead(nn.Module): 987 | def __init__(self, config, bert_model_embedding_weights): 988 | super(BertOnlyMLMHead, self).__init__() 989 | self.predictions = BertLMPredictionHead(config, bert_model_embedding_weights) 990 | 991 | def forward(self, sequence_output): 992 | prediction_scores = self.predictions(sequence_output) 993 | return prediction_scores 994 | 995 | 996 | class BertOnlyNSPHead(nn.Module): 997 | def __init__(self, config): 998 | super(BertOnlyNSPHead, self).__init__() 999 | self.seq_relationship = nn.Linear(config.hidden_size, 2) 1000 | 1001 | def forward(self, pooled_output): 1002 | seq_relationship_score = self.seq_relationship(pooled_output) 1003 | return seq_relationship_score 1004 | 1005 | 1006 | class BertPreTrainingHeads(nn.Module): 1007 | def __init__(self, config, bert_model_embedding_weights): 1008 | super(BertPreTrainingHeads, self).__init__() 1009 | self.predictions = BertLMPredictionHead(config, bert_model_embedding_weights) 1010 | self.seq_relationship = nn.Linear(config.hidden_size, 2) 1011 | 1012 | def forward(self, sequence_output, pooled_output): 1013 | prediction_scores = self.predictions(sequence_output) 1014 | seq_relationship_score = self.seq_relationship(pooled_output) 1015 | return prediction_scores, seq_relationship_score 1016 | 1017 | 1018 | class BertPreTrainedModel(nn.Module): 1019 | """ An abstract class to handle weights initialization and 1020 | a simple interface for dowloading and loading pretrained models. 1021 | """ 1022 | 1023 | def __init__(self, config, *inputs, **kwargs): 1024 | super(BertPreTrainedModel, self).__init__() 1025 | if not isinstance(config, BertConfig): 1026 | raise ValueError( 1027 | "Parameter config in `{}(config)` should be an instance of class `BertConfig`. " 1028 | "To create a model from a Google pretrained model use " 1029 | "`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format( 1030 | self.__class__.__name__, self.__class__.__name__ 1031 | ) 1032 | ) 1033 | self.config = config 1034 | 1035 | def init_bert_weights(self, module): 1036 | """ Initialize the weights. 1037 | """ 1038 | if isinstance(module, (nn.Linear, nn.Embedding)): 1039 | # Slightly different from the TF version which uses truncated_normal for initialization 1040 | # cf https://github.com/pytorch/pytorch/pull/5617 1041 | module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) 1042 | elif isinstance(module, BertLayerNorm): 1043 | module.bias.data.zero_() 1044 | module.weight.data.fill_(1.0) 1045 | if isinstance(module, nn.Linear) and module.bias is not None: 1046 | module.bias.data.zero_() 1047 | 1048 | @classmethod 1049 | def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs): 1050 | """ 1051 | Instantiate a BertPreTrainedModel from a pre-trained model file or a pytorch state dict. 1052 | Download and cache the pre-trained model file if needed. 1053 | 1054 | Params: 1055 | pretrained_model_name_or_path: either: 1056 | - a str with the name of a pre-trained model to load selected in the list of: 1057 | . `bert-base-uncased` 1058 | . `bert-large-uncased` 1059 | . `bert-base-cased` 1060 | . `bert-large-cased` 1061 | . `bert-base-multilingual-uncased` 1062 | . `bert-base-multilingual-cased` 1063 | . `bert-base-chinese` 1064 | . `bert-base-german-cased` 1065 | . `bert-large-uncased-whole-word-masking` 1066 | . `bert-large-cased-whole-word-masking` 1067 | - a path or url to a pretrained model archive containing: 1068 | . `bert_config.json` a configuration file for the model 1069 | . `pytorch_model.bin` a PyTorch dump of a BertForPreTraining instance 1070 | - a path or url to a pretrained model archive containing: 1071 | . `bert_config.json` a configuration file for the model 1072 | . `model.chkpt` a TensorFlow checkpoint 1073 | from_tf: should we load the weights from a locally saved TensorFlow checkpoint 1074 | cache_dir: an optional path to a folder in which the pre-trained models will be cached. 1075 | state_dict: an optional state dictionnary (collections.OrderedDict object) to use instead of Google pre-trained models 1076 | *inputs, **kwargs: additional input for the specific Bert class 1077 | (ex: num_labels for BertForSequenceClassification) 1078 | """ 1079 | state_dict = kwargs.get("state_dict", None) 1080 | kwargs.pop("state_dict", None) 1081 | cache_dir = kwargs.get("cache_dir", None) 1082 | kwargs.pop("cache_dir", None) 1083 | from_tf = kwargs.get("from_tf", False) 1084 | kwargs.pop("from_tf", None) 1085 | 1086 | if pretrained_model_name_or_path in PRETRAINED_MODEL_ARCHIVE_MAP: 1087 | archive_file = PRETRAINED_MODEL_ARCHIVE_MAP[pretrained_model_name_or_path] 1088 | else: 1089 | archive_file = pretrained_model_name_or_path 1090 | # redirect to the cache, if necessary 1091 | try: 1092 | resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir) 1093 | except EnvironmentError: 1094 | if pretrained_model_name_or_path in PRETRAINED_MODEL_ARCHIVE_MAP: 1095 | logger.error( 1096 | "Couldn't reach server at '{}' to download pretrained weights.".format( 1097 | archive_file 1098 | ) 1099 | ) 1100 | else: 1101 | logger.error( 1102 | "Model name '{}' was not found in model name list ({}). " 1103 | "We assumed '{}' was a path or url but couldn't find any file " 1104 | "associated to this path or url.".format( 1105 | pretrained_model_name_or_path, 1106 | ", ".join(PRETRAINED_MODEL_ARCHIVE_MAP.keys()), 1107 | archive_file, 1108 | ) 1109 | ) 1110 | return None 1111 | if resolved_archive_file == archive_file: 1112 | logger.info("loading archive file {}".format(archive_file)) 1113 | else: 1114 | logger.info( 1115 | "loading archive file {} from cache at {}".format( 1116 | archive_file, resolved_archive_file 1117 | ) 1118 | ) 1119 | tempdir = None 1120 | if os.path.isdir(resolved_archive_file) or from_tf: 1121 | serialization_dir = resolved_archive_file 1122 | else: 1123 | # Extract archive to temp dir 1124 | tempdir = tempfile.mkdtemp() 1125 | logger.info( 1126 | "extracting archive file {} to temp dir {}".format(resolved_archive_file, tempdir) 1127 | ) 1128 | with tarfile.open(resolved_archive_file, "r:gz") as archive: 1129 | archive.extractall(tempdir) 1130 | serialization_dir = tempdir 1131 | # Load config 1132 | config_file = os.path.join(serialization_dir, CONFIG_NAME) 1133 | if not os.path.exists(config_file): 1134 | # Backward compatibility with old naming format 1135 | config_file = os.path.join(serialization_dir, BERT_CONFIG_NAME) 1136 | config = BertConfig.from_json_file(config_file) 1137 | logger.info("Model config {}".format(config)) 1138 | # Instantiate model. 1139 | model = cls(config, *inputs, **kwargs) 1140 | if state_dict is None and not from_tf: 1141 | weights_path = os.path.join(serialization_dir, WEIGHTS_NAME) 1142 | state_dict = torch.load(weights_path, map_location="cpu") 1143 | if tempdir: 1144 | # Clean up temp dir 1145 | shutil.rmtree(tempdir) 1146 | if from_tf: 1147 | # Directly load from a TensorFlow checkpoint 1148 | weights_path = os.path.join(serialization_dir, TF_WEIGHTS_NAME) 1149 | return load_tf_weights_in_bert(model, weights_path) 1150 | # Load from a PyTorch state_dict 1151 | old_keys = [] 1152 | new_keys = [] 1153 | for key in state_dict.keys(): 1154 | new_key = None 1155 | if "gamma" in key: 1156 | new_key = key.replace("gamma", "weight") 1157 | if "beta" in key: 1158 | new_key = key.replace("beta", "bias") 1159 | if new_key: 1160 | old_keys.append(key) 1161 | new_keys.append(new_key) 1162 | for old_key, new_key in zip(old_keys, new_keys): 1163 | state_dict[new_key] = state_dict.pop(old_key) 1164 | 1165 | missing_keys = [] 1166 | unexpected_keys = [] 1167 | error_msgs = [] 1168 | # copy state_dict so _load_from_state_dict can modify it 1169 | metadata = getattr(state_dict, "_metadata", None) 1170 | state_dict = state_dict.copy() 1171 | if metadata is not None: 1172 | state_dict._metadata = metadata 1173 | 1174 | def load(module, prefix=""): 1175 | local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {}) 1176 | module._load_from_state_dict( 1177 | state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs 1178 | ) 1179 | for name, child in module._modules.items(): 1180 | if child is not None: 1181 | load(child, prefix + name + ".") 1182 | 1183 | start_prefix = "" 1184 | if not hasattr(model, "bert") and any(s.startswith("bert.") for s in state_dict.keys()): 1185 | start_prefix = "bert." 1186 | load(model, prefix=start_prefix) 1187 | if len(missing_keys) > 0: 1188 | logger.info( 1189 | "Weights of {} not initialized from pretrained model: {}".format( 1190 | model.__class__.__name__, missing_keys 1191 | ) 1192 | ) 1193 | if len(unexpected_keys) > 0: 1194 | logger.info( 1195 | "Weights from pretrained model not used in {}: {}".format( 1196 | model.__class__.__name__, unexpected_keys 1197 | ) 1198 | ) 1199 | if len(error_msgs) > 0: 1200 | raise RuntimeError( 1201 | "Error(s) in loading state_dict for {}:\n\t{}".format( 1202 | model.__class__.__name__, "\n\t".join(error_msgs) 1203 | ) 1204 | ) 1205 | return model 1206 | 1207 | 1208 | class BertModel(BertPreTrainedModel): 1209 | """BERT model ("Bidirectional Embedding Representations from a Transformer"). 1210 | 1211 | Params: 1212 | `config`: a BertConfig class instance with the configuration to build a new model 1213 | `output_attentions`: If True, also output attentions weights computed by the model at each layer. Default: False 1214 | `keep_multihead_output`: If True, saves output of the multi-head attention module with its gradient. 1215 | This can be used to compute head importance metrics. Default: False 1216 | 1217 | Inputs: 1218 | `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] 1219 | with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts 1220 | `extract_features.py`, `run_classifier.py` and `run_squad.py`) 1221 | `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token 1222 | types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to 1223 | a `sentence B` token (see BERT paper for more details). 1224 | `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices 1225 | selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max 1226 | input sequence length in the current batch. It's the mask that we typically use for attention when 1227 | a batch has varying length sentences. 1228 | `output_all_encoded_layers`: boolean which controls the content of the `encoded_layers` output as described below. Default: `True`. 1229 | `head_mask`: an optional torch.Tensor of shape [num_heads] or [num_layers, num_heads] with indices between 0 and 1. 1230 | It's a mask to be used to nullify some heads of the transformer. 1.0 => head is fully masked, 0.0 => head is not masked. 1231 | 1232 | 1233 | Outputs: Tuple of (encoded_layers, pooled_output) 1234 | `encoded_layers`: controled by `output_all_encoded_layers` argument: 1235 | - `output_all_encoded_layers=True`: outputs a list of the full sequences of encoded-hidden-states at the end 1236 | of each attention block (i.e. 12 full sequences for BERT-base, 24 for BERT-large), each 1237 | encoded-hidden-state is a torch.FloatTensor of size [batch_size, sequence_length, hidden_size], 1238 | - `output_all_encoded_layers=False`: outputs only the full sequence of hidden-states corresponding 1239 | to the last attention block of shape [batch_size, sequence_length, hidden_size], 1240 | `pooled_output`: a torch.FloatTensor of size [batch_size, hidden_size] which is the output of a 1241 | classifier pretrained on top of the hidden state associated to the first character of the 1242 | input (`CLS`) to train on the Next-Sentence task (see BERT's paper). 1243 | 1244 | Example usage: 1245 | ```python 1246 | # Already been converted into WordPiece token ids 1247 | input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) 1248 | input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) 1249 | token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]]) 1250 | 1251 | config = modeling.BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768, 1252 | num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072) 1253 | 1254 | model = modeling.BertModel(config=config) 1255 | all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask) 1256 | ``` 1257 | """ 1258 | 1259 | def __init__(self, config, output_attentions=False, keep_multihead_output=False): 1260 | super(BertModel, self).__init__(config) 1261 | self.output_attentions = output_attentions 1262 | self.embeddings = BertEmbeddings(config) 1263 | self.encoder = BertEncoder( 1264 | config, output_attentions=output_attentions, keep_multihead_output=keep_multihead_output 1265 | ) 1266 | self.pooler = BertPooler(config) 1267 | self.apply(self.init_bert_weights) 1268 | 1269 | def prune_heads(self, heads_to_prune): 1270 | """ Prunes heads of the model. 1271 | heads_to_prune: dict of {layer_num: list of heads to prune in this layer} 1272 | """ 1273 | for layer, heads in heads_to_prune.items(): 1274 | self.encoder.layer[layer].attention.prune_heads(heads) 1275 | 1276 | def get_multihead_outputs(self): 1277 | """ Gather all multi-head outputs. 1278 | Return: list (layers) of multihead module outputs with gradients 1279 | """ 1280 | return [layer.attention.self.multihead_output for layer in self.encoder.layer] 1281 | 1282 | def forward( 1283 | self, 1284 | input_ids, 1285 | token_type_ids=None, 1286 | attention_mask=None, 1287 | output_all_encoded_layers=True, 1288 | head_mask=None, 1289 | ): 1290 | if attention_mask is None: 1291 | attention_mask = torch.ones_like(input_ids) 1292 | if token_type_ids is None: 1293 | token_type_ids = torch.zeros_like(input_ids) 1294 | 1295 | # We create a 3D attention mask from a 2D tensor mask. 1296 | # Sizes are [batch_size, 1, 1, to_seq_length] 1297 | # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] 1298 | # this attention mask is more simple than the triangular masking of causal attention 1299 | # used in OpenAI GPT, we just need to prepare the broadcast dimension here. 1300 | extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) 1301 | 1302 | # Since attention_mask is 1.0 for positions we want to attend and 0.0 for 1303 | # masked positions, this operation will create a tensor which is 0.0 for 1304 | # positions we want to attend and -10000.0 for masked positions. 1305 | # Since we are adding it to the raw scores before the softmax, this is 1306 | # effectively the same as removing these entirely. 1307 | extended_attention_mask = extended_attention_mask.to( 1308 | dtype=next(self.parameters()).dtype 1309 | ) # fp16 compatibility 1310 | extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 1311 | 1312 | # Prepare head mask if needed 1313 | # 1.0 in head_mask indicate we mask the head 1314 | # attention_probs has shape bsz x n_heads x N x N 1315 | # head_mask has shape num_hidden_layers x batch x n_heads x N x N 1316 | if head_mask is not None: 1317 | if head_mask.dim() == 1: 1318 | head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1) 1319 | head_mask = head_mask.expand_as(self.config.num_hidden_layers, -1, -1, -1, -1) 1320 | elif head_mask.dim() == 2: 1321 | head_mask = ( 1322 | head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1) 1323 | ) # We can specify head_mask for each layer 1324 | head_mask = head_mask.to( 1325 | dtype=next(self.parameters()).dtype 1326 | ) # switch to fload if need + fp16 compatibility 1327 | head_mask = 1.0 - head_mask 1328 | else: 1329 | head_mask = [None] * self.config.num_hidden_layers 1330 | 1331 | embedding_output = self.embeddings(input_ids, token_type_ids) 1332 | encoded_layers = self.encoder( 1333 | embedding_output, 1334 | extended_attention_mask, 1335 | output_all_encoded_layers=output_all_encoded_layers, 1336 | head_mask=head_mask, 1337 | ) 1338 | 1339 | if self.output_attentions: 1340 | all_attentions, encoded_layers = encoded_layers 1341 | 1342 | sequence_output = encoded_layers[-1] 1343 | pooled_output = self.pooler(sequence_output) 1344 | if not output_all_encoded_layers: 1345 | encoded_layers = encoded_layers[-1] 1346 | if self.output_attentions: 1347 | return all_attentions, encoded_layers, pooled_output 1348 | return encoded_layers, pooled_output 1349 | 1350 | 1351 | class BertForPreTraining(BertPreTrainedModel): 1352 | """BERT model with pre-training heads. 1353 | This module comprises the BERT model followed by the two pre-training heads: 1354 | - the masked language modeling head, and 1355 | - the next sentence classification head. 1356 | 1357 | Params: 1358 | `config`: a BertConfig class instance with the configuration to build a new model 1359 | `output_attentions`: If True, also output attentions weights computed by the model at each layer. Default: False 1360 | `keep_multihead_output`: If True, saves output of the multi-head attention module with its gradient. 1361 | This can be used to compute head importance metrics. Default: False 1362 | 1363 | Inputs: 1364 | `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] 1365 | with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts 1366 | `extract_features.py`, `run_classifier.py` and `run_squad.py`) 1367 | `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token 1368 | types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to 1369 | a `sentence B` token (see BERT paper for more details). 1370 | `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices 1371 | selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max 1372 | input sequence length in the current batch. It's the mask that we typically use for attention when 1373 | a batch has varying length sentences. 1374 | `masked_lm_labels`: optional masked language modeling labels: torch.LongTensor of shape [batch_size, sequence_length] 1375 | with indices selected in [-1, 0, ..., vocab_size]. All labels set to -1 are ignored (masked), the loss 1376 | is only computed for the labels set in [0, ..., vocab_size] 1377 | `next_sentence_label`: optional next sentence classification loss: torch.LongTensor of shape [batch_size] 1378 | with indices selected in [0, 1]. 1379 | 0 => next sentence is the continuation, 1 => next sentence is a random sentence. 1380 | `head_mask`: an optional torch.Tensor of shape [num_heads] or [num_layers, num_heads] with indices between 0 and 1. 1381 | It's a mask to be used to nullify some heads of the transformer. 1.0 => head is fully masked, 0.0 => head is not masked. 1382 | 1383 | Outputs: 1384 | if `masked_lm_labels` and `next_sentence_label` are not `None`: 1385 | Outputs the total_loss which is the sum of the masked language modeling loss and the next 1386 | sentence classification loss. 1387 | if `masked_lm_labels` or `next_sentence_label` is `None`: 1388 | Outputs a tuple comprising 1389 | - the masked language modeling logits of shape [batch_size, sequence_length, vocab_size], and 1390 | - the next sentence classification logits of shape [batch_size, 2]. 1391 | 1392 | Example usage: 1393 | ```python 1394 | # Already been converted into WordPiece token ids 1395 | input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) 1396 | input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) 1397 | token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]]) 1398 | 1399 | config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768, 1400 | num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072) 1401 | 1402 | model = BertForPreTraining(config) 1403 | masked_lm_logits_scores, seq_relationship_logits = model(input_ids, token_type_ids, input_mask) 1404 | ``` 1405 | """ 1406 | 1407 | def __init__(self, config, output_attentions=False, keep_multihead_output=False): 1408 | super(BertForPreTraining, self).__init__(config) 1409 | self.output_attentions = output_attentions 1410 | self.bert = BertModel( 1411 | config, output_attentions=output_attentions, keep_multihead_output=keep_multihead_output 1412 | ) 1413 | self.cls = BertPreTrainingHeads(config, self.bert.embeddings.word_embeddings.weight) 1414 | self.apply(self.init_bert_weights) 1415 | 1416 | def forward( 1417 | self, 1418 | input_ids, 1419 | token_type_ids=None, 1420 | attention_mask=None, 1421 | masked_lm_labels=None, 1422 | next_sentence_label=None, 1423 | head_mask=None, 1424 | ): 1425 | outputs = self.bert( 1426 | input_ids, 1427 | token_type_ids, 1428 | attention_mask, 1429 | output_all_encoded_layers=False, 1430 | head_mask=head_mask, 1431 | ) 1432 | if self.output_attentions: 1433 | all_attentions, sequence_output, pooled_output = outputs 1434 | else: 1435 | sequence_output, pooled_output = outputs 1436 | prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output) 1437 | 1438 | if masked_lm_labels is not None and next_sentence_label is not None: 1439 | loss_fct = CrossEntropyLoss(ignore_index=-1) 1440 | masked_lm_loss = loss_fct( 1441 | prediction_scores.view(-1, self.config.vocab_size), masked_lm_labels.view(-1) 1442 | ) 1443 | next_sentence_loss = loss_fct( 1444 | seq_relationship_score.view(-1, 2), next_sentence_label.view(-1) 1445 | ) 1446 | total_loss = masked_lm_loss + next_sentence_loss 1447 | return total_loss 1448 | elif self.output_attentions: 1449 | return all_attentions, prediction_scores, seq_relationship_score 1450 | return prediction_scores, seq_relationship_score 1451 | 1452 | 1453 | class BertForMaskedLM(BertPreTrainedModel): 1454 | """BERT model with the masked language modeling head. 1455 | This module comprises the BERT model followed by the masked language modeling head. 1456 | 1457 | Params: 1458 | `config`: a BertConfig class instance with the configuration to build a new model 1459 | `output_attentions`: If True, also output attentions weights computed by the model at each layer. Default: False 1460 | `keep_multihead_output`: If True, saves output of the multi-head attention module with its gradient. 1461 | This can be used to compute head importance metrics. Default: False 1462 | 1463 | Inputs: 1464 | `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] 1465 | with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts 1466 | `extract_features.py`, `run_classifier.py` and `run_squad.py`) 1467 | `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token 1468 | types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to 1469 | a `sentence B` token (see BERT paper for more details). 1470 | `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices 1471 | selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max 1472 | input sequence length in the current batch. It's the mask that we typically use for attention when 1473 | a batch has varying length sentences. 1474 | `masked_lm_labels`: masked language modeling labels: torch.LongTensor of shape [batch_size, sequence_length] 1475 | with indices selected in [-1, 0, ..., vocab_size]. All labels set to -1 are ignored (masked), the loss 1476 | is only computed for the labels set in [0, ..., vocab_size] 1477 | `head_mask`: an optional torch.LongTensor of shape [num_heads] with indices 1478 | selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max 1479 | input sequence length in the current batch. It's the mask that we typically use for attention when 1480 | a batch has varying length sentences. 1481 | `head_mask`: an optional torch.Tensor of shape [num_heads] or [num_layers, num_heads] with indices between 0 and 1. 1482 | It's a mask to be used to nullify some heads of the transformer. 1.0 => head is fully masked, 0.0 => head is not masked. 1483 | 1484 | Outputs: 1485 | if `masked_lm_labels` is not `None`: 1486 | Outputs the masked language modeling loss. 1487 | if `masked_lm_labels` is `None`: 1488 | Outputs the masked language modeling logits of shape [batch_size, sequence_length, vocab_size]. 1489 | 1490 | Example usage: 1491 | ```python 1492 | # Already been converted into WordPiece token ids 1493 | input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) 1494 | input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) 1495 | token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]]) 1496 | 1497 | config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768, 1498 | num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072) 1499 | 1500 | model = BertForMaskedLM(config) 1501 | masked_lm_logits_scores = model(input_ids, token_type_ids, input_mask) 1502 | ``` 1503 | """ 1504 | 1505 | def __init__(self, config, output_attentions=False, keep_multihead_output=False): 1506 | super(BertForMaskedLM, self).__init__(config) 1507 | self.output_attentions = output_attentions 1508 | self.bert = BertModel( 1509 | config, output_attentions=output_attentions, keep_multihead_output=keep_multihead_output 1510 | ) 1511 | self.cls = BertOnlyMLMHead(config, self.bert.embeddings.word_embeddings.weight) 1512 | self.apply(self.init_bert_weights) 1513 | 1514 | def forward( 1515 | self, 1516 | input_ids, 1517 | token_type_ids=None, 1518 | attention_mask=None, 1519 | masked_lm_labels=None, 1520 | head_mask=None, 1521 | ): 1522 | outputs = self.bert( 1523 | input_ids, 1524 | token_type_ids, 1525 | attention_mask, 1526 | output_all_encoded_layers=False, 1527 | head_mask=head_mask, 1528 | ) 1529 | if self.output_attentions: 1530 | all_attentions, sequence_output, _ = outputs 1531 | else: 1532 | sequence_output, _ = outputs 1533 | prediction_scores = self.cls(sequence_output) 1534 | 1535 | if masked_lm_labels is not None: 1536 | loss_fct = CrossEntropyLoss(ignore_index=-1) 1537 | masked_lm_loss = loss_fct( 1538 | prediction_scores.view(-1, self.config.vocab_size), masked_lm_labels.view(-1) 1539 | ) 1540 | return masked_lm_loss 1541 | elif self.output_attentions: 1542 | return all_attentions, prediction_scores 1543 | return prediction_scores 1544 | 1545 | 1546 | class BertForNextSentencePrediction(BertPreTrainedModel): 1547 | """BERT model with next sentence prediction head. 1548 | This module comprises the BERT model followed by the next sentence classification head. 1549 | 1550 | Params: 1551 | `config`: a BertConfig class instance with the configuration to build a new model 1552 | `output_attentions`: If True, also output attentions weights computed by the model at each layer. Default: False 1553 | `keep_multihead_output`: If True, saves output of the multi-head attention module with its gradient. 1554 | This can be used to compute head importance metrics. Default: False 1555 | 1556 | Inputs: 1557 | `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] 1558 | with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts 1559 | `extract_features.py`, `run_classifier.py` and `run_squad.py`) 1560 | `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token 1561 | types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to 1562 | a `sentence B` token (see BERT paper for more details). 1563 | `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices 1564 | selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max 1565 | input sequence length in the current batch. It's the mask that we typically use for attention when 1566 | a batch has varying length sentences. 1567 | `next_sentence_label`: next sentence classification loss: torch.LongTensor of shape [batch_size] 1568 | with indices selected in [0, 1]. 1569 | 0 => next sentence is the continuation, 1 => next sentence is a random sentence. 1570 | `head_mask`: an optional torch.Tensor of shape [num_heads] or [num_layers, num_heads] with indices between 0 and 1. 1571 | It's a mask to be used to nullify some heads of the transformer. 1.0 => head is fully masked, 0.0 => head is not masked. 1572 | 1573 | Outputs: 1574 | if `next_sentence_label` is not `None`: 1575 | Outputs the total_loss which is the sum of the masked language modeling loss and the next 1576 | sentence classification loss. 1577 | if `next_sentence_label` is `None`: 1578 | Outputs the next sentence classification logits of shape [batch_size, 2]. 1579 | 1580 | Example usage: 1581 | ```python 1582 | # Already been converted into WordPiece token ids 1583 | input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) 1584 | input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) 1585 | token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]]) 1586 | 1587 | config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768, 1588 | num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072) 1589 | 1590 | model = BertForNextSentencePrediction(config) 1591 | seq_relationship_logits = model(input_ids, token_type_ids, input_mask) 1592 | ``` 1593 | """ 1594 | 1595 | def __init__(self, config, output_attentions=False, keep_multihead_output=False): 1596 | super(BertForNextSentencePrediction, self).__init__(config) 1597 | self.output_attentions = output_attentions 1598 | self.bert = BertModel( 1599 | config, output_attentions=output_attentions, keep_multihead_output=keep_multihead_output 1600 | ) 1601 | self.cls = BertOnlyNSPHead(config) 1602 | self.apply(self.init_bert_weights) 1603 | 1604 | def forward( 1605 | self, 1606 | input_ids, 1607 | token_type_ids=None, 1608 | attention_mask=None, 1609 | next_sentence_label=None, 1610 | head_mask=None, 1611 | ): 1612 | outputs = self.bert( 1613 | input_ids, 1614 | token_type_ids, 1615 | attention_mask, 1616 | output_all_encoded_layers=False, 1617 | head_mask=head_mask, 1618 | ) 1619 | if self.output_attentions: 1620 | all_attentions, _, pooled_output = outputs 1621 | else: 1622 | _, pooled_output = outputs 1623 | seq_relationship_score = self.cls(pooled_output) 1624 | 1625 | if next_sentence_label is not None: 1626 | loss_fct = CrossEntropyLoss(ignore_index=-1) 1627 | next_sentence_loss = loss_fct( 1628 | seq_relationship_score.view(-1, 2), next_sentence_label.view(-1) 1629 | ) 1630 | return next_sentence_loss 1631 | elif self.output_attentions: 1632 | return all_attentions, seq_relationship_score 1633 | return seq_relationship_score 1634 | 1635 | 1636 | class BertForSequenceClassification(BertPreTrainedModel): 1637 | """BERT model for classification. 1638 | This module is composed of the BERT model with a linear layer on top of 1639 | the pooled output. 1640 | 1641 | Params: 1642 | `config`: a BertConfig class instance with the configuration to build a new model 1643 | `output_attentions`: If True, also output attentions weights computed by the model at each layer. Default: False 1644 | `keep_multihead_output`: If True, saves output of the multi-head attention module with its gradient. 1645 | This can be used to compute head importance metrics. Default: False 1646 | `num_labels`: the number of classes for the classifier. Default = 2. 1647 | 1648 | Inputs: 1649 | `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] 1650 | with the word token indices in the vocabulary. Items in the batch should begin with the special "CLS" token. (see the tokens preprocessing logic in the scripts 1651 | `extract_features.py`, `run_classifier.py` and `run_squad.py`) 1652 | `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token 1653 | types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to 1654 | a `sentence B` token (see BERT paper for more details). 1655 | `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices 1656 | selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max 1657 | input sequence length in the current batch. It's the mask that we typically use for attention when 1658 | a batch has varying length sentences. 1659 | `labels`: labels for the classification output: torch.LongTensor of shape [batch_size] 1660 | with indices selected in [0, ..., num_labels]. 1661 | `head_mask`: an optional torch.Tensor of shape [num_heads] or [num_layers, num_heads] with indices between 0 and 1. 1662 | It's a mask to be used to nullify some heads of the transformer. 1.0 => head is fully masked, 0.0 => head is not masked. 1663 | 1664 | Outputs: 1665 | if `labels` is not `None`: 1666 | Outputs the CrossEntropy classification loss of the output with the labels. 1667 | if `labels` is `None`: 1668 | Outputs the classification logits of shape [batch_size, num_labels]. 1669 | 1670 | Example usage: 1671 | ```python 1672 | # Already been converted into WordPiece token ids 1673 | input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) 1674 | input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) 1675 | token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]]) 1676 | 1677 | config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768, 1678 | num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072) 1679 | 1680 | num_labels = 2 1681 | 1682 | model = BertForSequenceClassification(config, num_labels) 1683 | logits = model(input_ids, token_type_ids, input_mask) 1684 | ``` 1685 | """ 1686 | 1687 | def __init__(self, config, num_labels=2, output_attentions=False, keep_multihead_output=False): 1688 | super(BertForSequenceClassification, self).__init__(config) 1689 | self.output_attentions = output_attentions 1690 | self.num_labels = num_labels 1691 | self.bert = BertModel( 1692 | config, output_attentions=output_attentions, keep_multihead_output=keep_multihead_output 1693 | ) 1694 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 1695 | self.classifier = nn.Linear(config.hidden_size, num_labels) 1696 | self.apply(self.init_bert_weights) 1697 | 1698 | def forward( 1699 | self, input_ids, token_type_ids=None, attention_mask=None, labels=None, head_mask=None 1700 | ): 1701 | outputs = self.bert( 1702 | input_ids, 1703 | token_type_ids, 1704 | attention_mask, 1705 | output_all_encoded_layers=False, 1706 | head_mask=head_mask, 1707 | ) 1708 | if self.output_attentions: 1709 | all_attentions, _, pooled_output = outputs 1710 | else: 1711 | _, pooled_output = outputs 1712 | pooled_output = self.dropout(pooled_output) 1713 | logits = self.classifier(pooled_output) 1714 | 1715 | if labels is not None: 1716 | loss_fct = CrossEntropyLoss() 1717 | loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) 1718 | return loss 1719 | elif self.output_attentions: 1720 | return all_attentions, logits 1721 | return logits 1722 | 1723 | 1724 | class BertForMultipleChoice(BertPreTrainedModel): 1725 | """BERT model for multiple choice tasks. 1726 | This module is composed of the BERT model with a linear layer on top of 1727 | the pooled output. 1728 | 1729 | Params: 1730 | `config`: a BertConfig class instance with the configuration to build a new model 1731 | `output_attentions`: If True, also output attentions weights computed by the model at each layer. Default: False 1732 | `keep_multihead_output`: If True, saves output of the multi-head attention module with its gradient. 1733 | This can be used to compute head importance metrics. Default: False 1734 | `num_choices`: the number of classes for the classifier. Default = 2. 1735 | 1736 | Inputs: 1737 | `input_ids`: a torch.LongTensor of shape [batch_size, num_choices, sequence_length] 1738 | with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts 1739 | `extract_features.py`, `run_classifier.py` and `run_squad.py`) 1740 | `token_type_ids`: an optional torch.LongTensor of shape [batch_size, num_choices, sequence_length] 1741 | with the token types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` 1742 | and type 1 corresponds to a `sentence B` token (see BERT paper for more details). 1743 | `attention_mask`: an optional torch.LongTensor of shape [batch_size, num_choices, sequence_length] with indices 1744 | selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max 1745 | input sequence length in the current batch. It's the mask that we typically use for attention when 1746 | a batch has varying length sentences. 1747 | `labels`: labels for the classification output: torch.LongTensor of shape [batch_size] 1748 | with indices selected in [0, ..., num_choices]. 1749 | `head_mask`: an optional torch.Tensor of shape [num_heads] or [num_layers, num_heads] with indices between 0 and 1. 1750 | It's a mask to be used to nullify some heads of the transformer. 1.0 => head is fully masked, 0.0 => head is not masked. 1751 | 1752 | Outputs: 1753 | if `labels` is not `None`: 1754 | Outputs the CrossEntropy classification loss of the output with the labels. 1755 | if `labels` is `None`: 1756 | Outputs the classification logits of shape [batch_size, num_labels]. 1757 | 1758 | Example usage: 1759 | ```python 1760 | # Already been converted into WordPiece token ids 1761 | input_ids = torch.LongTensor([[[31, 51, 99], [15, 5, 0]], [[12, 16, 42], [14, 28, 57]]]) 1762 | input_mask = torch.LongTensor([[[1, 1, 1], [1, 1, 0]],[[1,1,0], [1, 0, 0]]]) 1763 | token_type_ids = torch.LongTensor([[[0, 0, 1], [0, 1, 0]],[[0, 1, 1], [0, 0, 1]]]) 1764 | config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768, 1765 | num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072) 1766 | 1767 | num_choices = 2 1768 | 1769 | model = BertForMultipleChoice(config, num_choices) 1770 | logits = model(input_ids, token_type_ids, input_mask) 1771 | ``` 1772 | """ 1773 | 1774 | def __init__(self, config, num_choices=2, output_attentions=False, keep_multihead_output=False): 1775 | super(BertForMultipleChoice, self).__init__(config) 1776 | self.output_attentions = output_attentions 1777 | self.num_choices = num_choices 1778 | self.bert = BertModel( 1779 | config, output_attentions=output_attentions, keep_multihead_output=keep_multihead_output 1780 | ) 1781 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 1782 | self.classifier = nn.Linear(config.hidden_size, 1) 1783 | self.apply(self.init_bert_weights) 1784 | 1785 | def forward( 1786 | self, input_ids, token_type_ids=None, attention_mask=None, labels=None, head_mask=None 1787 | ): 1788 | flat_input_ids = input_ids.view(-1, input_ids.size(-1)) 1789 | flat_token_type_ids = ( 1790 | token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None 1791 | ) 1792 | flat_attention_mask = ( 1793 | attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None 1794 | ) 1795 | outputs = self.bert( 1796 | flat_input_ids, 1797 | flat_token_type_ids, 1798 | flat_attention_mask, 1799 | output_all_encoded_layers=False, 1800 | head_mask=head_mask, 1801 | ) 1802 | if self.output_attentions: 1803 | all_attentions, _, pooled_output = outputs 1804 | else: 1805 | _, pooled_output = outputs 1806 | pooled_output = self.dropout(pooled_output) 1807 | logits = self.classifier(pooled_output) 1808 | reshaped_logits = logits.view(-1, self.num_choices) 1809 | 1810 | if labels is not None: 1811 | loss_fct = CrossEntropyLoss() 1812 | loss = loss_fct(reshaped_logits, labels) 1813 | return loss 1814 | elif self.output_attentions: 1815 | return all_attentions, reshaped_logits 1816 | return reshaped_logits 1817 | 1818 | 1819 | class BertForTokenClassification(BertPreTrainedModel): 1820 | """BERT model for token-level classification. 1821 | This module is composed of the BERT model with a linear layer on top of 1822 | the full hidden state of the last layer. 1823 | 1824 | Params: 1825 | `config`: a BertConfig class instance with the configuration to build a new model 1826 | `output_attentions`: If True, also output attentions weights computed by the model at each layer. Default: False 1827 | `keep_multihead_output`: If True, saves output of the multi-head attention module with its gradient. 1828 | This can be used to compute head importance metrics. Default: False 1829 | `num_labels`: the number of classes for the classifier. Default = 2. 1830 | 1831 | Inputs: 1832 | `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] 1833 | with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts 1834 | `extract_features.py`, `run_classifier.py` and `run_squad.py`) 1835 | `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token 1836 | types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to 1837 | a `sentence B` token (see BERT paper for more details). 1838 | `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices 1839 | selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max 1840 | input sequence length in the current batch. It's the mask that we typically use for attention when 1841 | a batch has varying length sentences. 1842 | `labels`: labels for the classification output: torch.LongTensor of shape [batch_size, sequence_length] 1843 | with indices selected in [0, ..., num_labels]. 1844 | `head_mask`: an optional torch.Tensor of shape [num_heads] or [num_layers, num_heads] with indices between 0 and 1. 1845 | It's a mask to be used to nullify some heads of the transformer. 1.0 => head is fully masked, 0.0 => head is not masked. 1846 | 1847 | Outputs: 1848 | if `labels` is not `None`: 1849 | Outputs the CrossEntropy classification loss of the output with the labels. 1850 | if `labels` is `None`: 1851 | Outputs the classification logits of shape [batch_size, sequence_length, num_labels]. 1852 | 1853 | Example usage: 1854 | ```python 1855 | # Already been converted into WordPiece token ids 1856 | input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) 1857 | input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) 1858 | token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]]) 1859 | 1860 | config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768, 1861 | num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072) 1862 | 1863 | num_labels = 2 1864 | 1865 | model = BertForTokenClassification(config, num_labels) 1866 | logits = model(input_ids, token_type_ids, input_mask) 1867 | ``` 1868 | """ 1869 | 1870 | def __init__(self, config, num_labels=2, output_attentions=False, keep_multihead_output=False): 1871 | super(BertForTokenClassification, self).__init__(config) 1872 | self.output_attentions = output_attentions 1873 | self.num_labels = num_labels 1874 | self.bert = BertModel( 1875 | config, output_attentions=output_attentions, keep_multihead_output=keep_multihead_output 1876 | ) 1877 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 1878 | self.classifier = nn.Linear(config.hidden_size, num_labels) 1879 | self.apply(self.init_bert_weights) 1880 | 1881 | def forward( 1882 | self, input_ids, token_type_ids=None, attention_mask=None, labels=None, head_mask=None 1883 | ): 1884 | outputs = self.bert( 1885 | input_ids, 1886 | token_type_ids, 1887 | attention_mask, 1888 | output_all_encoded_layers=False, 1889 | head_mask=head_mask, 1890 | ) 1891 | if self.output_attentions: 1892 | all_attentions, sequence_output, _ = outputs 1893 | else: 1894 | sequence_output, _ = outputs 1895 | sequence_output = self.dropout(sequence_output) 1896 | logits = self.classifier(sequence_output) 1897 | 1898 | if labels is not None: 1899 | loss_fct = CrossEntropyLoss() 1900 | # Only keep active parts of the loss 1901 | if attention_mask is not None: 1902 | active_loss = attention_mask.view(-1) == 1 1903 | active_logits = logits.view(-1, self.num_labels)[active_loss] 1904 | active_labels = labels.view(-1)[active_loss] 1905 | loss = loss_fct(active_logits, active_labels) 1906 | else: 1907 | loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) 1908 | return loss 1909 | elif self.output_attentions: 1910 | return all_attentions, logits 1911 | return logits 1912 | 1913 | 1914 | class BertForQuestionAnswering(BertPreTrainedModel): 1915 | """BERT model for Question Answering (span extraction). 1916 | This module is composed of the BERT model with a linear layer on top of 1917 | the sequence output that computes start_logits and end_logits 1918 | 1919 | Params: 1920 | `config`: a BertConfig class instance with the configuration to build a new model 1921 | `output_attentions`: If True, also output attentions weights computed by the model at each layer. Default: False 1922 | `keep_multihead_output`: If True, saves output of the multi-head attention module with its gradient. 1923 | This can be used to compute head importance metrics. Default: False 1924 | 1925 | Inputs: 1926 | `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] 1927 | with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts 1928 | `extract_features.py`, `run_classifier.py` and `run_squad.py`) 1929 | `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token 1930 | types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to 1931 | a `sentence B` token (see BERT paper for more details). 1932 | `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices 1933 | selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max 1934 | input sequence length in the current batch. It's the mask that we typically use for attention when 1935 | a batch has varying length sentences. 1936 | `start_positions`: position of the first token for the labeled span: torch.LongTensor of shape [batch_size]. 1937 | Positions are clamped to the length of the sequence and position outside of the sequence are not taken 1938 | into account for computing the loss. 1939 | `end_positions`: position of the last token for the labeled span: torch.LongTensor of shape [batch_size]. 1940 | Positions are clamped to the length of the sequence and position outside of the sequence are not taken 1941 | into account for computing the loss. 1942 | `head_mask`: an optional torch.Tensor of shape [num_heads] or [num_layers, num_heads] with indices between 0 and 1. 1943 | It's a mask to be used to nullify some heads of the transformer. 1.0 => head is fully masked, 0.0 => head is not masked. 1944 | 1945 | Outputs: 1946 | if `start_positions` and `end_positions` are not `None`: 1947 | Outputs the total_loss which is the sum of the CrossEntropy loss for the start and end token positions. 1948 | if `start_positions` or `end_positions` is `None`: 1949 | Outputs a tuple of start_logits, end_logits which are the logits respectively for the start and end 1950 | position tokens of shape [batch_size, sequence_length]. 1951 | 1952 | Example usage: 1953 | ```python 1954 | # Already been converted into WordPiece token ids 1955 | input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) 1956 | input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) 1957 | token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]]) 1958 | 1959 | config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768, 1960 | num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072) 1961 | 1962 | model = BertForQuestionAnswering(config) 1963 | start_logits, end_logits = model(input_ids, token_type_ids, input_mask) 1964 | ``` 1965 | """ 1966 | 1967 | def __init__(self, config, output_attentions=False, keep_multihead_output=False): 1968 | super(BertForQuestionAnswering, self).__init__(config) 1969 | self.output_attentions = output_attentions 1970 | self.bert = BertModel( 1971 | config, output_attentions=output_attentions, keep_multihead_output=keep_multihead_output 1972 | ) 1973 | self.qa_outputs = nn.Linear(config.hidden_size, 2) 1974 | self.apply(self.init_bert_weights) 1975 | 1976 | def forward( 1977 | self, 1978 | input_ids, 1979 | token_type_ids=None, 1980 | attention_mask=None, 1981 | start_positions=None, 1982 | end_positions=None, 1983 | head_mask=None, 1984 | ): 1985 | outputs = self.bert( 1986 | input_ids, 1987 | token_type_ids, 1988 | attention_mask, 1989 | output_all_encoded_layers=False, 1990 | head_mask=head_mask, 1991 | ) 1992 | if self.output_attentions: 1993 | all_attentions, sequence_output, _ = outputs 1994 | else: 1995 | sequence_output, _ = outputs 1996 | logits = self.qa_outputs(sequence_output) 1997 | start_logits, end_logits = logits.split(1, dim=-1) 1998 | start_logits = start_logits.squeeze(-1) 1999 | end_logits = end_logits.squeeze(-1) 2000 | 2001 | if start_positions is not None and end_positions is not None: 2002 | # If we are on multi-GPU, split add a dimension 2003 | if len(start_positions.size()) > 1: 2004 | start_positions = start_positions.squeeze(-1) 2005 | if len(end_positions.size()) > 1: 2006 | end_positions = end_positions.squeeze(-1) 2007 | # sometimes the start/end positions are outside our model inputs, we ignore these terms 2008 | ignored_index = start_logits.size(1) 2009 | start_positions.clamp_(0, ignored_index) 2010 | end_positions.clamp_(0, ignored_index) 2011 | 2012 | loss_fct = CrossEntropyLoss(ignore_index=ignored_index) 2013 | start_loss = loss_fct(start_logits, start_positions) 2014 | end_loss = loss_fct(end_logits, end_positions) 2015 | total_loss = (start_loss + end_loss) / 2 2016 | return total_loss 2017 | elif self.output_attentions: 2018 | return all_attentions, start_logits, end_logits 2019 | return start_logits, end_logits 2020 | -------------------------------------------------------------------------------- /models/bert_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Utilities for working with the local dataset cache. 3 | This file is adapted from the AllenNLP library at https://github.com/allenai/allennlp 4 | Copyright by the AllenNLP authors. 5 | """ 6 | from __future__ import (absolute_import, division, print_function, unicode_literals) 7 | 8 | import sys 9 | import json 10 | import logging 11 | import os 12 | import shutil 13 | import tempfile 14 | import fnmatch 15 | from functools import wraps 16 | from hashlib import sha256 17 | import sys 18 | from io import open 19 | 20 | import boto3 21 | import requests 22 | from botocore.exceptions import ClientError 23 | from tqdm import tqdm 24 | 25 | try: 26 | from torch.hub import _get_torch_home 27 | torch_cache_home = _get_torch_home() 28 | except ImportError: 29 | torch_cache_home = os.path.expanduser( 30 | os.getenv('TORCH_HOME', os.path.join( 31 | os.getenv('XDG_CACHE_HOME', '~/.cache'), 'torch'))) 32 | default_cache_path = os.path.join(torch_cache_home, 'pytorch_pretrained_bert') 33 | 34 | try: 35 | from urllib.parse import urlparse 36 | except ImportError: 37 | from urlparse import urlparse 38 | 39 | try: 40 | from pathlib import Path 41 | PYTORCH_PRETRAINED_BERT_CACHE = Path( 42 | os.getenv('PYTORCH_PRETRAINED_BERT_CACHE', default_cache_path)) 43 | except (AttributeError, ImportError): 44 | PYTORCH_PRETRAINED_BERT_CACHE = os.getenv('PYTORCH_PRETRAINED_BERT_CACHE', 45 | default_cache_path) 46 | 47 | CONFIG_NAME = "config.json" 48 | WEIGHTS_NAME = "pytorch_model.bin" 49 | 50 | logger = logging.getLogger(__name__) # pylint: disable=invalid-name 51 | 52 | 53 | def url_to_filename(url, etag=None): 54 | """ 55 | Convert `url` into a hashed filename in a repeatable way. 56 | If `etag` is specified, append its hash to the url's, delimited 57 | by a period. 58 | """ 59 | url_bytes = url.encode('utf-8') 60 | url_hash = sha256(url_bytes) 61 | filename = url_hash.hexdigest() 62 | 63 | if etag: 64 | etag_bytes = etag.encode('utf-8') 65 | etag_hash = sha256(etag_bytes) 66 | filename += '.' + etag_hash.hexdigest() 67 | 68 | return filename 69 | 70 | 71 | def filename_to_url(filename, cache_dir=None): 72 | """ 73 | Return the url and etag (which may be ``None``) stored for `filename`. 74 | Raise ``EnvironmentError`` if `filename` or its stored metadata do not exist. 75 | """ 76 | if cache_dir is None: 77 | cache_dir = PYTORCH_PRETRAINED_BERT_CACHE 78 | if sys.version_info[0] == 3 and isinstance(cache_dir, Path): 79 | cache_dir = str(cache_dir) 80 | 81 | cache_path = os.path.join(cache_dir, filename) 82 | if not os.path.exists(cache_path): 83 | raise EnvironmentError("file {} not found".format(cache_path)) 84 | 85 | meta_path = cache_path + '.json' 86 | if not os.path.exists(meta_path): 87 | raise EnvironmentError("file {} not found".format(meta_path)) 88 | 89 | with open(meta_path, encoding="utf-8") as meta_file: 90 | metadata = json.load(meta_file) 91 | url = metadata['url'] 92 | etag = metadata['etag'] 93 | 94 | return url, etag 95 | 96 | 97 | def cached_path(url_or_filename, cache_dir=None): 98 | """ 99 | Given something that might be a URL (or might be a local path), 100 | determine which. If it's a URL, download the file and cache it, and 101 | return the path to the cached file. If it's already a local path, 102 | make sure the file exists and then return the path. 103 | """ 104 | if cache_dir is None: 105 | cache_dir = PYTORCH_PRETRAINED_BERT_CACHE 106 | if sys.version_info[0] == 3 and isinstance(url_or_filename, Path): 107 | url_or_filename = str(url_or_filename) 108 | if sys.version_info[0] == 3 and isinstance(cache_dir, Path): 109 | cache_dir = str(cache_dir) 110 | 111 | parsed = urlparse(url_or_filename) 112 | 113 | if parsed.scheme in ('http', 'https', 's3'): 114 | # URL, so get it from the cache (downloading if necessary) 115 | return get_from_cache(url_or_filename, cache_dir) 116 | elif os.path.exists(url_or_filename): 117 | # File, and it exists. 118 | return url_or_filename 119 | elif parsed.scheme == '': 120 | # File, but it doesn't exist. 121 | raise EnvironmentError("file {} not found".format(url_or_filename)) 122 | else: 123 | # Something unknown 124 | raise ValueError("unable to parse {} as a URL or as a local path".format(url_or_filename)) 125 | 126 | 127 | def split_s3_path(url): 128 | """Split a full s3 path into the bucket name and path.""" 129 | parsed = urlparse(url) 130 | if not parsed.netloc or not parsed.path: 131 | raise ValueError("bad s3 path {}".format(url)) 132 | bucket_name = parsed.netloc 133 | s3_path = parsed.path 134 | # Remove '/' at beginning of path. 135 | if s3_path.startswith("/"): 136 | s3_path = s3_path[1:] 137 | return bucket_name, s3_path 138 | 139 | 140 | def s3_request(func): 141 | """ 142 | Wrapper function for s3 requests in order to create more helpful error 143 | messages. 144 | """ 145 | 146 | @wraps(func) 147 | def wrapper(url, *args, **kwargs): 148 | try: 149 | return func(url, *args, **kwargs) 150 | except ClientError as exc: 151 | if int(exc.response["Error"]["Code"]) == 404: 152 | raise EnvironmentError("file {} not found".format(url)) 153 | else: 154 | raise 155 | 156 | return wrapper 157 | 158 | 159 | @s3_request 160 | def s3_etag(url): 161 | """Check ETag on S3 object.""" 162 | s3_resource = boto3.resource("s3") 163 | bucket_name, s3_path = split_s3_path(url) 164 | s3_object = s3_resource.Object(bucket_name, s3_path) 165 | return s3_object.e_tag 166 | 167 | 168 | @s3_request 169 | def s3_get(url, temp_file): 170 | """Pull a file directly from S3.""" 171 | s3_resource = boto3.resource("s3") 172 | bucket_name, s3_path = split_s3_path(url) 173 | s3_resource.Bucket(bucket_name).download_fileobj(s3_path, temp_file) 174 | 175 | 176 | def http_get(url, temp_file): 177 | req = requests.get(url, stream=True) 178 | content_length = req.headers.get('Content-Length') 179 | total = int(content_length) if content_length is not None else None 180 | progress = tqdm(unit="B", total=total) 181 | for chunk in req.iter_content(chunk_size=1024): 182 | if chunk: # filter out keep-alive new chunks 183 | progress.update(len(chunk)) 184 | temp_file.write(chunk) 185 | progress.close() 186 | 187 | 188 | def get_from_cache(url, cache_dir=None): 189 | """ 190 | Given a URL, look for the corresponding dataset in the local cache. 191 | If it's not there, download it. Then return the path to the cached file. 192 | """ 193 | if cache_dir is None: 194 | cache_dir = PYTORCH_PRETRAINED_BERT_CACHE 195 | if sys.version_info[0] == 3 and isinstance(cache_dir, Path): 196 | cache_dir = str(cache_dir) 197 | 198 | if not os.path.exists(cache_dir): 199 | os.makedirs(cache_dir) 200 | 201 | # Get eTag to add to filename, if it exists. 202 | if url.startswith("s3://"): 203 | etag = s3_etag(url) 204 | else: 205 | try: 206 | response = requests.head(url, allow_redirects=True) 207 | if response.status_code != 200: 208 | etag = None 209 | else: 210 | etag = response.headers.get("ETag") 211 | except EnvironmentError: 212 | etag = None 213 | 214 | if sys.version_info[0] == 2 and etag is not None: 215 | etag = etag.decode('utf-8') 216 | filename = url_to_filename(url, etag) 217 | 218 | # get cache path to put the file 219 | cache_path = os.path.join(cache_dir, filename) 220 | 221 | # If we don't have a connection (etag is None) and can't identify the file 222 | # try to get the last downloaded one 223 | if not os.path.exists(cache_path) and etag is None: 224 | matching_files = fnmatch.filter(os.listdir(cache_dir), filename + '.*') 225 | matching_files = list(filter(lambda s: not s.endswith('.json'), matching_files)) 226 | if matching_files: 227 | cache_path = os.path.join(cache_dir, matching_files[-1]) 228 | 229 | if not os.path.exists(cache_path): 230 | # Download to temporary file, then copy to cache dir once finished. 231 | # Otherwise you get corrupt cache entries if the download gets interrupted. 232 | with tempfile.NamedTemporaryFile() as temp_file: 233 | logger.info("%s not found in cache, downloading to %s", url, temp_file.name) 234 | 235 | # GET file object 236 | if url.startswith("s3://"): 237 | s3_get(url, temp_file) 238 | else: 239 | http_get(url, temp_file) 240 | 241 | # we are copying the file before closing it, so flush to avoid truncation 242 | temp_file.flush() 243 | # shutil.copyfileobj() starts at the current position, so go to the start 244 | temp_file.seek(0) 245 | 246 | logger.info("copying %s to cache at %s", temp_file.name, cache_path) 247 | with open(cache_path, 'wb') as cache_file: 248 | shutil.copyfileobj(temp_file, cache_file) 249 | 250 | logger.info("creating metadata file for %s", cache_path) 251 | meta = {'url': url, 'etag': etag} 252 | meta_path = cache_path + '.json' 253 | with open(meta_path, 'w') as meta_file: 254 | output_string = json.dumps(meta) 255 | if sys.version_info[0] == 2 and isinstance(output_string, str): 256 | output_string = unicode(output_string, 'utf-8') # The beauty of python 2 257 | meta_file.write(output_string) 258 | 259 | logger.info("removing temp file %s", temp_file.name) 260 | 261 | return cache_path 262 | 263 | 264 | def read_set_from_file(filename): 265 | ''' 266 | Extract a de-duped collection (set) of text from a file. 267 | Expected file format is one item per line. 268 | ''' 269 | collection = set() 270 | with open(filename, 'r', encoding='utf-8') as file_: 271 | for line in file_: 272 | collection.add(line.rstrip()) 273 | return collection 274 | 275 | 276 | def get_file_extension(path, dot=True, lower=True): 277 | ext = os.path.splitext(path)[1] 278 | ext = ext if dot else ext[1:] 279 | return ext.lower() if lower else ext 280 | -------------------------------------------------------------------------------- /models/gaussian.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numbers 3 | 4 | 5 | def gaussian_kernel_2d(mean, std_inv, size): 6 | """Create a 2D gaussian kernel 7 | 8 | Args: 9 | mean: center of the gaussian filter (shift from origin) 10 | (2, ) vector 11 | std_inv: standard deviation $Sigma^{-1/2}$ 12 | can be a single number, a vector of dimension 2, or a 2x2 matrix 13 | size: size of the kernel 14 | pair of integer for width and height 15 | or single number will be used for both width and height 16 | 17 | Returns: 18 | A gaussian kernel of shape size. 19 | """ 20 | if type(mean) is torch.Tensor: 21 | device = mean.device 22 | elif type(std_inv) is torch.Tensor: 23 | device = std_inv.device 24 | else: 25 | device = "cpu" 26 | 27 | # repeat the size for width, height if single number 28 | if isinstance(size, numbers.Number): 29 | width = height = size 30 | else: 31 | width, height = size 32 | 33 | # expand std to (2, 2) matrix 34 | if isinstance(std_inv, numbers.Number): 35 | std_inv = torch.tensor([[std_inv, 0], [0, std_inv]], device=device) 36 | elif std_inv.dim() == 0: 37 | std_inv = torch.diag(std_inv.repeat(2)) 38 | elif std_inv.dim() == 1: 39 | assert len(std_inv) == 2 40 | std_inv = torch.diag(std_inv) 41 | 42 | # Enforce PSD of covariance matrix 43 | covariance_inv = std_inv.transpose(0, 1) @ std_inv 44 | covariance_inv = covariance_inv.float() 45 | 46 | # make a grid (width, height, 2) 47 | X = torch.cat( 48 | [ 49 | t.unsqueeze(-1) 50 | for t in reversed( 51 | torch.meshgrid( 52 | [torch.arange(s, device=device) for s in [width, height]] 53 | ) 54 | ) 55 | ], 56 | dim=-1, 57 | ) 58 | X = X.float() 59 | 60 | # center the gaussian in (0, 0) and then shift to mean 61 | X -= torch.tensor([(width - 1) / 2, (height - 1) / 2], device=device).float() 62 | X -= mean.float() 63 | 64 | # does not use the normalize constant of gaussian distribution 65 | Y = torch.exp((-1 / 2) * torch.einsum("xyi,ij,xyj->xy", [X, covariance_inv, X])) 66 | 67 | # normalize 68 | # TODO could compute the correct normalization (1/2pi det ...) 69 | # and send warning if there is a significant diff 70 | # -> part of the gaussian is outside the kernel 71 | Z = Y / Y.sum() 72 | return Z 73 | -------------------------------------------------------------------------------- /models/resnet.py: -------------------------------------------------------------------------------- 1 | """ResNet in PyTorch. 2 | 3 | For Pre-activation ResNet, see 'preact_resnet.py'. 4 | 5 | Reference: 6 | [1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun 7 | Deep Residual Learning for Image Recognition. arXiv:1512.03385 8 | """ 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | 13 | from torch.autograd import Variable 14 | 15 | 16 | class BasicBlock(nn.Module): 17 | expansion = 1 18 | 19 | def __init__(self, in_planes, planes, stride=1, use_batchnorm=True): 20 | super(BasicBlock, self).__init__() 21 | self.conv1 = nn.Conv2d( 22 | in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False 23 | ) 24 | self.bn1 = nn.BatchNorm2d(planes) 25 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 26 | self.bn2 = nn.BatchNorm2d(planes) 27 | 28 | if not use_batchnorm: 29 | self.bn1 = self.bn2 = nn.Sequential() 30 | 31 | self.shortcut = nn.Sequential() 32 | if stride != 1 or in_planes != self.expansion * planes: 33 | self.shortcut = nn.Sequential( 34 | nn.Conv2d( 35 | in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False 36 | ), 37 | nn.BatchNorm2d(self.expansion * planes) if use_batchnorm else nn.Sequential(), 38 | ) 39 | 40 | def forward(self, x): 41 | out = F.relu(self.bn1(self.conv1(x))) 42 | out = self.bn2(self.conv2(out)) 43 | out += self.shortcut(x) 44 | out = F.relu(out) 45 | return out 46 | 47 | 48 | class Bottleneck(nn.Module): 49 | expansion = 4 50 | 51 | def __init__(self, in_planes, planes, stride=1, use_batchnorm=True): 52 | super(Bottleneck, self).__init__() 53 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) 54 | self.bn1 = nn.BatchNorm2d(planes) 55 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 56 | self.bn2 = nn.BatchNorm2d(planes) 57 | self.conv3 = nn.Conv2d(planes, self.expansion * planes, kernel_size=1, bias=False) 58 | self.bn3 = nn.BatchNorm2d(self.expansion * planes) 59 | 60 | if not use_batchnorm: 61 | self.bn1 = self.bn2 = self.bn3 = nn.Sequential() 62 | 63 | self.shortcut = nn.Sequential() 64 | if stride != 1 or in_planes != self.expansion * planes: 65 | self.shortcut = nn.Sequential( 66 | nn.Conv2d( 67 | in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False 68 | ), 69 | nn.BatchNorm2d(self.expansion * planes) if use_batchnorm else nn.Sequential(), 70 | ) 71 | 72 | def forward(self, x): 73 | out = F.relu(self.bn1(self.conv1(x))) 74 | out = F.relu(self.bn2(self.conv2(out))) 75 | out = self.bn3(self.conv3(out)) 76 | out += self.shortcut(x) 77 | out = F.relu(out) 78 | return out 79 | 80 | 81 | class ResNet(nn.Module): 82 | def __init__(self, block, num_blocks, num_classes=10, use_batchnorm=True): 83 | super(ResNet, self).__init__() 84 | self.in_planes = 64 85 | self.use_batchnorm = use_batchnorm 86 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) 87 | self.bn1 = nn.BatchNorm2d(64) if use_batchnorm else nn.Sequential() 88 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) 89 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) 90 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) 91 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) 92 | self.linear = nn.Linear(512 * block.expansion, num_classes) 93 | 94 | def _make_layer(self, block, planes, num_blocks, stride): 95 | strides = [stride] + [1] * (num_blocks - 1) 96 | layers = [] 97 | for stride in strides: 98 | layers.append(block(self.in_planes, planes, stride, self.use_batchnorm)) 99 | self.in_planes = planes * block.expansion 100 | return nn.Sequential(*layers) 101 | 102 | def forward(self, x): 103 | out = F.relu(self.bn1(self.conv1(x))) 104 | out = self.layer1(out) 105 | out = self.layer2(out) 106 | out = self.layer3(out) 107 | out = self.layer4(out) 108 | out = F.avg_pool2d(out, 4) 109 | out = out.view(out.size(0), -1) 110 | out = self.linear(out) 111 | return out 112 | 113 | 114 | def ResNet10(num_classes=10, use_batchnorm=True): 115 | return ResNet(BasicBlock, [1, 1, 1, 1], num_classes=num_classes, use_batchnorm=use_batchnorm) 116 | 117 | 118 | def ResNet18(num_classes=10, use_batchnorm=True): 119 | return ResNet(BasicBlock, [2, 2, 2, 2], num_classes=num_classes, use_batchnorm=use_batchnorm) 120 | 121 | 122 | def ResNet34(num_classes=10, use_batchnorm=True): 123 | return ResNet(BasicBlock, [3, 4, 6, 3], num_classes=num_classes, use_batchnorm=use_batchnorm) 124 | 125 | 126 | def ResNet50(num_classes=10, use_batchnorm=True): 127 | return ResNet(Bottleneck, [3, 4, 6, 3], num_classes=num_classes, use_batchnorm=use_batchnorm) 128 | 129 | 130 | def ResNet101(num_classes=10, use_batchnorm=True): 131 | return ResNet(Bottleneck, [3, 4, 23, 3], num_classes=num_classes, use_batchnorm=use_batchnorm) 132 | 133 | 134 | def ResNet152(num_classes=10, use_batchnorm=True): 135 | return ResNet(Bottleneck, [3, 8, 36, 3], num_classes=num_classes, use_batchnorm=use_batchnorm) 136 | 137 | 138 | def test(): 139 | net = ResNet18() 140 | y = net(Variable(torch.randn(1, 3, 32, 32))) 141 | print(y.size()) 142 | 143 | 144 | # test() 145 | -------------------------------------------------------------------------------- /models/transformer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.nn.parameter import Parameter 5 | import random 6 | from .bert import BertEncoder, BertConfig 7 | import torchvision.models as models 8 | from torch.autograd import Variable 9 | from enum import Enum 10 | 11 | 12 | class ResBottom(nn.Module): 13 | def __init__(self, origin_model, block_num=1): 14 | super(ResBottom, self).__init__() 15 | self.seq = nn.Sequential(*list(origin_model.children())[0 : (4 + block_num)]) 16 | 17 | def forward(self, batch): 18 | return self.seq(batch) 19 | 20 | 21 | class BertImage(nn.Module): 22 | """ 23 | Wrapper for a Bert encoder 24 | """ 25 | 26 | def __init__(self, config, num_classes, output_attentions=False): 27 | super().__init__() 28 | 29 | self.output_attentions = output_attentions 30 | self.with_resnet = config["pooling_use_resnet"] 31 | self.hidden_size = config["hidden_size"] 32 | self.pooling_concatenate_size = config["pooling_concatenate_size"] 33 | assert (config["pooling_concatenate_size"] == 1) or ( 34 | not config["pooling_use_resnet"] 35 | ), "Use either resnet or pooling_concatenate_size" 36 | 37 | 38 | if self.with_resnet: 39 | res50 = models.resnet50(pretrained=True) 40 | self.extract_feature = ResBottom(res50) 41 | 42 | # compute downscale factor and channel at output of ResNet 43 | _, num_channels_in, new_width, new_height = self.extract_feature( 44 | torch.rand(1, 3, 1024, 1024) 45 | ).shape 46 | self.feature_downscale_factor = 1024 // new_width 47 | elif self.pooling_concatenate_size > 1: 48 | num_channels_in = 3 * (self.pooling_concatenate_size ** 2) 49 | else: 50 | num_channels_in = 3 51 | 52 | bert_config = BertConfig.from_dict(config) 53 | 54 | self.features_upscale = nn.Linear(num_channels_in, self.hidden_size) 55 | # self.features_downscale = nn.Linear(self.hidden_size, num_channels_in) 56 | 57 | # output all attentions, won't return them if self.output_attentions is False 58 | self.encoder = BertEncoder(bert_config, output_attentions=True) 59 | self.classifier = nn.Linear(self.hidden_size, num_classes) 60 | # self.pixelizer = nn.Linear(self.hidden_size, 3) 61 | self.register_buffer("attention_mask", torch.tensor(1.0)) 62 | 63 | # self.mask_embedding = Parameter(torch.zeros(self.hidden_size)) 64 | # self.cls_embedding = Parameter(torch.zeros(self.hidden_size)) 65 | # self.reset_parameters() 66 | 67 | def reset_parameters(self): 68 | # self.mask_embedding.data.normal_(mean=0.0, std=0.01) 69 | # self.cls_embedding.data.normal_(mean=0.0, std=0.01) # TODO no hard coded 70 | # self.positional_encoding.reset_parameters() 71 | pass 72 | 73 | def random_masking(self, batch_images, batch_mask, device): 74 | """ 75 | with probability 10% we keep the image unchanged; 76 | with probability 10% we change the mask region to a normal distribution 77 | with 80% we mask the region as 0. 78 | :param batch_images: image to be masked 79 | :param batch_mask: mask region 80 | :param device: 81 | :return: masked image 82 | """ 83 | return batch_images 84 | # TODO disabled 85 | temp = random.random() 86 | if temp > 0.1: 87 | batch_images = batch_images * batch_mask.unsqueeze(1).float() 88 | if temp < 0.2: 89 | batch_images = batch_images + ( 90 | ((-batch_mask.unsqueeze(1).float()) + 1) 91 | * torch.normal(mean=0.5, std=torch.ones(batch_images.shape)).to(device) 92 | ) 93 | return batch_images 94 | 95 | def prune_heads(self, heads_to_prune): 96 | """ Prunes heads of the model. 97 | heads_to_prune: dict of {layer_num: list of heads to prune in this layer} 98 | """ 99 | for layer, heads in heads_to_prune.items(): 100 | self.encoder.layer[layer].attention.prune_heads(heads) 101 | 102 | def reset_heads(self, heads_to_reset): 103 | """ Prunes heads of the model. 104 | heads_to_prune: dict of {layer_num: list of heads to prune in this layer} 105 | """ 106 | for layer, heads in heads_to_reset.items(): 107 | self.encoder.layer[layer].attention.reset_heads(heads) 108 | 109 | def forward(self, batch_images, batch_mask=None, feature_mask=None): 110 | 111 | """ 112 | Replace masked pixels with 0s 113 | If ResNet 114 | | compute features 115 | | downscale the mask 116 | Replace masked pixels/features by MSK token 117 | Use Bert encoder 118 | """ 119 | device = batch_images.device 120 | 121 | # compute ResNet features 122 | if self.with_resnet: 123 | 124 | # replace masked pixels with 0, batch_images has NCHW format 125 | batch_features_unmasked = self.extract_feature(batch_images) 126 | 127 | if batch_mask is not None: 128 | batch_images = self.random_masking(batch_images, batch_mask, device) 129 | batch_features = self.extract_feature(batch_images) 130 | else: 131 | batch_features = batch_features_unmasked 132 | 133 | # downscale the mask 134 | if batch_mask is not None: 135 | # downsample the mask 136 | # mask any downsampled pixel if it contained one masked pixel originialy 137 | feature_mask = ~( 138 | F.max_pool2d((~batch_mask).float(), self.feature_downscale_factor).byte() 139 | ) 140 | # reshape from NCHW to NHWC 141 | batch_features = batch_features.permute(0, 2, 3, 1) 142 | 143 | elif self.pooling_concatenate_size > 1: 144 | 145 | def downsample_concatenate(X, kernel): 146 | """X is of shape B x H x W x C 147 | return shape B x (kernel*H) x (kernel*W) x (kernel*kernel*C) 148 | """ 149 | b, h, w, c = X.shape 150 | Y = X.contiguous().view(b, h, w // kernel, c * kernel) 151 | Y = Y.permute(0, 2, 1, 3).contiguous() 152 | Y = Y.view(b, w // kernel, h // kernel, kernel * kernel * c).contiguous() 153 | Y = Y.permute(0, 2, 1, 3).contiguous() 154 | return Y 155 | 156 | # reshape from NCHW to NHWC 157 | batch_features = batch_images.permute(0, 2, 3, 1) 158 | batch_features = downsample_concatenate(batch_features, self.pooling_concatenate_size) 159 | feature_mask = None 160 | if batch_mask is not None: 161 | feature_mask = batch_mask[ 162 | :, :: self.pooling_concatenate_size, :: self.pooling_concatenate_size 163 | ] 164 | 165 | else: 166 | batch_features = batch_images 167 | feature_mask = batch_mask 168 | # reshape from NCHW to NHWC 169 | batch_features = batch_features.permute(0, 2, 3, 1) 170 | 171 | # feature upscale to BERT dimension 172 | batch_features = self.features_upscale(batch_features) 173 | 174 | b, w, h, _ = batch_features.shape 175 | 176 | all_attentions, all_representations = self.encoder( 177 | batch_features, 178 | attention_mask=self.attention_mask, 179 | output_all_encoded_layers=False, 180 | ) 181 | 182 | representations = all_representations[0] 183 | 184 | # mean pool for representation (features for classification) 185 | cls_representation = representations.view(b, -1, representations.shape[-1]).mean(dim=1) 186 | cls_prediction = self.classifier(cls_representation) 187 | 188 | if self.output_attentions: 189 | return cls_prediction, all_attentions 190 | else: 191 | return cls_prediction 192 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | boto3==1.9.234 2 | tabulate==0.8.3 3 | tensorboardX==1.8 4 | termcolor==1.1.0 5 | thop==0.0.31 6 | -------------------------------------------------------------------------------- /runs/learned/run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # OUTPUTDIR is directory containing this run.sh script 4 | OUTPUTDIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )" 5 | 6 | python train.py \ 7 | --num_hidden_layers 6 \ 8 | --num_attention_heads 9 \ 9 | --optimizer_cosine_lr True \ 10 | --optimizer_warmup_ratio 0.05 \ 11 | --batch_size 100 \ 12 | --num_epochs 300 \ 13 | --hidden_size 400 \ 14 | --use_learned_2d_encoding True \ 15 | --use_gaussian_attention False \ 16 | --num_keep_checkpoints 30 \ 17 | --output_dir $OUTPUTDIR 18 | -------------------------------------------------------------------------------- /runs/many-heads/run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # OUTPUTDIR is directory containing this run.sh script 4 | OUTPUTDIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )" 5 | 6 | python train.py \ 7 | --num_hidden_layers 6 \ 8 | --num_attention_heads 16 \ 9 | --optimizer_cosine_lr True \ 10 | --optimizer_warmup_ratio 0.05 \ 11 | --batch_size 100 \ 12 | --num_epochs 300 \ 13 | --hidden_size 400 \ 14 | --attention_isotropic_gaussian True \ 15 | --num_keep_checkpoints 30 \ 16 | --output_dir $OUTPUTDIR 17 | -------------------------------------------------------------------------------- /runs/quadratic-generalized-pruned/run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # OUTPUTDIR is directory containing this run.sh script 4 | OUTPUTDIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )" 5 | 6 | python train.py \ 7 | --num_hidden_layers 6 \ 8 | --num_attention_heads 9 \ 9 | --optimizer_cosine_lr True \ 10 | --optimizer_learning_rate 0.01 \ 11 | --optimizer_warmup_ratio 0.05 \ 12 | --batch_size 100 \ 13 | --num_epochs 100 \ 14 | --hidden_size 400 \ 15 | --attention_isotropic_gaussian False \ 16 | --num_keep_checkpoints 10 \ 17 | --prune_degenerated_heads True \ 18 | --load_checkpoint_file ./quadratic-generalized/final.checkpoint \ 19 | --output_dir $OUTPUTDIR 20 | -------------------------------------------------------------------------------- /runs/quadratic-generalized/run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # OUTPUTDIR is directory containing this run.sh script 4 | OUTPUTDIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )" 5 | 6 | python train.py \ 7 | --num_hidden_layers 6 \ 8 | --num_attention_heads 9 \ 9 | --optimizer_cosine_lr True \ 10 | --optimizer_warmup_ratio 0.05 \ 11 | --batch_size 100 \ 12 | --num_epochs 300 \ 13 | --hidden_size 400 \ 14 | --attention_isotropic_gaussian False \ 15 | --num_keep_checkpoints 30 \ 16 | --output_dir $OUTPUTDIR 17 | -------------------------------------------------------------------------------- /runs/quadratic/run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # OUTPUTDIR is directory containing this run.sh script 4 | OUTPUTDIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )" 5 | 6 | python train.py \ 7 | --num_hidden_layers 6 \ 8 | --num_attention_heads 9 \ 9 | --optimizer_cosine_lr True \ 10 | --optimizer_warmup_ratio 0.05 \ 11 | --batch_size 100 \ 12 | --num_epochs 300 \ 13 | --hidden_size 400 \ 14 | --attention_isotropic_gaussian True \ 15 | --num_keep_checkpoints 30 \ 16 | --output_dir $OUTPUTDIR 17 | -------------------------------------------------------------------------------- /runs/resnet/run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # OUTPUTDIR is directory containing this run.sh script 4 | OUTPUTDIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )" 5 | 6 | python train.py \ 7 | --model "resnet18" \ 8 | --batch_size 100 \ 9 | --num_epochs 300 \ 10 | --output_dir $OUTPUTDIR 11 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | 4 | import numpy as np 5 | import torch 6 | from torch.utils.data import TensorDataset 7 | import torchvision 8 | from tqdm import tqdm 9 | import yaml 10 | import enum 11 | from enum import Enum 12 | import argparse 13 | from tensorboardX import SummaryWriter 14 | from collections import OrderedDict 15 | from termcolor import colored 16 | import tabulate 17 | 18 | import models 19 | from utils.data import MaskedDataset 20 | from utils.logging import get_num_parameter, human_format, DummySummaryWriter, sizeof_fmt 21 | from utils.plotting import plot_attention_positions_all_layers 22 | from utils.config import parse_cli_overides 23 | from utils.learning_rate import linear_warmup_cosine_lr_scheduler 24 | import utils.accumulators 25 | 26 | 27 | # fmt: off 28 | config = OrderedDict( 29 | dataset="Cifar10", 30 | model="bert", 31 | load_checkpoint_file=None, 32 | no_cuda=False, 33 | 34 | # === OPTIMIZER === 35 | optimizer="SGD", 36 | optimizer_cosine_lr=False, 37 | optimizer_warmup_ratio=0.0, # period of linear increase for lr scheduler 38 | optimizer_decay_at_epochs=[80, 150, 250], 39 | optimizer_decay_with_factor=10.0, 40 | optimizer_learning_rate=0.1, 41 | optimizer_momentum=0.9, 42 | optimizer_weight_decay=0.0001, 43 | batch_size=300, 44 | num_epochs=300, 45 | seed=42, 46 | 47 | # === From BERT === 48 | vocab_size_or_config_json_file=-1, 49 | hidden_size=128, # 768, 50 | position_encoding_size=-1, # dimension of the position embedding for relative attention, if -1 will default to hidden_size 51 | num_hidden_layers=2, 52 | num_attention_heads=8, 53 | intermediate_size=512, 54 | hidden_act="gelu", 55 | hidden_dropout_prob=0.1, 56 | attention_probs_dropout_prob=0.1, 57 | max_position_embeddings=16, 58 | type_vocab_size=2, 59 | initializer_range=0.02, 60 | layer_norm_eps=1e-12, 61 | 62 | # === BERT IMAGE=== 63 | add_positional_encoding_to_input=False, 64 | use_learned_2d_encoding=False, 65 | share_position_encoding=False, # share learned relative position encoding for all layers 66 | use_attention_data=False, # use attention between pixel values instead of only positional (q.k attention) 67 | query_positional_score=False, # use q.r attention (see Ramachandran, 2019) 68 | use_gaussian_attention=True, 69 | attention_isotropic_gaussian=False, 70 | prune_degenerated_heads=False, # remove heads with Sigma^{-1} close to 0 or very singular (kappa > 1000) at epoch 0 71 | reset_degenerated_heads=False, # reinitialize randomly the heads mentioned above 72 | fix_original_heads_position=False, # original heads (not pruned/reinit) position are fixed to their original value 73 | fix_original_heads_weights=False, # original heads (not pruned/reinit) value matrix are fixed to their original value 74 | gaussian_spread_regularizer=0., # penalize singular covariance gaussian attention 75 | 76 | gaussian_init_sigma_std=0.01, 77 | gaussian_init_mu_std=2., 78 | attention_gaussian_blur_trick=False, # use a computational trick for gaussian attention to avoid computing the attention probas 79 | pooling_concatenate_size=2, # concatenate the pixels value by patch of pooling_concatenate_size x pooling_concatenate_size to redude dimension 80 | pooling_use_resnet=False, 81 | 82 | # === LOGGING === 83 | only_list_parameters=False, 84 | num_keep_checkpoints=0, 85 | plot_attention_positions=True, 86 | output_dir="./output.tmp", 87 | ) 88 | # fmt: on 89 | 90 | output_dir = "./output.tmp" # Can be overwritten by a script calling this 91 | 92 | 93 | def main(): 94 | """ 95 | Train a model 96 | You can either call this script directly (using the default parameters), 97 | or import it as a module, override config and run main() 98 | :return: scalar of the best accuracy 99 | """ 100 | 101 | """ 102 | Directory structure: 103 | 104 | output_dir 105 | |-- config.yaml 106 | |-- best.checkpoint 107 | |-- last.checkpoint 108 | |-- tensorboard logs... 109 | """ 110 | 111 | global output_dir 112 | output_dir = config["output_dir"] 113 | os.makedirs(output_dir, exist_ok = True) 114 | 115 | # save config in YAML file 116 | store_config() 117 | 118 | # create tensorboard writter 119 | writer = SummaryWriter(logdir=output_dir, max_queue=100, flush_secs=10) 120 | print(f"Tensorboard logs saved in '{output_dir}'") 121 | 122 | # Set the seed 123 | torch.manual_seed(config["seed"]) 124 | np.random.seed(config["seed"]) 125 | 126 | # We will run on CUDA if there is a GPU available 127 | device = torch.device("cuda" if not config["no_cuda"] and torch.cuda.is_available() else "cpu") 128 | 129 | # Configure the dataset, model and the optimizer based on the global 130 | # `config` dictionary. 131 | training_loader, test_loader = get_dataset(test_batch_size=config["batch_size"]) 132 | model = get_model(device) 133 | 134 | print_parameters(model) 135 | if config["only_list_parameters"]: 136 | print_flops(model) 137 | 138 | if config["load_checkpoint_file"] is not None: 139 | restore_checkpoint(config["load_checkpoint_file"], model, device) 140 | 141 | # for each layer, which heads position to block list[list[int]] 142 | original_heads_per_layer = None 143 | 144 | if config["prune_degenerated_heads"]: 145 | assert config["model"] == "bert" and config["use_gaussian_attention"] 146 | with torch.no_grad(): 147 | heads_to_prune = find_degenerated_heads(model) 148 | model.prune_heads(heads_to_prune) 149 | original_heads_per_layer = [ 150 | torch.tensor(list(range(model.encoder.layer[layer_idx].attention.self.num_attention_heads))) 151 | for layer_idx in range(config["num_hidden_layers"]) 152 | ] 153 | 154 | print_parameters(model) 155 | print_flops(model) 156 | 157 | if config["reset_degenerated_heads"]: 158 | assert config["model"] == "bert" and config["use_gaussian_attention"] 159 | with torch.no_grad(): 160 | heads_to_reset = find_degenerated_heads(model) 161 | model.reset_heads(heads_to_reset) 162 | original_heads_per_layer = [ 163 | torch.tensor([ 164 | head_idx 165 | for head_idx in range(model.encoder.layer[layer_idx].attention.self.num_attention_heads) 166 | if head_idx not in heads_to_reset.get(layer_idx, []) 167 | ]) 168 | for layer_idx in range(config["num_hidden_layers"]) 169 | ] 170 | 171 | if config["only_list_parameters"]: 172 | exit() 173 | 174 | max_steps = config["num_epochs"] 175 | if config["optimizer_cosine_lr"]: 176 | max_steps *= len(training_loader.dataset) // config["batch_size"] + 1 177 | 178 | optimizer, scheduler = get_optimizer(model.named_parameters(), max_steps) 179 | criterion = torch.nn.CrossEntropyLoss() 180 | 181 | # We keep track of the best accuracy so far to store checkpoints 182 | best_accuracy_so_far = utils.accumulators.Max() 183 | checkpoint_every_n_epoch = None 184 | if config["num_keep_checkpoints"] > 0: 185 | checkpoint_every_n_epoch = max(1, config["num_epochs"] // config["num_keep_checkpoints"]) 186 | else: 187 | checkpoint_every_n_epoch = 999999999999 188 | global_step = 0 189 | 190 | for epoch in range(config["num_epochs"]): 191 | print("Epoch {:03d}".format(epoch)) 192 | 193 | if ( 194 | "bert" in config["model"] 195 | and config["plot_attention_positions"] 196 | and (config["use_gaussian_attention"] or config["use_learned_2d_encoding"]) 197 | ): 198 | if not config["attention_gaussian_blur_trick"]: 199 | plot_attention_positions_all_layers(model, 9, writer, epoch) 200 | else: 201 | # TODO plot gaussian without attention weights 202 | pass 203 | 204 | # Enable training mode (automatic differentiation + batch norm) 205 | model.train() 206 | 207 | # Update the optimizer's learning rate 208 | if config["optimizer_cosine_lr"]: 209 | scheduler.step(global_step) 210 | else: 211 | scheduler.step() 212 | writer.add_scalar("train/lr", scheduler.get_lr()[0], global_step) 213 | 214 | # Keep track of statistics during training 215 | mean_train_accuracy = utils.accumulators.Mean() 216 | mean_train_loss = utils.accumulators.Mean() 217 | 218 | for batch_x, batch_y in tqdm(training_loader): 219 | 220 | batch_x, batch_y = batch_x.to(device), batch_y.to(device) 221 | 222 | batch_size, _, width, height = batch_x.shape 223 | 224 | # Compute gradients for the batch 225 | optimizer.zero_grad() 226 | 227 | if config["pooling_use_resnet"]: 228 | # , image_out, reconstruction, reconstruction_mask 229 | prediction = model(batch_x) # , batch_mask) 230 | else: 231 | # prediction, image_out 232 | prediction = model(batch_x) # , batch_mask) 233 | # reconstruction = batch_x 234 | # reconstruction_mask = batch_mask 235 | 236 | classification_loss = criterion(prediction, batch_y) 237 | loss = classification_loss 238 | 239 | if config["gaussian_spread_regularizer"] > 0: 240 | gaussian_regularizer_loss = config["gaussian_spread_regularizer"] * get_singular_gaussian_penalty(model) 241 | loss += gaussian_regularizer_loss 242 | 243 | acc = accuracy(prediction, batch_y) 244 | 245 | loss.backward() 246 | 247 | # set blocked gradient to 0 248 | if config["fix_original_heads_position"] and original_heads_per_layer is not None: 249 | for layer_idx, heads_to_fix in enumerate(original_heads_per_layer): 250 | model.encoder.layer[layer_idx].attention.self.attention_spreads.grad[heads_to_fix].zero_() 251 | model.encoder.layer[layer_idx].attention.self.attention_centers.grad[heads_to_fix].zero_() 252 | 253 | if config["fix_original_heads_weights"] and original_heads_per_layer is not None: 254 | for layer_idx, heads_to_fix in enumerate(original_heads_per_layer): 255 | layer = model.encoder.layer[layer_idx] 256 | n_head = layer.attention.self.num_attention_heads 257 | d_head = layer.attention.self.attention_head_size 258 | mask = torch.zeros([n_head, d_head], dtype=torch.bool) 259 | for head in heads_to_fix: 260 | mask[head] = 1 261 | mask = mask.view(-1) 262 | layer.attention.self.value.weight.grad[:, mask].zero_() 263 | 264 | # Do an optimizer step 265 | optimizer.step() 266 | 267 | writer.add_scalar("train/loss", loss, global_step) 268 | writer.add_scalar("train/classification-loss", classification_loss, global_step) 269 | if config["gaussian_spread_regularizer"] > 0: 270 | writer.add_scalar("train/gaussian_regularizer_loss", gaussian_regularizer_loss, global_step) 271 | writer.add_scalar("train/accuracy", acc, global_step) 272 | 273 | global_step += 1 274 | 275 | # Store the statistics 276 | mean_train_loss.add(loss.item(), weight=len(batch_x)) 277 | mean_train_accuracy.add(acc.item(), weight=len(batch_x)) 278 | 279 | 280 | # Log training stats 281 | log_metric( 282 | "accuracy", {"epoch": epoch, "value": mean_train_accuracy.value()}, {"split": "train"} 283 | ) 284 | log_metric( 285 | "cross_entropy", {"epoch": epoch, "value": mean_train_loss.value()}, {"split": "train"} 286 | ) 287 | log_metric("lr", {"epoch": epoch, "value": scheduler.get_lr()[0]}, {}) 288 | 289 | # Evaluation 290 | with torch.no_grad(): 291 | model.eval() 292 | mean_test_accuracy = utils.accumulators.Mean() 293 | mean_test_loss = utils.accumulators.Mean() 294 | for batch_x, batch_y in test_loader: 295 | batch_x, batch_y = batch_x.to(device), batch_y.to(device) 296 | prediction = model(batch_x) 297 | loss = criterion(prediction, batch_y) 298 | acc = accuracy(prediction, batch_y) 299 | mean_test_loss.add(loss.item(), weight=len(batch_x)) 300 | mean_test_accuracy.add(acc.item(), weight=len(batch_x)) 301 | 302 | # Log test stats 303 | log_metric( 304 | "accuracy", {"epoch": epoch, "value": mean_test_accuracy.value()}, {"split": "test"} 305 | ) 306 | log_metric( 307 | "cross_entropy", {"epoch": epoch, "value": mean_test_loss.value()}, {"split": "test"} 308 | ) 309 | writer.add_scalar("eval/classification_loss", mean_test_loss.value(), epoch) 310 | writer.add_scalar("eval/accuracy", mean_test_accuracy.value(), epoch) 311 | 312 | # Store checkpoints for the best model so far 313 | is_best_so_far = best_accuracy_so_far.add(mean_test_accuracy.value()) 314 | if is_best_so_far: 315 | store_checkpoint("best.checkpoint", model, epoch, mean_test_accuracy.value()) 316 | if epoch % checkpoint_every_n_epoch == 0: 317 | store_checkpoint("{:04d}.checkpoint".format(epoch), model, epoch, mean_test_accuracy.value()) 318 | 319 | # Store a final checkpoint 320 | store_checkpoint( 321 | "final.checkpoint", model, config["num_epochs"] - 1, mean_test_accuracy.value() 322 | ) 323 | writer.close() 324 | 325 | # Return the optimal accuracy, could be used for learning rate tuning 326 | return best_accuracy_so_far.value() 327 | 328 | 329 | def accuracy(predicted_logits, reference): 330 | """Compute the ratio of correctly predicted labels""" 331 | labels = torch.argmax(predicted_logits, 1) 332 | correct_predictions = labels.eq(reference) 333 | return correct_predictions.sum().float() / correct_predictions.nelement() 334 | 335 | 336 | def log_metric(name, values, tags): 337 | """ 338 | Log timeseries data. 339 | Placeholder implementation. 340 | This function should be overwritten by any script that runs this as a module. 341 | """ 342 | print("{name}: {values} ({tags})".format(name=name, values=values, tags=tags)) 343 | 344 | 345 | def get_dataset(test_batch_size=100, shuffle_train=True, num_workers=2, data_root="./data"): 346 | """ 347 | Create dataset loaders for the chosen dataset 348 | :return: Tuple (training_loader, test_loader) 349 | """ 350 | if config["dataset"] == "Cifar10": 351 | dataset = torchvision.datasets.CIFAR10 352 | elif config["dataset"] == "Cifar100": 353 | dataset = torchvision.datasets.CIFAR100 354 | elif config["dataset"].startswith("/"): 355 | train_data = torch.load(config["dataset"] + ".train") 356 | test_data = torch.load(config["dataset"] + ".test") 357 | training_set = TensorDataset(train_data["data"], train_data["target"]) 358 | test_set = TensorDataset(test_data["data"], test_data["target"]) 359 | 360 | training_loader = torch.utils.data.DataLoader( 361 | training_set, 362 | batch_size=config["batch_size"], 363 | shuffle=shuffle_train, 364 | num_workers=num_workers, 365 | ) 366 | test_loader = torch.utils.data.DataLoader( 367 | test_set, batch_size=test_batch_size, shuffle=False, num_workers=num_workers 368 | ) 369 | 370 | return training_loader, test_loader 371 | else: 372 | raise ValueError("Unexpected value for config[dataset] {}".format(config["dataset"])) 373 | 374 | data_mean = (0.4914, 0.4822, 0.4465) 375 | data_stddev = (0.2023, 0.1994, 0.2010) 376 | 377 | transform_train = torchvision.transforms.Compose( 378 | [ 379 | torchvision.transforms.RandomCrop(32, padding=4), 380 | torchvision.transforms.RandomHorizontalFlip(), 381 | torchvision.transforms.ToTensor(), 382 | torchvision.transforms.Normalize(data_mean, data_stddev), 383 | ] 384 | ) 385 | 386 | transform_test = torchvision.transforms.Compose( 387 | [ 388 | torchvision.transforms.ToTensor(), 389 | torchvision.transforms.Normalize(data_mean, data_stddev), 390 | ] 391 | ) 392 | 393 | training_set = dataset(root=data_root, train=True, download=True, transform=transform_train) 394 | test_set = dataset(root=data_root, train=False, download=True, transform=transform_test) 395 | 396 | training_loader = torch.utils.data.DataLoader( 397 | training_set, 398 | batch_size=config["batch_size"], 399 | shuffle=shuffle_train, 400 | num_workers=num_workers, 401 | ) 402 | test_loader = torch.utils.data.DataLoader( 403 | test_set, batch_size=test_batch_size, shuffle=False, num_workers=num_workers 404 | ) 405 | 406 | return training_loader, test_loader 407 | 408 | def split_dict(d, first_predicate): 409 | """split the dictionary d into 2 dictionaries, first one contains elements validating first_predicate""" 410 | first, second = OrderedDict(), OrderedDict() 411 | for key, value in d.items(): 412 | if first_predicate(key): 413 | first[key] = value 414 | else: 415 | second[key] = value 416 | return first, second 417 | 418 | def get_optimizer(model_named_parameters, max_steps): 419 | """ 420 | Create an optimizer for a given model 421 | :param model_parameters: a list of parameters to be trained 422 | :return: Tuple (optimizer, scheduler) 423 | """ 424 | if config["optimizer"] == "SGD": 425 | without_weight_decay, with_weight_decay = split_dict( 426 | OrderedDict(model_named_parameters), 427 | lambda name: "attention_spreads" in name or "attention_centers" in name 428 | ) 429 | 430 | optimizer = torch.optim.SGD( 431 | [ 432 | {"params": with_weight_decay.values()}, 433 | {"params": without_weight_decay.values(), "weight_decay": 0.} 434 | ], 435 | lr=config["optimizer_learning_rate"], 436 | momentum=config["optimizer_momentum"], 437 | weight_decay=config["optimizer_weight_decay"], 438 | ) 439 | elif config["optimizer"] == "Adam": 440 | optimizer = torch.optim.Adam(model_named_parameters.values(), lr=config["optimizer_learning_rate"]) 441 | else: 442 | raise ValueError("Unexpected value for optimizer") 443 | 444 | if config["optimizer"] == "Adam": 445 | scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lambda e: 1.) 446 | print("Adam optimizer ignore all learning rate schedules.") 447 | elif config["optimizer_cosine_lr"]: 448 | scheduler = linear_warmup_cosine_lr_scheduler( 449 | optimizer, config["optimizer_warmup_ratio"], max_steps 450 | ) 451 | 452 | else: 453 | scheduler = torch.optim.lr_scheduler.MultiStepLR( 454 | optimizer, 455 | milestones=config["optimizer_decay_at_epochs"], 456 | gamma=1.0 / config["optimizer_decay_with_factor"], 457 | ) 458 | 459 | return optimizer, scheduler 460 | 461 | 462 | def get_model(device): 463 | """ 464 | :param device: instance of torch.device 465 | :return: An instance of torch.nn.Module 466 | """ 467 | num_classes = 2 468 | if config["dataset"] == "Cifar100": 469 | num_classes = 100 470 | elif config["dataset"] == "Cifar10": 471 | num_classes = 10 472 | 473 | model = { 474 | "vgg11": lambda: models.VGG("VGG11", num_classes, batch_norm=False), 475 | "vgg11_bn": lambda: models.VGG("VGG11", num_classes, batch_norm=True), 476 | "vgg13": lambda: models.VGG("VGG13", num_classes, batch_norm=False), 477 | "vgg13_bn": lambda: models.VGG("VGG13", num_classes, batch_norm=True), 478 | "vgg16": lambda: models.VGG("VGG16", num_classes, batch_norm=False), 479 | "vgg16_bn": lambda: models.VGG("VGG16", num_classes, batch_norm=True), 480 | "vgg19": lambda: models.VGG("VGG19", num_classes, batch_norm=False), 481 | "vgg19_bn": lambda: models.VGG("VGG19", num_classes, batch_norm=True), 482 | "resnet10": lambda: models.ResNet10(num_classes=num_classes), 483 | "resnet18": lambda: models.ResNet18(num_classes=num_classes), 484 | "resnet34": lambda: models.ResNet34(num_classes=num_classes), 485 | "resnet50": lambda: models.ResNet50(num_classes=num_classes), 486 | "resnet101": lambda: models.ResNet101(num_classes=num_classes), 487 | "resnet152": lambda: models.ResNet152(num_classes=num_classes), 488 | "bert": lambda: models.BertImage(config, num_classes=num_classes), 489 | }[config["model"]]() 490 | 491 | model.to(device) 492 | if device == torch.device("cuda"): 493 | print("Use DataParallel if multi-GPU") 494 | model = torch.nn.DataParallel(model) 495 | torch.backends.cudnn.benchmark = True 496 | 497 | return model 498 | 499 | def print_parameters(model): 500 | # compute number of parameters 501 | num_params, _ = get_num_parameter(model, trainable=False) 502 | num_bytes = num_params * 32 // 8 # assume float32 for all 503 | print(f"Number of parameters: {human_format(num_params)} ({sizeof_fmt(num_bytes)} for float32)") 504 | num_trainable_params, trainable_parameters = get_num_parameter(model, trainable=True) 505 | print("Number of trainable parameters:", human_format(num_trainable_params)) 506 | 507 | if config["only_list_parameters"]: 508 | # Print detailed number of parameters 509 | print(tabulate.tabulate(trainable_parameters)) 510 | 511 | 512 | def print_flops(model): 513 | shape = None 514 | if config["dataset"] in ["Cifar10", "Cifar100"]: 515 | shape = (1, 3, 32, 32) 516 | else: 517 | print(f"Unknown dataset {config['dataset']} input size to compute # FLOPS") 518 | return 519 | 520 | try: 521 | from thop import profile 522 | except: 523 | print("Please `pip install thop` to compute # FLOPS") 524 | return 525 | 526 | model = model.train() 527 | input_data = torch.rand(*shape) 528 | num_flops, num_params = profile(model, inputs=(input_data, )) 529 | print("Number of FLOPS:", human_format(num_flops)) 530 | 531 | 532 | def find_degenerated_heads(model): 533 | """ 534 | returns a dict of degenerated head per layer like {layer_idx -> [head_idx, ...]} 535 | """ 536 | model_params = dict(model.named_parameters()) 537 | degenerated_heads = OrderedDict() 538 | degenerated_reasons = [] 539 | 540 | for layer_idx in range(config["num_hidden_layers"]): 541 | prune_heads = [] 542 | sigmas_half_inv = model_params["encoder.layer.{}.attention.self.attention_spreads".format(layer_idx)] 543 | 544 | for head_idx in range(config["num_attention_heads"]): 545 | head_is_degenerated = False 546 | 547 | 548 | if config["attention_isotropic_gaussian"]: 549 | sigma_inv = sigmas_half_inv[head_idx] 550 | if sigma_inv ** 2 < 1e-5: 551 | degenerated_reasons.append("Sigma too low -> uniform attention: sigma**-2= {}".format(sigma_inv ** 2)) 552 | head_is_degenerated = True 553 | else: 554 | sigma_half_inv = sigmas_half_inv[head_idx] 555 | sigma_inv = sigma_half_inv.transpose(0, 1) @ sigma_half_inv 556 | eig_values = torch.eig(sigma_inv)[0][:, 0].abs() 557 | condition_number = eig_values.max() / eig_values.min() 558 | 559 | if condition_number > 1000: 560 | degenerated_reasons.append("Covariance matrix is ill defined: condition number = {}".format(condition_number)) 561 | head_is_degenerated = True 562 | elif eig_values.max() < 1e-5: 563 | degenerated_reasons.append("Covariance matrix is close to 0: largest eigen value = {}".format(eig_values.max())) 564 | head_is_degenerated = True 565 | 566 | if head_is_degenerated: 567 | prune_heads.append(head_idx) 568 | 569 | if prune_heads: 570 | degenerated_heads[layer_idx] = prune_heads 571 | 572 | if degenerated_heads: 573 | print("Degenerated heads:") 574 | reasons = iter(degenerated_reasons) 575 | table = [(layer, head, next(reasons)) for layer, heads in degenerated_heads.items() for head in heads] 576 | print(tabulate.tabulate(table, headers=["layer", "head", "reason"])) 577 | 578 | return degenerated_heads 579 | 580 | def get_singular_gaussian_penalty(model): 581 | """Return scalar high when attention covariance get very singular 582 | """ 583 | if config["attention_isotropic_gaussian"]: 584 | # TODO move at setup 585 | print("Singular gaussian penalty ignored as `attention_isotropic_gaussian` is True") 586 | return 0 587 | 588 | condition_numbers = [] 589 | for layer in model.encoder.layer: 590 | for sigma_half_inv in layer.attention.self.attention_spreads: 591 | sigma_inv = sigma_half_inv.transpose(0, 1) @ sigma_half_inv 592 | eig_values = torch.eig(sigma_inv)[0][:, 0].abs() 593 | condition_number = eig_values.max() / eig_values.min() 594 | condition_numbers.append(condition_number) 595 | 596 | return torch.mean((torch.tensor(condition_numbers) - 1) ** 2) 597 | 598 | def store_config(): 599 | path = os.path.join(output_dir, "config.yaml") 600 | with open(path, "w") as f: 601 | yaml.dump(dict(config), f, sort_keys=False) 602 | 603 | 604 | def store_checkpoint(filename, model, epoch, test_accuracy): 605 | """Store a checkpoint file to the output directory""" 606 | path = os.path.join(output_dir, filename) 607 | 608 | # Ensure the output directory exists 609 | directory = os.path.dirname(path) 610 | if not os.path.isdir(directory): 611 | os.makedirs(directory, exist_ok=True) 612 | 613 | # remove buffer from checkpoint 614 | # TODO should not hard code 615 | def keep_state_dict_keys(key): 616 | if "self.R" in key: 617 | return False 618 | return True 619 | 620 | time.sleep( 621 | 1 622 | ) # workaround for RuntimeError('Unknown Error -1') https://github.com/pytorch/pytorch/issues/10577 623 | torch.save( 624 | { 625 | "epoch": epoch, 626 | "test_accuracy": test_accuracy, 627 | "model_state_dict": OrderedDict([ 628 | (key, value) for key, value in model.state_dict().items() if keep_state_dict_keys(key) 629 | ]), 630 | }, 631 | path, 632 | ) 633 | 634 | def restore_checkpoint(filename, model, device): 635 | """Load model from a checkpoint""" 636 | print("Loading model parameters from '{}'".format(filename)) 637 | with open(filename, "rb") as f: 638 | checkpoint_data = torch.load(f, map_location=device) 639 | 640 | try: 641 | model.load_state_dict(checkpoint_data["model_state_dict"]) 642 | except RuntimeError as e: 643 | print(colored("Missing state_dict keys in checkpoint", "red"), e) 644 | print("Retry import with current model values for missing keys.") 645 | state = model.state_dict() 646 | state.update(checkpoint_data["model_state_dict"]) 647 | model.load_state_dict(state) 648 | 649 | 650 | if __name__ == "__main__": 651 | # if directly called from CLI (not as module) 652 | # we parse the parameters overides 653 | config = parse_cli_overides(config) 654 | main() 655 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/epfml/attention-cnn/21483bba7d8e3ff1dc104c1b311a44a84d5c9db4/utils/__init__.py -------------------------------------------------------------------------------- /utils/accumulators.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from copy import deepcopy 3 | 4 | class Mean: 5 | """ 6 | Running average of the values that are 'add'ed 7 | """ 8 | def __init__(self, update_weight=1): 9 | """ 10 | :param update_weight: 1 for normal, 2 for t-average 11 | """ 12 | self.average = None 13 | self.counter = 0 14 | self.update_weight = update_weight 15 | 16 | def add(self, value, weight=1): 17 | """Add a value to the accumulator""" 18 | self.counter += weight 19 | if self.average is None: 20 | self.average = deepcopy(value) 21 | else: 22 | delta = value - self.average 23 | self.average += delta * self.update_weight * weight / (self.counter + self.update_weight - 1) 24 | if isinstance(self.average, torch.Tensor): 25 | self.average.detach() 26 | 27 | def value(self): 28 | """Access the current running average""" 29 | return self.average 30 | 31 | 32 | class Max: 33 | """ 34 | Keeps track of the max of all the values that are 'add'ed 35 | """ 36 | def __init__(self): 37 | self.max = None 38 | 39 | def add(self, value): 40 | """ 41 | Add a value to the accumulator. 42 | :return: `true` if the provided value became the new max 43 | """ 44 | if self.max is None or value > self.max: 45 | self.max = deepcopy(value) 46 | return True 47 | else: 48 | return False 49 | 50 | def value(self): 51 | """Access the current running average""" 52 | return self.max 53 | -------------------------------------------------------------------------------- /utils/config.py: -------------------------------------------------------------------------------- 1 | from termcolor import colored 2 | import argparse 3 | import enum 4 | 5 | 6 | def parse_cli_overides(config: dict): 7 | """ 8 | Parse args from CLI and override config dictionary entries 9 | """ 10 | parser = argparse.ArgumentParser() 11 | for key, value in config.items(): 12 | parser.add_argument(f"--{key}") 13 | args = vars(parser.parse_args()) 14 | 15 | def print_config_override(key, old_value, new_value, first_config_overide): 16 | if first_config_overide: 17 | print(colored("Config overrides:", "red")) 18 | print(f" {key:25s} -> {new_value} (instead of {old_value})") 19 | 20 | def cast_argument(key, old_value, new_value): 21 | try: 22 | if new_value is None: 23 | return None 24 | if type(old_value) is int: 25 | return int(new_value) 26 | if type(old_value) is float: 27 | return float(new_value) 28 | if type(old_value) is str: 29 | return new_value 30 | if type(old_value) is bool: 31 | return new_value.lower() in ("yes", "true", "t", "1") 32 | if issubclass(old_value.__class__, enum.Enum): 33 | return old_value.__class__(new_value) 34 | if old_value is None: 35 | return new_value # assume string 36 | raise ValueError() 37 | except Exception: 38 | raise ValueError(f"Unable to parse config key '{key}' with value '{new_value}'") 39 | 40 | first_config_overide = True 41 | for key, original_value in config.items(): 42 | override_value = cast_argument(key, original_value, args[key]) 43 | if override_value is not None and override_value != original_value: 44 | config[key] = override_value 45 | print_config_override(key, original_value, override_value, first_config_overide) 46 | first_config_overide = False 47 | 48 | return config 49 | -------------------------------------------------------------------------------- /utils/data.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset 2 | import torch 3 | import numpy as np 4 | 5 | 6 | class MaskedDataset(Dataset): 7 | """ 8 | Wrap a dataset of images and append a random mask to each sample 9 | """ 10 | 11 | def __init__(self, dataset, mask_size): 12 | self.dataset = dataset 13 | self.mask_size = mask_size 14 | 15 | def __getitem__(self, item): 16 | sample = self.dataset[item] 17 | image = sample[0] 18 | _, width, height = image.shape 19 | 20 | batch_mask = torch.ones([width, height], dtype=torch.uint8) 21 | mask_left = np.random.randint(0, width - self.mask_size) 22 | mask_top = np.random.randint(0, height - self.mask_size) 23 | batch_mask[mask_left : mask_left + self.mask_size, mask_top : mask_top + self.mask_size] = 0 24 | 25 | return sample + (batch_mask,) 26 | 27 | def __len__(self): 28 | return len(self.dataset) 29 | 30 | -------------------------------------------------------------------------------- /utils/learning_rate.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | from torch.optim.lr_scheduler import LambdaLR 4 | 5 | 6 | def linear_warmup_cosine_lr_scheduler(optimizer, warmup_time_ratio, T_max): 7 | T_warmup = int(T_max * warmup_time_ratio) 8 | 9 | def lr_lambda(epoch): 10 | # linear warm up 11 | if epoch < T_warmup: 12 | return epoch / T_warmup 13 | else: 14 | progress_0_1 = (epoch - T_warmup) / (T_max - T_warmup) 15 | cosine_decay = 0.5 * (1 + math.cos(math.pi * progress_0_1)) 16 | return cosine_decay 17 | 18 | return LambdaLR(optimizer, lr_lambda=lr_lambda) 19 | 20 | -------------------------------------------------------------------------------- /utils/logging.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from tensorboardX import SummaryWriter 4 | 5 | 6 | def get_num_parameter(model, trainable=False): 7 | if trainable: 8 | params = [(n, p) for (n, p) in model.named_parameters() if p.requires_grad] 9 | else: 10 | params = [(n, p) for (n, p) in model.named_parameters()] 11 | 12 | total_params = sum(p.numel() for (n, p) in params) 13 | num_param_list = [(n, p.numel()) for (n, p) in params] 14 | 15 | return total_params, num_param_list 16 | 17 | 18 | def human_format(num): 19 | num = float("{:.3g}".format(num)) 20 | magnitude = 0 21 | while abs(num) >= 1000: 22 | magnitude += 1 23 | num /= 1000.0 24 | return "{}{}".format( 25 | "{:f}".format(num).rstrip("0").rstrip("."), ["", "K", "M", "B", "T"][magnitude] 26 | ) 27 | 28 | 29 | def sizeof_fmt(num, suffix="B"): 30 | for unit in ["", "Ki", "Mi", "Gi", "Ti", "Pi", "Ei", "Zi"]: 31 | if abs(num) < 1024.0: 32 | return "%3.1f%s%s" % (num, unit, suffix) 33 | num /= 1024.0 34 | return "%.1f%s%s" % (num, "Yi", suffix) 35 | 36 | 37 | class DummySummaryWriter: 38 | """Mock a TensorboardX summary writer but does not do anything""" 39 | 40 | def __init__(self): 41 | def noop(*args, **kwargs): 42 | pass 43 | 44 | s = SummaryWriter() 45 | for f in dir(s): 46 | if not f.startswith("_"): 47 | self.__setattr__(f, noop) 48 | 49 | 50 | class JSONLogger: 51 | """ 52 | Very simple prototype logger that will store the values to a JSON file 53 | """ 54 | 55 | def __init__(self, filename, auto_save=True): 56 | """ 57 | :param filename: ending with .json 58 | :param auto_save: save the JSON file after every addition 59 | """ 60 | self.filename = filename 61 | self.values = [] 62 | self.auto_save = auto_save 63 | 64 | # Ensure the output directory exists 65 | directory = os.path.dirname(self.filename) 66 | if not os.path.isdir(directory): 67 | os.makedirs(directory, exist_ok=True) 68 | 69 | def log_metric(self, name, values, tags): 70 | """ 71 | Store a scalar metric 72 | 73 | :param name: measurement, like 'accuracy' 74 | :param values: dictionary, like { epoch: 3, value: 0.23 } 75 | :param tags: dictionary, like { split: train } 76 | """ 77 | self.values.append({"measurement": name, **values, **tags}) 78 | print("{name}: {values} ({tags})".format(name=name, values=values, tags=tags)) 79 | if self.auto_save: 80 | self.save() 81 | 82 | def save(self): 83 | """ 84 | Save the internal memory to a file 85 | """ 86 | with open(self.filename, "w") as fp: 87 | json.dump(self.values, fp, indent=" ") 88 | -------------------------------------------------------------------------------- /utils/plotting.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import numpy as np 3 | from scipy import interpolate 4 | import matplotlib.colors as mcolors 5 | from matplotlib.patches import Ellipse, Rectangle 6 | import itertools 7 | 8 | 9 | def plot_grid_query_pix(width, ax=None): 10 | if ax is None: 11 | plt.figure() 12 | ax = plt.gca() 13 | 14 | ax.set_xticks(np.arange(-width / 2, width / 2)) # , minor=True) 15 | ax.set_aspect(1) 16 | ax.set_yticks(np.arange(-width / 2, width / 2)) # , minor=True) 17 | ax.tick_params( 18 | axis="both", 19 | which="both", 20 | bottom=False, 21 | top=False, 22 | left=False, 23 | labelbottom=False, 24 | labelleft=False, 25 | ) 26 | ax.grid(True, alpha=0.5) 27 | 28 | # query pixel 29 | querry_pix = Rectangle(xy=(-0.5,-0.5), 30 | width=1, 31 | height=1, 32 | edgecolor="black", 33 | fc='None', 34 | lw=2) 35 | 36 | ax.add_patch(querry_pix); 37 | 38 | ax.set_xlim(-width / 2, width / 2) 39 | ax.set_ylim(-width / 2, width / 2) 40 | ax.set_aspect("equal") 41 | 42 | def plot_attention_layer(model, layer_idx, width, ax=None): 43 | """Plot the 2D attention probabilities of all heads on an image 44 | of layer layer_idx 45 | """ 46 | if ax is None: 47 | fig, ax = plt.subplots() 48 | 49 | attention = model.encoder.layer[layer_idx].attention.self 50 | attention_probs = attention.get_attention_probs(width + 2, width + 2) 51 | 52 | contours = np.array([0.9, 0.5]) 53 | linestyles = [":", "-"] 54 | flat_colors = ["#3498db", "#f1c40f", "#2ecc71", "#e74c3c", "#e67e22", "#9b59b6", "#34495e", "#1abc9c", "#95a5a6"] 55 | 56 | if ax is None: 57 | fig, ax = plt.subplots() 58 | 59 | shape = attention_probs.shape 60 | # remove batch size if present 61 | if len(shape) == 6: 62 | shape = shape[1:] 63 | height, width, num_heads, _, _ = shape 64 | 65 | attention_at_center = attention_probs[width // 2, height // 2] 66 | attention_at_center = attention_at_center.detach().cpu().numpy() 67 | 68 | # compute integral of distribution for thresholding 69 | n = 1000 70 | t = np.linspace(0, attention_at_center.max(), n) 71 | integral = ((attention_at_center >= t[:, None, None, None]) * attention_at_center).sum( 72 | axis=(-1, -2) 73 | ) 74 | 75 | plot_grid_query_pix(width - 2, ax) 76 | 77 | for h, color in zip(range(num_heads), itertools.cycle(flat_colors)): 78 | f = interpolate.interp1d(integral[:, h], t, fill_value=(1, 0), bounds_error=False) 79 | t_contours = f(contours) 80 | 81 | # remove duplicate contours if any 82 | keep_contour = np.concatenate([np.array([True]), np.diff(t_contours) > 0]) 83 | t_contours = t_contours[keep_contour] 84 | 85 | for t_contour, linestyle in zip(t_contours, linestyles): 86 | ax.contour( 87 | np.arange(-width // 2, width // 2) + 1, 88 | np.arange(-height // 2, height // 2) + 1, 89 | attention_at_center[h], 90 | [t_contour], 91 | extent=[- width // 2, width // 2 + 1, - height // 2, height // 2 + 1], 92 | colors=color, 93 | linestyles=linestyle 94 | ) 95 | 96 | return ax 97 | 98 | 99 | def plot_attention_positions_all_layers(model, width, tensorboard_writer=None, global_step=None): 100 | 101 | for layer_idx in range(len(model.encoder.layer)): 102 | fig, ax = plt.subplots() 103 | plot_attention_layer(model, layer_idx, width, ax=ax) 104 | 105 | ax.set_title(f"Layer {layer_idx + 1}") 106 | if tensorboard_writer: 107 | tensorboard_writer.add_figure(f"attention/layer{layer_idx}", fig, global_step=global_step) 108 | plt.close(fig) 109 | --------------------------------------------------------------------------------