├── .github └── workflows │ └── python-publish.yml ├── .gitignore ├── .vscode └── settings.json ├── LICENSE ├── LOGO.png ├── README.md ├── setup.py └── vat_pytorch ├── __init__.py ├── alice.py ├── alicepp.py ├── models ├── __init__.py ├── alice_classification_model.py ├── alicepp_classification_model.py ├── extracted_model.py ├── extracted_roberta.py └── smart_classification_model.py ├── smart.py └── utils.py /.github/workflows/python-publish.yml: -------------------------------------------------------------------------------- 1 | # This workflow will upload a Python Package using Twine when a release is created 2 | # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries 3 | 4 | # This workflow uses actions that are not certified by GitHub. 5 | # They are provided by a third-party and are governed by 6 | # separate terms of service, privacy policy, and support 7 | # documentation. 8 | 9 | name: Upload Python Package 10 | 11 | on: 12 | release: 13 | types: [published] 14 | 15 | permissions: 16 | contents: read 17 | 18 | jobs: 19 | deploy: 20 | 21 | runs-on: ubuntu-latest 22 | 23 | steps: 24 | - uses: actions/checkout@v3 25 | - name: Set up Python 26 | uses: actions/setup-python@v3 27 | with: 28 | python-version: '3.x' 29 | - name: Install dependencies 30 | run: | 31 | python -m pip install --upgrade pip 32 | pip install build 33 | - name: Build package 34 | run: python -m build 35 | - name: Publish package 36 | uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29 37 | with: 38 | user: __token__ 39 | password: ${{ secrets.PYPI_API_TOKEN }} 40 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ -------------------------------------------------------------------------------- /.vscode/settings.json: -------------------------------------------------------------------------------- 1 | { 2 | "python.formatting.provider": "black" 3 | } -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 archinet.ai 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /LOGO.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/archinetai/vat-pytorch/34df8043969c0d2ccb7e4189ca337d91b7270c2c/LOGO.png -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | A collection of VAT (Virtual Adversarial Training) methods, in PyTorch. 4 | 5 | ## Install 6 | 7 | ```bash 8 | $ pip install vat-pytorch 9 | ``` 10 | 11 | [![PyPI - Python Version](https://img.shields.io/pypi/v/vat-pytorch?style=flat&colorA=0f0f0f&colorB=0f0f0f)](https://pypi.org/project/vat-pytorch/) 12 | 13 | 14 | ## API 15 | 16 | ### SMART 17 | The SMART paper proposes to find the noise that maximally perturbs the logits when added to the embedding layer, and to use a loss function to make sure that the perturbed logits are as close as possible to the predicted logits. 18 | 19 | ```py 20 | from vat_pytorch import SMARTLoss, inf_norm 21 | 22 | loss = SMARTLoss( 23 | model: nn.Module, 24 | loss_fn: Callable, 25 | loss_last_fn: Callable = None, 26 | norm_fn: Callable = inf_norm, 27 | num_steps: int = 1, 28 | step_size: float = 1e-3, 29 | epsilon: float = 1e-6, 30 | noise_var: float = 1e-5 31 | ) 32 | ``` 33 | 34 | ### ALICE 35 | 36 | The ALICE paper is analogous to the SMART paper, but adds an additional term to make sure that the perturbed logits are as close as possible to both the predicted logits *and* the ground truth labels. 37 | 38 | ```py 39 | from vat_pytorch import ALICELoss, inf_norm 40 | 41 | loss = ALICEPPLoss( 42 | model: nn.Module, 43 | loss_fn: Callable, 44 | num_classes: int, 45 | loss_last_fn: Callable = None, 46 | gold_loss_fn: Callable = None, 47 | gold_loss_last_fn: Callable = None, 48 | norm_fn: Callable = inf_norm, 49 | alpha: float = 1, 50 | num_steps: int = 1, 51 | step_size: float = 1e-3, 52 | epsilon: float = 1e-6, 53 | noise_var: float = 1e-5, 54 | ) 55 | ``` 56 | 57 | ### ALICE++ 58 | 59 | The ALICE++ paper is analogous to the ALICE paper, but instead of adding noise to the embedding layer, it picks a random layer from the network at each iteration on which to add the noise. 60 | 61 | ```py 62 | from vat_pytorch import ALICEPPLoss, ALICEPPModule, inf_norm 63 | 64 | loss = ALICEPPLoss( 65 | model: ALICEPPModule, 66 | num_classes: int, 67 | loss_fn: Callable, 68 | num_layers: int, 69 | max_layer: int = None, 70 | loss_last_fn: Callable = None, 71 | gold_loss_fn: Callable = None, 72 | gold_loss_last_fn: Callable = None, 73 | norm_fn: Callable = inf_norm, 74 | alpha: float = 1, 75 | num_steps: int = 1, 76 | step_size: float = 1e-3, 77 | epsilon: float = 1e-6, 78 | noise_var: float = 1e-5, 79 | ) 80 | ``` 81 | 82 | ## Usage (Classification) 83 | 84 | ### Extract Model 85 | The first thing we have to do is extract the chunk of the model that we want to perturb adversarially. A generic example with Huggingface's RoBERTa for sequence classification is given. 86 | 87 | ```py 88 | import torch.nn as nn 89 | from transformers import AutoModelForSequenceClassification 90 | 91 | class ExtractedRoBERTa(nn.Module): 92 | 93 | def __init__(self): 94 | super().__init__() 95 | model = AutoModelForSequenceClassification.from_pretrained('roberta-base') 96 | self.roberta = model.roberta 97 | self.layers = model.roberta.encoder.layer 98 | self.classifier = model.classifier 99 | self.attention_mask = None 100 | self.num_layers = len(self.layers) - 1 101 | 102 | def forward(self, hidden, with_hidden_states = False, start_layer = 0): 103 | """ Forwards the hidden value from self.start_layer layer to the logits. """ 104 | hidden_states = [hidden] 105 | 106 | for layer in self.layers[start_layer:]: 107 | hidden = layer(hidden, attention_mask = self.attention_mask)[0] 108 | hidden_states += [hidden] 109 | 110 | logits = self.classifier(hidden) 111 | 112 | return (logits, hidden_states) if with_hidden_states else logits 113 | 114 | def get_embeddings(self, input_ids): 115 | """ Computes first embedding layer given inputs_ids """ 116 | return self.roberta.embeddings(input_ids) 117 | 118 | def set_attention_mask(self, attention_mask): 119 | """ Sets the correct mask on all subsequent forward passes """ 120 | self.attention_mask = self.roberta.get_extended_attention_mask( 121 | attention_mask, 122 | input_shape = attention_mask.shape, 123 | device = attention_mask.device 124 | ) # (b, 1, 1, s) 125 | ``` 126 | The function `set_attention_mask` is used to fix the attention mask for all subsequent forward calls, this is necessary if we want to use a mask using any VAT loss. The parameter `start_layer` in the forward function is necessary only if we are using `ALICEPPLoss` since the loss function needs a way to change the start layer internally. 127 | 128 | 129 | ### SMART 130 | 131 | ```py 132 | import torch.nn as nn 133 | import torch.nn.functional as F 134 | from vat_pytorch import SMARTLoss, kl_loss, sym_kl_loss 135 | 136 | class SMARTClassificationModel(nn.Module): 137 | # b: batch_size, s: sequence_length, d: hidden_size , n: num_labels 138 | 139 | def __init__(self, extracted_model, weight = 1.0): 140 | super().__init__() 141 | self.model = extracted_model 142 | self.weight = weight 143 | self.vat_loss = SMARTLoss(model = extracted_model, loss_fn = kl_loss, loss_last_fn = sym_kl_loss) 144 | 145 | def forward(self, input_ids, attention_mask, labels): 146 | """ input_ids: (b, s), attention_mask: (b, s), labels: (b,) """ 147 | # Get input embeddings 148 | embeddings = self.model.get_embeddings(input_ids) 149 | # Set mask and compute logits 150 | self.model.set_attention_mask(attention_mask) 151 | logits = self.model(embeddings) 152 | # Compute CE loss 153 | ce_loss = F.cross_entropy(logits.view(-1, 2), labels.view(-1)) 154 | # Compute VAT loss 155 | vat_loss = self.vat_loss(embeddings, logits) 156 | # Merge losses 157 | loss = ce_loss + self.weight * vat_loss 158 | return logits, loss 159 | ``` 160 | 161 | ### ALICE 162 | 163 | ```py 164 | import torch.nn as nn 165 | import torch.nn.functional as F 166 | from vat_pytorch import ALICELoss, kl_loss 167 | 168 | class ALICEClassificationModel(nn.Module): 169 | # b: batch_size, s: sequence_length, d: hidden_size , n: num_labels 170 | 171 | def __init__(self, extracted_model): 172 | super().__init__() 173 | self.model = extracted_model 174 | self.vat_loss = ALICELoss(model = extracted_model, loss_fn = kl_loss, num_classes = 2) 175 | 176 | def forward(self, input_ids, attention_mask, labels): 177 | """ input_ids: (b, s), attention_mask: (b, s), labels: (b,) """ 178 | # Get input embeddings 179 | embeddings = self.model.get_embeddings(input_ids) 180 | # Set iteration specific data (e.g. attention mask) 181 | self.model.set_attention_mask(attention_mask) 182 | # Compute logits 183 | logits = self.model(embeddings) 184 | # Compute VAT loss 185 | loss = self.vat_loss(embeddings, logits, labels) 186 | return logits, loss 187 | ``` 188 | 189 | ### ALICE++ 190 | 191 | ```py 192 | import torch.nn as nn 193 | import torch.nn.functional as F 194 | from vat_pytorch import ALICEPPLoss, kl_loss 195 | 196 | class ALICEPPClassificationModel(nn.Module): 197 | # b: batch_size, s: sequence_length, d: hidden_size , n: num_labels 198 | 199 | def __init__(self, extracted_model): 200 | super().__init__() 201 | self.model = extracted_model 202 | self.vat_loss = ALICEPPLoss( 203 | model = extracted_model, 204 | loss_fn = kl_loss, 205 | num_layers = self.model.num_layers, 206 | num_classes = 2 207 | ) 208 | 209 | def forward(self, input_ids, attention_mask, labels): 210 | """ input_ids: (b, s), attention_mask: (b, s), labels: (b,) """ 211 | # Get input embeddings 212 | embeddings = self.model.get_embeddings(input_ids) 213 | # Set iteration specific data (e.g. attention mask) 214 | self.model.set_attention_mask(attention_mask) 215 | # Compute logits 216 | logits, hidden_states = self.model(embeddings, with_hidden_states = True) 217 | # Compute VAT loss 218 | loss = self.vat_loss(hidden_states, logits, labels) 219 | return logits, loss 220 | ``` 221 | 222 | Note that `extracted_model` requires a function with the following signature `forward(self, hidden: Tensor, *, start_layer: int) -> Tensor`, the interface `ALICEPPModule` (`from vat_pytorch import ALICEPPModule`) can be used instead of the `nn.Module` class on the extracted model to make sure that the method is present. 223 | 224 | 225 | ### Wrapped Model Usage 226 | Any of the above losses can be used as follows with the extracted model. 227 | ```py 228 | import torch 229 | from transformers import AutoTokenizer 230 | 231 | extracted_model = ExtractedRoBERTa() 232 | tokenizer = AutoTokenizer.from_pretrained('roberta-base') 233 | # Pick one: 234 | model = SMARTClassificationModel(extracted_model) 235 | model = ALICEClassificationModel(extracted_model) 236 | model = ALICEPPClassificationModel(extracted_model) 237 | # Compute inputs 238 | text = ["This text belongs to class 1...", "This text belongs to class 0..."] 239 | inputs = tokenizer(text, return_tensors='pt') 240 | labels = torch.tensor([1, 0]) 241 | # Compute logits and loss 242 | logits, loss = model(input_ids = inputs['input_ids'], attention_mask = inputs['attention_mask'], labels = labels) 243 | # To finetune do this for many steps 244 | loss.backward() 245 | ``` 246 | 247 | ## Citations 248 | 249 | ```bibtex 250 | @inproceedings{Jiang2020SMARTRA, 251 | title={SMART: Robust and Efficient Fine-Tuning for Pre-trained Natural Language Models through Principled Regularized Optimization}, 252 | author={Haoming Jiang and Pengcheng He and Weizhu Chen and Xiaodong Liu and Jianfeng Gao and Tuo Zhao}, 253 | booktitle={ACL}, 254 | year={2020} 255 | } 256 | ``` 257 | 258 | ```bibtex 259 | @article{Pereira2020AdversarialTF, 260 | title={Adversarial Training for Commonsense Inference}, 261 | author={Lis Kanashiro Pereira and Xiaodong Liu and Fei Cheng and Masayuki Asahara and Ichiro Kobayashi}, 262 | journal={ArXiv}, 263 | year={2020}, 264 | volume={abs/2005.08156} 265 | } 266 | ``` 267 | 268 | ```bibtex 269 | @inproceedings{Pereira2021ALICEAT, 270 | title={ALICE++: Adversarial Training for Robust and Effective Temporal Reasoning}, 271 | author={Lis Kanashiro Pereira and Fei Cheng and Masayuki Asahara and Ichiro Kobayashi}, 272 | booktitle={PACLIC}, 273 | year={2021} 274 | } 275 | ``` 276 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup( 4 | name="vat-pytorch", 5 | packages=find_packages(exclude=[]), 6 | version="0.0.9", 7 | license="MIT", 8 | description="Virtual Adversarial Training - Pytorch", 9 | long_description_content_type="text/markdown", 10 | author="Archinet", 11 | author_email="archinetai@protonmail.com", 12 | url="https://github.com/archinetai/vat-pytorch", 13 | keywords=[ 14 | "artificial intelligence", 15 | "deep learning", 16 | "fine-tuning", 17 | "pre-trained", 18 | ], 19 | install_requires=[ 20 | "torch>=1.6", 21 | "data-science-types>=0.2" 22 | "transformers>=4.0.0" 23 | ], 24 | classifiers=[ 25 | "Development Status :: 4 - Beta", 26 | "Intended Audience :: Developers", 27 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 28 | "License :: OSI Approved :: MIT License", 29 | "Programming Language :: Python :: 3.6", 30 | ], 31 | ) 32 | -------------------------------------------------------------------------------- /vat_pytorch/__init__.py: -------------------------------------------------------------------------------- 1 | from .smart import SMARTLoss 2 | from .alice import ALICELoss 3 | from .alicepp import ALICEPPLoss, ALICEPPModule 4 | from .utils import kl_loss, sym_kl_loss, js_loss, inf_norm, default 5 | from .models import * -------------------------------------------------------------------------------- /vat_pytorch/alice.py: -------------------------------------------------------------------------------- 1 | from typing import Union, Callable 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from torch import Tensor 7 | from itertools import count 8 | from .utils import default, inf_norm 9 | 10 | 11 | class ALICELoss(nn.Module): 12 | def __init__( 13 | self, 14 | model: nn.Module, 15 | loss_fn: Callable, 16 | num_classes: int, 17 | loss_last_fn: Callable = None, 18 | gold_loss_fn: Callable = None, 19 | gold_loss_last_fn: Callable = None, 20 | norm_fn: Callable = inf_norm, 21 | alpha: float = 1, 22 | num_steps: int = 1, 23 | step_size: float = 1e-3, 24 | epsilon: float = 1e-6, 25 | noise_var: float = 1e-5, 26 | ) -> None: 27 | super().__init__() 28 | self.model = model 29 | self.num_classes = num_classes 30 | self.loss_fn = loss_fn 31 | self.loss_last_fn = default(loss_last_fn, loss_fn) 32 | self.gold_loss_fn = default(gold_loss_fn, loss_fn) 33 | self.gold_loss_last_fn = default( 34 | default(gold_loss_last_fn, self.gold_loss_fn), self.loss_last_fn 35 | ) 36 | self.norm_fn = norm_fn 37 | self.alpha = alpha 38 | self.num_steps = num_steps 39 | self.step_size = step_size 40 | self.epsilon = epsilon 41 | self.noise_var = noise_var 42 | 43 | def forward(self, embed: Tensor, state: Tensor, labels: Tensor) -> Tensor: 44 | 45 | virtual_loss = self.get_perturbed_loss( 46 | embed, state, loss_fn=self.loss_fn, loss_last_fn=self.loss_last_fn 47 | ) 48 | 49 | labels_loss = self.get_perturbed_loss( 50 | embed, 51 | state=F.one_hot(labels, num_classes=self.num_classes).float(), 52 | loss_fn=self.gold_loss_fn, 53 | loss_last_fn=self.gold_loss_last_fn, 54 | ) 55 | 56 | return labels_loss + self.alpha * virtual_loss 57 | 58 | @torch.enable_grad() 59 | def get_perturbed_loss( 60 | self, embed: Tensor, state: Tensor, loss_fn: Callable, loss_last_fn: Callable 61 | ): 62 | noise = torch.randn_like(embed, requires_grad=True) * self.noise_var 63 | 64 | # Indefinite loop with counter 65 | for i in count(): 66 | # Compute perturbed embed and states 67 | embed_perturbed = embed + noise 68 | state_perturbed = self.model(embed_perturbed) 69 | # Return final loss if last step (undetached state) 70 | if i == self.num_steps: 71 | return loss_last_fn(state_perturbed, state) 72 | # Compute perturbation loss (detached state) 73 | loss = loss_fn(state_perturbed, state.detach()) 74 | # Compute noise gradient ∂loss/∂noise 75 | (noise_gradient,) = torch.autograd.grad(loss, noise) 76 | # Move noise towards gradient to change state as much as possible 77 | step = noise + self.step_size * noise_gradient 78 | # Normalize new noise step into norm induced ball 79 | step_norm = self.norm_fn(step) 80 | noise = step / (step_norm + self.epsilon) 81 | # Reset noise gradients for next step 82 | noise = noise.detach().requires_grad_() 83 | -------------------------------------------------------------------------------- /vat_pytorch/alicepp.py: -------------------------------------------------------------------------------- 1 | from typing import List, Union, Callable 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from torch import Tensor 7 | from itertools import count 8 | from .utils import default, inf_norm 9 | 10 | 11 | class ALICEPPModule(nn.Module): 12 | """ Interface for the model provided to ALICEPPLoss """ 13 | 14 | def __init__(self): 15 | super().__init__() 16 | 17 | def forward(self, hidden: Tensor, *, start_layer: int) -> Tensor: 18 | raise NotImplementedError() 19 | 20 | 21 | class ALICEPPLoss(nn.Module): 22 | 23 | def __init__( 24 | self, 25 | model: ALICEPPModule, 26 | num_classes: int, 27 | loss_fn: Callable, 28 | num_layers: int, 29 | max_layer: int = None, 30 | loss_last_fn: Callable = None, 31 | gold_loss_fn: Callable = None, 32 | gold_loss_last_fn: Callable = None, 33 | norm_fn: Callable = inf_norm, 34 | alpha: float = 1, 35 | num_steps: int = 1, 36 | step_size: float = 1e-3, 37 | epsilon: float = 1e-6, 38 | noise_var: float = 1e-5, 39 | ) -> None: 40 | super().__init__() 41 | self.model = model 42 | self.num_classes = num_classes 43 | self.loss_fn = loss_fn 44 | self.num_layers = num_layers 45 | self.max_layer = min(default(max_layer, num_layers), num_layers) 46 | self.loss_last_fn = default(loss_last_fn, loss_fn) 47 | self.gold_loss_fn = default(gold_loss_fn, loss_fn) 48 | self.gold_loss_last_fn = default(default(gold_loss_last_fn, self.gold_loss_fn), self.loss_last_fn) 49 | self.norm_fn = norm_fn 50 | self.alpha = alpha 51 | self.num_steps = num_steps 52 | self.step_size = step_size 53 | self.epsilon = epsilon 54 | self.noise_var = noise_var 55 | 56 | def forward(self, hiddens: List[Tensor], state: Tensor, labels: Tensor) -> Tensor: 57 | 58 | # Pick random layer on which we apply the perturbation 59 | random_layer_id = torch.randint(low = 0, high = self.max_layer, size = (1,))[0] 60 | 61 | virtual_loss = self.get_perturbed_loss( 62 | hidden = hiddens[random_layer_id], 63 | state = state, 64 | layer_id = random_layer_id, 65 | loss_fn = self.loss_fn, 66 | loss_last_fn = self.loss_last_fn 67 | ) 68 | 69 | label_loss = self.get_perturbed_loss( 70 | hidden = hiddens[random_layer_id], 71 | state = F.one_hot(labels, num_classes=self.num_classes).float(), 72 | layer_id = random_layer_id, 73 | loss_fn = self.gold_loss_fn, 74 | loss_last_fn = self.gold_loss_last_fn 75 | ) 76 | 77 | return label_loss + self.alpha * virtual_loss 78 | 79 | @torch.enable_grad() 80 | def get_perturbed_loss( 81 | self, 82 | hidden: Tensor, 83 | state: Tensor, 84 | layer_id: int, 85 | loss_fn: Callable, 86 | loss_last_fn: Callable 87 | ): 88 | noise = torch.randn_like(hidden, requires_grad = True) * self.noise_var 89 | 90 | # Indefinite loop with counter 91 | for i in count(): 92 | # Compute perturbed hidden and states 93 | hidden_perturbed = hidden + noise 94 | state_perturbed = self.model(hidden_perturbed, start_layer = layer_id) 95 | # Return final loss if last step (undetached state) 96 | if i == self.num_steps: 97 | return loss_last_fn(state_perturbed, state) 98 | # Compute perturbation loss (detached state) 99 | loss = loss_fn(state_perturbed, state.detach()) 100 | # Compute noise gradient ∂loss/∂noise 101 | noise_gradient, = torch.autograd.grad(loss, noise) 102 | # Move noise towards gradient to change state as much as possible 103 | step = noise + self.step_size * noise_gradient 104 | # Normalize new noise step into norm induced ball 105 | step_norm = self.norm_fn(step) 106 | noise = step / (step_norm + self.epsilon) 107 | # Reset noise gradients for next step 108 | noise = noise.detach().requires_grad_() 109 | -------------------------------------------------------------------------------- /vat_pytorch/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .extracted_model import ExtractedModel 2 | from .extracted_roberta import ExtractedRoBERTa 3 | from .smart_classification_model import SMARTClassificationModel 4 | from .alice_classification_model import ALICEClassificationModel 5 | from .alicepp_classification_model import ALICEPPClassificationModel -------------------------------------------------------------------------------- /vat_pytorch/models/alice_classification_model.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | from vat_pytorch import ALICELoss, kl_loss 4 | 5 | class ALICEClassificationModel(nn.Module): 6 | # b: batch_size, s: sequence_length, d: hidden_size , n: num_labels 7 | 8 | def __init__(self, extracted_model): 9 | super().__init__() 10 | self.model = extracted_model 11 | self.vat_loss = ALICELoss(model = extracted_model, loss_fn = kl_loss, num_classes = 2) 12 | 13 | def forward(self, input_ids, attention_mask, labels): 14 | """ input_ids: (b, s), attention_mask: (b, s), labels: (b,) """ 15 | # Get input embeddings 16 | embeddings = self.model.get_embeddings(input_ids) 17 | # Set iteration specific data (e.g. attention mask) 18 | self.model.set_attention_mask(attention_mask) 19 | # Compute logits 20 | logits = self.model(embeddings) 21 | # Compute VAT loss 22 | loss = self.vat_loss(embeddings, logits, labels) 23 | return logits, loss -------------------------------------------------------------------------------- /vat_pytorch/models/alicepp_classification_model.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | from vat_pytorch import ALICEPPLoss, kl_loss 4 | 5 | class ALICEPPClassificationModel(nn.Module): 6 | # b: batch_size, s: sequence_length, d: hidden_size , n: num_labels 7 | 8 | def __init__(self, extracted_model): 9 | super().__init__() 10 | self.model = extracted_model 11 | self.vat_loss = ALICEPPLoss( 12 | model = extracted_model, 13 | loss_fn = kl_loss, 14 | num_layers = self.model.num_layers, 15 | num_classes = 2 16 | ) 17 | 18 | def forward(self, input_ids, attention_mask, labels): 19 | """ input_ids: (b, s), attention_mask: (b, s), labels: (b,) """ 20 | # Get input embeddings 21 | embeddings = self.model.get_embeddings(input_ids) 22 | # Set iteration specific data (e.g. attention mask) 23 | self.model.set_attention_mask(attention_mask) 24 | # Compute logits 25 | logits, hidden_states = self.model(embeddings, with_hidden_states = True) 26 | # Compute VAT loss 27 | loss = self.vat_loss(hidden_states, logits, labels) 28 | return logits, loss -------------------------------------------------------------------------------- /vat_pytorch/models/extracted_model.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | class ExtractedModel(nn.Module): 4 | """ Interface to be used with extracted models """ 5 | 6 | def forward(self, hidden, *, start_layer = 0): 7 | """ Forwards the hidden value from self.start_layer layer to the logits. """ 8 | raise NotImplementedError() 9 | 10 | def get_embeddings(self, input_ids): 11 | """ Computes first embedding layer given inputs_ids """ 12 | raise NotImplementedError() 13 | 14 | def set_attention_mask(self, attention_mask): 15 | """ Sets the correct mask on all subsequent forward passes """ 16 | # This is optional -------------------------------------------------------------------------------- /vat_pytorch/models/extracted_roberta.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from transformers import AutoModelForSequenceClassification 3 | from .extracted_model import ExtractedModel 4 | 5 | class ExtractedRoBERTa(ExtractedModel): 6 | 7 | def __init__(self): 8 | super().__init__() 9 | model = AutoModelForSequenceClassification.from_pretrained('roberta-base') 10 | self.roberta = model.roberta 11 | self.layers = model.roberta.encoder.layer 12 | self.classifier = model.classifier 13 | self.attention_mask = None 14 | self.num_layers = len(self.layers) - 1 15 | 16 | def forward(self, hidden, with_hidden_states = False, start_layer = 0): 17 | """ Forwards the hidden value from self.start_layer layer to the logits. """ 18 | hidden_states = [hidden] 19 | 20 | for layer in self.layers[start_layer:]: 21 | hidden = layer(hidden, attention_mask = self.attention_mask)[0] 22 | hidden_states += [hidden] 23 | 24 | logits = self.classifier(hidden) 25 | 26 | return (logits, hidden_states) if with_hidden_states else logits 27 | 28 | def get_embeddings(self, input_ids): 29 | """ Computes first embedding layer given inputs_ids """ 30 | return self.roberta.embeddings(input_ids) 31 | 32 | def set_attention_mask(self, attention_mask): 33 | """ Sets the correct mask on all subsequent forward passes """ 34 | self.attention_mask = self.roberta.get_extended_attention_mask( 35 | attention_mask, 36 | input_shape = attention_mask.shape, 37 | device = attention_mask.device 38 | ) # (b, 1, 1, s) -------------------------------------------------------------------------------- /vat_pytorch/models/smart_classification_model.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | from vat_pytorch import SMARTLoss, kl_loss, sym_kl_loss 4 | 5 | class SMARTClassificationModel(nn.Module): 6 | # b: batch_size, s: sequence_length, d: hidden_size , n: num_labels 7 | 8 | def __init__(self, extracted_model, weight = 1.0): 9 | super().__init__() 10 | self.model = extracted_model 11 | self.weight = weight 12 | self.vat_loss = SMARTLoss(model = extracted_model, loss_fn = kl_loss, loss_last_fn = sym_kl_loss) 13 | 14 | def forward(self, input_ids, attention_mask, labels): 15 | """ input_ids: (b, s), attention_mask: (b, s), labels: (b,) """ 16 | # Get input embeddings 17 | embeddings = self.model.get_embeddings(input_ids) 18 | # Set mask and compute logits 19 | self.model.set_attention_mask(attention_mask) 20 | logits = self.model(embeddings) 21 | # Compute CE loss 22 | ce_loss = F.cross_entropy(logits.view(-1, 2), labels.view(-1)) 23 | # Compute VAT loss 24 | vat_loss = self.vat_loss(embeddings, logits) 25 | # Merge losses 26 | loss = ce_loss + self.weight * vat_loss 27 | return logits, loss -------------------------------------------------------------------------------- /vat_pytorch/smart.py: -------------------------------------------------------------------------------- 1 | from typing import Union, Callable 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from torch import Tensor 7 | from itertools import count 8 | from .utils import default, inf_norm 9 | 10 | 11 | class SMARTLoss(nn.Module): 12 | 13 | def __init__( 14 | self, 15 | model: nn.Module, 16 | loss_fn: Callable, 17 | loss_last_fn: Callable = None, 18 | norm_fn: Callable = inf_norm, 19 | num_steps: int = 1, 20 | step_size: float = 1e-3, 21 | epsilon: float = 1e-6, 22 | noise_var: float = 1e-5 23 | ) -> None: 24 | super().__init__() 25 | self.model = model 26 | self.loss_fn = loss_fn 27 | self.loss_last_fn = default(loss_last_fn, loss_fn) 28 | self.norm_fn = norm_fn 29 | self.num_steps = num_steps 30 | self.step_size = step_size 31 | self.epsilon = epsilon 32 | self.noise_var = noise_var 33 | 34 | @torch.enable_grad() 35 | def forward(self, embed: Tensor, state: Tensor): 36 | noise = torch.randn_like(embed, requires_grad = True) * self.noise_var 37 | 38 | # Indefinite loop with counter 39 | for i in count(): 40 | # Compute perturbed embed and states 41 | embed_perturbed = embed + noise 42 | state_perturbed = self.model(embed_perturbed) 43 | # Return final loss if last step (undetached state) 44 | if i == self.num_steps: 45 | return self.loss_last_fn(state_perturbed, state) 46 | # Compute perturbation loss (detached state) 47 | loss = self.loss_fn(state_perturbed, state.detach()) 48 | # Compute noise gradient ∂loss/∂noise 49 | noise_gradient, = torch.autograd.grad(loss, noise) 50 | # Move noise towards gradient to change state as much as possible 51 | step = noise + self.step_size * noise_gradient 52 | # Normalize new noise step into norm induced ball 53 | step_norm = self.norm_fn(step) 54 | noise = step / (step_norm + self.epsilon) 55 | # Reset noise gradients for next step 56 | noise = noise.detach().requires_grad_() 57 | -------------------------------------------------------------------------------- /vat_pytorch/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | 5 | def exists(value): 6 | return value is not None 7 | 8 | 9 | def default(value, default): 10 | if exists(value): 11 | return value 12 | return default 13 | 14 | 15 | def inf_norm(x): 16 | return torch.norm(x, p=float("inf"), dim=-1, keepdim=True) 17 | 18 | 19 | def kl_loss(input, target, reduction="batchmean"): 20 | return F.kl_div( 21 | F.log_softmax(input, dim=-1), 22 | F.softmax(target, dim=-1), 23 | reduction=reduction, 24 | ) 25 | 26 | 27 | def sym_kl_loss(input, target, reduction="batchmean", alpha=1.0): 28 | return alpha * F.kl_div( 29 | F.log_softmax(input, dim=-1), 30 | F.softmax(target.detach(), dim=-1), 31 | reduction=reduction, 32 | ) + F.kl_div( 33 | F.log_softmax(target, dim=-1), 34 | F.softmax(input.detach(), dim=-1), 35 | reduction=reduction, 36 | ) 37 | 38 | 39 | def js_loss(input, target, reduction="batchmean", alpha=1.0): 40 | mean_proba = 0.5 * ( 41 | F.softmax(input.detach(), dim=-1) + F.softmax(target.detach(), dim=-1) 42 | ) 43 | return alpha * ( 44 | F.kl_div(F.log_softmax(input, dim=-1), mean_proba, reduction=reduction) 45 | + F.kl_div(F.log_softmax(target, dim=-1), mean_proba, reduction=reduction) 46 | ) 47 | --------------------------------------------------------------------------------