├── .github
└── workflows
│ └── python-publish.yml
├── .gitignore
├── .vscode
└── settings.json
├── LICENSE
├── LOGO.png
├── README.md
├── setup.py
└── vat_pytorch
├── __init__.py
├── alice.py
├── alicepp.py
├── models
├── __init__.py
├── alice_classification_model.py
├── alicepp_classification_model.py
├── extracted_model.py
├── extracted_roberta.py
└── smart_classification_model.py
├── smart.py
└── utils.py
/.github/workflows/python-publish.yml:
--------------------------------------------------------------------------------
1 | # This workflow will upload a Python Package using Twine when a release is created
2 | # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries
3 |
4 | # This workflow uses actions that are not certified by GitHub.
5 | # They are provided by a third-party and are governed by
6 | # separate terms of service, privacy policy, and support
7 | # documentation.
8 |
9 | name: Upload Python Package
10 |
11 | on:
12 | release:
13 | types: [published]
14 |
15 | permissions:
16 | contents: read
17 |
18 | jobs:
19 | deploy:
20 |
21 | runs-on: ubuntu-latest
22 |
23 | steps:
24 | - uses: actions/checkout@v3
25 | - name: Set up Python
26 | uses: actions/setup-python@v3
27 | with:
28 | python-version: '3.x'
29 | - name: Install dependencies
30 | run: |
31 | python -m pip install --upgrade pip
32 | pip install build
33 | - name: Build package
34 | run: python -m build
35 | - name: Publish package
36 | uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29
37 | with:
38 | user: __token__
39 | password: ${{ secrets.PYPI_API_TOKEN }}
40 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | __pycache__
--------------------------------------------------------------------------------
/.vscode/settings.json:
--------------------------------------------------------------------------------
1 | {
2 | "python.formatting.provider": "black"
3 | }
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2022 archinet.ai
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/LOGO.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/archinetai/vat-pytorch/34df8043969c0d2ccb7e4189ca337d91b7270c2c/LOGO.png
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | A collection of VAT (Virtual Adversarial Training) methods, in PyTorch.
4 |
5 | ## Install
6 |
7 | ```bash
8 | $ pip install vat-pytorch
9 | ```
10 |
11 | [](https://pypi.org/project/vat-pytorch/)
12 |
13 |
14 | ## API
15 |
16 | ### SMART
17 | The SMART paper proposes to find the noise that maximally perturbs the logits when added to the embedding layer, and to use a loss function to make sure that the perturbed logits are as close as possible to the predicted logits.
18 |
19 | ```py
20 | from vat_pytorch import SMARTLoss, inf_norm
21 |
22 | loss = SMARTLoss(
23 | model: nn.Module,
24 | loss_fn: Callable,
25 | loss_last_fn: Callable = None,
26 | norm_fn: Callable = inf_norm,
27 | num_steps: int = 1,
28 | step_size: float = 1e-3,
29 | epsilon: float = 1e-6,
30 | noise_var: float = 1e-5
31 | )
32 | ```
33 |
34 | ### ALICE
35 |
36 | The ALICE paper is analogous to the SMART paper, but adds an additional term to make sure that the perturbed logits are as close as possible to both the predicted logits *and* the ground truth labels.
37 |
38 | ```py
39 | from vat_pytorch import ALICELoss, inf_norm
40 |
41 | loss = ALICEPPLoss(
42 | model: nn.Module,
43 | loss_fn: Callable,
44 | num_classes: int,
45 | loss_last_fn: Callable = None,
46 | gold_loss_fn: Callable = None,
47 | gold_loss_last_fn: Callable = None,
48 | norm_fn: Callable = inf_norm,
49 | alpha: float = 1,
50 | num_steps: int = 1,
51 | step_size: float = 1e-3,
52 | epsilon: float = 1e-6,
53 | noise_var: float = 1e-5,
54 | )
55 | ```
56 |
57 | ### ALICE++
58 |
59 | The ALICE++ paper is analogous to the ALICE paper, but instead of adding noise to the embedding layer, it picks a random layer from the network at each iteration on which to add the noise.
60 |
61 | ```py
62 | from vat_pytorch import ALICEPPLoss, ALICEPPModule, inf_norm
63 |
64 | loss = ALICEPPLoss(
65 | model: ALICEPPModule,
66 | num_classes: int,
67 | loss_fn: Callable,
68 | num_layers: int,
69 | max_layer: int = None,
70 | loss_last_fn: Callable = None,
71 | gold_loss_fn: Callable = None,
72 | gold_loss_last_fn: Callable = None,
73 | norm_fn: Callable = inf_norm,
74 | alpha: float = 1,
75 | num_steps: int = 1,
76 | step_size: float = 1e-3,
77 | epsilon: float = 1e-6,
78 | noise_var: float = 1e-5,
79 | )
80 | ```
81 |
82 | ## Usage (Classification)
83 |
84 | ### Extract Model
85 | The first thing we have to do is extract the chunk of the model that we want to perturb adversarially. A generic example with Huggingface's RoBERTa for sequence classification is given.
86 |
87 | ```py
88 | import torch.nn as nn
89 | from transformers import AutoModelForSequenceClassification
90 |
91 | class ExtractedRoBERTa(nn.Module):
92 |
93 | def __init__(self):
94 | super().__init__()
95 | model = AutoModelForSequenceClassification.from_pretrained('roberta-base')
96 | self.roberta = model.roberta
97 | self.layers = model.roberta.encoder.layer
98 | self.classifier = model.classifier
99 | self.attention_mask = None
100 | self.num_layers = len(self.layers) - 1
101 |
102 | def forward(self, hidden, with_hidden_states = False, start_layer = 0):
103 | """ Forwards the hidden value from self.start_layer layer to the logits. """
104 | hidden_states = [hidden]
105 |
106 | for layer in self.layers[start_layer:]:
107 | hidden = layer(hidden, attention_mask = self.attention_mask)[0]
108 | hidden_states += [hidden]
109 |
110 | logits = self.classifier(hidden)
111 |
112 | return (logits, hidden_states) if with_hidden_states else logits
113 |
114 | def get_embeddings(self, input_ids):
115 | """ Computes first embedding layer given inputs_ids """
116 | return self.roberta.embeddings(input_ids)
117 |
118 | def set_attention_mask(self, attention_mask):
119 | """ Sets the correct mask on all subsequent forward passes """
120 | self.attention_mask = self.roberta.get_extended_attention_mask(
121 | attention_mask,
122 | input_shape = attention_mask.shape,
123 | device = attention_mask.device
124 | ) # (b, 1, 1, s)
125 | ```
126 | The function `set_attention_mask` is used to fix the attention mask for all subsequent forward calls, this is necessary if we want to use a mask using any VAT loss. The parameter `start_layer` in the forward function is necessary only if we are using `ALICEPPLoss` since the loss function needs a way to change the start layer internally.
127 |
128 |
129 | ### SMART
130 |
131 | ```py
132 | import torch.nn as nn
133 | import torch.nn.functional as F
134 | from vat_pytorch import SMARTLoss, kl_loss, sym_kl_loss
135 |
136 | class SMARTClassificationModel(nn.Module):
137 | # b: batch_size, s: sequence_length, d: hidden_size , n: num_labels
138 |
139 | def __init__(self, extracted_model, weight = 1.0):
140 | super().__init__()
141 | self.model = extracted_model
142 | self.weight = weight
143 | self.vat_loss = SMARTLoss(model = extracted_model, loss_fn = kl_loss, loss_last_fn = sym_kl_loss)
144 |
145 | def forward(self, input_ids, attention_mask, labels):
146 | """ input_ids: (b, s), attention_mask: (b, s), labels: (b,) """
147 | # Get input embeddings
148 | embeddings = self.model.get_embeddings(input_ids)
149 | # Set mask and compute logits
150 | self.model.set_attention_mask(attention_mask)
151 | logits = self.model(embeddings)
152 | # Compute CE loss
153 | ce_loss = F.cross_entropy(logits.view(-1, 2), labels.view(-1))
154 | # Compute VAT loss
155 | vat_loss = self.vat_loss(embeddings, logits)
156 | # Merge losses
157 | loss = ce_loss + self.weight * vat_loss
158 | return logits, loss
159 | ```
160 |
161 | ### ALICE
162 |
163 | ```py
164 | import torch.nn as nn
165 | import torch.nn.functional as F
166 | from vat_pytorch import ALICELoss, kl_loss
167 |
168 | class ALICEClassificationModel(nn.Module):
169 | # b: batch_size, s: sequence_length, d: hidden_size , n: num_labels
170 |
171 | def __init__(self, extracted_model):
172 | super().__init__()
173 | self.model = extracted_model
174 | self.vat_loss = ALICELoss(model = extracted_model, loss_fn = kl_loss, num_classes = 2)
175 |
176 | def forward(self, input_ids, attention_mask, labels):
177 | """ input_ids: (b, s), attention_mask: (b, s), labels: (b,) """
178 | # Get input embeddings
179 | embeddings = self.model.get_embeddings(input_ids)
180 | # Set iteration specific data (e.g. attention mask)
181 | self.model.set_attention_mask(attention_mask)
182 | # Compute logits
183 | logits = self.model(embeddings)
184 | # Compute VAT loss
185 | loss = self.vat_loss(embeddings, logits, labels)
186 | return logits, loss
187 | ```
188 |
189 | ### ALICE++
190 |
191 | ```py
192 | import torch.nn as nn
193 | import torch.nn.functional as F
194 | from vat_pytorch import ALICEPPLoss, kl_loss
195 |
196 | class ALICEPPClassificationModel(nn.Module):
197 | # b: batch_size, s: sequence_length, d: hidden_size , n: num_labels
198 |
199 | def __init__(self, extracted_model):
200 | super().__init__()
201 | self.model = extracted_model
202 | self.vat_loss = ALICEPPLoss(
203 | model = extracted_model,
204 | loss_fn = kl_loss,
205 | num_layers = self.model.num_layers,
206 | num_classes = 2
207 | )
208 |
209 | def forward(self, input_ids, attention_mask, labels):
210 | """ input_ids: (b, s), attention_mask: (b, s), labels: (b,) """
211 | # Get input embeddings
212 | embeddings = self.model.get_embeddings(input_ids)
213 | # Set iteration specific data (e.g. attention mask)
214 | self.model.set_attention_mask(attention_mask)
215 | # Compute logits
216 | logits, hidden_states = self.model(embeddings, with_hidden_states = True)
217 | # Compute VAT loss
218 | loss = self.vat_loss(hidden_states, logits, labels)
219 | return logits, loss
220 | ```
221 |
222 | Note that `extracted_model` requires a function with the following signature `forward(self, hidden: Tensor, *, start_layer: int) -> Tensor`, the interface `ALICEPPModule` (`from vat_pytorch import ALICEPPModule`) can be used instead of the `nn.Module` class on the extracted model to make sure that the method is present.
223 |
224 |
225 | ### Wrapped Model Usage
226 | Any of the above losses can be used as follows with the extracted model.
227 | ```py
228 | import torch
229 | from transformers import AutoTokenizer
230 |
231 | extracted_model = ExtractedRoBERTa()
232 | tokenizer = AutoTokenizer.from_pretrained('roberta-base')
233 | # Pick one:
234 | model = SMARTClassificationModel(extracted_model)
235 | model = ALICEClassificationModel(extracted_model)
236 | model = ALICEPPClassificationModel(extracted_model)
237 | # Compute inputs
238 | text = ["This text belongs to class 1...", "This text belongs to class 0..."]
239 | inputs = tokenizer(text, return_tensors='pt')
240 | labels = torch.tensor([1, 0])
241 | # Compute logits and loss
242 | logits, loss = model(input_ids = inputs['input_ids'], attention_mask = inputs['attention_mask'], labels = labels)
243 | # To finetune do this for many steps
244 | loss.backward()
245 | ```
246 |
247 | ## Citations
248 |
249 | ```bibtex
250 | @inproceedings{Jiang2020SMARTRA,
251 | title={SMART: Robust and Efficient Fine-Tuning for Pre-trained Natural Language Models through Principled Regularized Optimization},
252 | author={Haoming Jiang and Pengcheng He and Weizhu Chen and Xiaodong Liu and Jianfeng Gao and Tuo Zhao},
253 | booktitle={ACL},
254 | year={2020}
255 | }
256 | ```
257 |
258 | ```bibtex
259 | @article{Pereira2020AdversarialTF,
260 | title={Adversarial Training for Commonsense Inference},
261 | author={Lis Kanashiro Pereira and Xiaodong Liu and Fei Cheng and Masayuki Asahara and Ichiro Kobayashi},
262 | journal={ArXiv},
263 | year={2020},
264 | volume={abs/2005.08156}
265 | }
266 | ```
267 |
268 | ```bibtex
269 | @inproceedings{Pereira2021ALICEAT,
270 | title={ALICE++: Adversarial Training for Robust and Effective Temporal Reasoning},
271 | author={Lis Kanashiro Pereira and Fei Cheng and Masayuki Asahara and Ichiro Kobayashi},
272 | booktitle={PACLIC},
273 | year={2021}
274 | }
275 | ```
276 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup, find_packages
2 |
3 | setup(
4 | name="vat-pytorch",
5 | packages=find_packages(exclude=[]),
6 | version="0.0.9",
7 | license="MIT",
8 | description="Virtual Adversarial Training - Pytorch",
9 | long_description_content_type="text/markdown",
10 | author="Archinet",
11 | author_email="archinetai@protonmail.com",
12 | url="https://github.com/archinetai/vat-pytorch",
13 | keywords=[
14 | "artificial intelligence",
15 | "deep learning",
16 | "fine-tuning",
17 | "pre-trained",
18 | ],
19 | install_requires=[
20 | "torch>=1.6",
21 | "data-science-types>=0.2"
22 | "transformers>=4.0.0"
23 | ],
24 | classifiers=[
25 | "Development Status :: 4 - Beta",
26 | "Intended Audience :: Developers",
27 | "Topic :: Scientific/Engineering :: Artificial Intelligence",
28 | "License :: OSI Approved :: MIT License",
29 | "Programming Language :: Python :: 3.6",
30 | ],
31 | )
32 |
--------------------------------------------------------------------------------
/vat_pytorch/__init__.py:
--------------------------------------------------------------------------------
1 | from .smart import SMARTLoss
2 | from .alice import ALICELoss
3 | from .alicepp import ALICEPPLoss, ALICEPPModule
4 | from .utils import kl_loss, sym_kl_loss, js_loss, inf_norm, default
5 | from .models import *
--------------------------------------------------------------------------------
/vat_pytorch/alice.py:
--------------------------------------------------------------------------------
1 | from typing import Union, Callable
2 |
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 | from torch import Tensor
7 | from itertools import count
8 | from .utils import default, inf_norm
9 |
10 |
11 | class ALICELoss(nn.Module):
12 | def __init__(
13 | self,
14 | model: nn.Module,
15 | loss_fn: Callable,
16 | num_classes: int,
17 | loss_last_fn: Callable = None,
18 | gold_loss_fn: Callable = None,
19 | gold_loss_last_fn: Callable = None,
20 | norm_fn: Callable = inf_norm,
21 | alpha: float = 1,
22 | num_steps: int = 1,
23 | step_size: float = 1e-3,
24 | epsilon: float = 1e-6,
25 | noise_var: float = 1e-5,
26 | ) -> None:
27 | super().__init__()
28 | self.model = model
29 | self.num_classes = num_classes
30 | self.loss_fn = loss_fn
31 | self.loss_last_fn = default(loss_last_fn, loss_fn)
32 | self.gold_loss_fn = default(gold_loss_fn, loss_fn)
33 | self.gold_loss_last_fn = default(
34 | default(gold_loss_last_fn, self.gold_loss_fn), self.loss_last_fn
35 | )
36 | self.norm_fn = norm_fn
37 | self.alpha = alpha
38 | self.num_steps = num_steps
39 | self.step_size = step_size
40 | self.epsilon = epsilon
41 | self.noise_var = noise_var
42 |
43 | def forward(self, embed: Tensor, state: Tensor, labels: Tensor) -> Tensor:
44 |
45 | virtual_loss = self.get_perturbed_loss(
46 | embed, state, loss_fn=self.loss_fn, loss_last_fn=self.loss_last_fn
47 | )
48 |
49 | labels_loss = self.get_perturbed_loss(
50 | embed,
51 | state=F.one_hot(labels, num_classes=self.num_classes).float(),
52 | loss_fn=self.gold_loss_fn,
53 | loss_last_fn=self.gold_loss_last_fn,
54 | )
55 |
56 | return labels_loss + self.alpha * virtual_loss
57 |
58 | @torch.enable_grad()
59 | def get_perturbed_loss(
60 | self, embed: Tensor, state: Tensor, loss_fn: Callable, loss_last_fn: Callable
61 | ):
62 | noise = torch.randn_like(embed, requires_grad=True) * self.noise_var
63 |
64 | # Indefinite loop with counter
65 | for i in count():
66 | # Compute perturbed embed and states
67 | embed_perturbed = embed + noise
68 | state_perturbed = self.model(embed_perturbed)
69 | # Return final loss if last step (undetached state)
70 | if i == self.num_steps:
71 | return loss_last_fn(state_perturbed, state)
72 | # Compute perturbation loss (detached state)
73 | loss = loss_fn(state_perturbed, state.detach())
74 | # Compute noise gradient ∂loss/∂noise
75 | (noise_gradient,) = torch.autograd.grad(loss, noise)
76 | # Move noise towards gradient to change state as much as possible
77 | step = noise + self.step_size * noise_gradient
78 | # Normalize new noise step into norm induced ball
79 | step_norm = self.norm_fn(step)
80 | noise = step / (step_norm + self.epsilon)
81 | # Reset noise gradients for next step
82 | noise = noise.detach().requires_grad_()
83 |
--------------------------------------------------------------------------------
/vat_pytorch/alicepp.py:
--------------------------------------------------------------------------------
1 | from typing import List, Union, Callable
2 |
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 | from torch import Tensor
7 | from itertools import count
8 | from .utils import default, inf_norm
9 |
10 |
11 | class ALICEPPModule(nn.Module):
12 | """ Interface for the model provided to ALICEPPLoss """
13 |
14 | def __init__(self):
15 | super().__init__()
16 |
17 | def forward(self, hidden: Tensor, *, start_layer: int) -> Tensor:
18 | raise NotImplementedError()
19 |
20 |
21 | class ALICEPPLoss(nn.Module):
22 |
23 | def __init__(
24 | self,
25 | model: ALICEPPModule,
26 | num_classes: int,
27 | loss_fn: Callable,
28 | num_layers: int,
29 | max_layer: int = None,
30 | loss_last_fn: Callable = None,
31 | gold_loss_fn: Callable = None,
32 | gold_loss_last_fn: Callable = None,
33 | norm_fn: Callable = inf_norm,
34 | alpha: float = 1,
35 | num_steps: int = 1,
36 | step_size: float = 1e-3,
37 | epsilon: float = 1e-6,
38 | noise_var: float = 1e-5,
39 | ) -> None:
40 | super().__init__()
41 | self.model = model
42 | self.num_classes = num_classes
43 | self.loss_fn = loss_fn
44 | self.num_layers = num_layers
45 | self.max_layer = min(default(max_layer, num_layers), num_layers)
46 | self.loss_last_fn = default(loss_last_fn, loss_fn)
47 | self.gold_loss_fn = default(gold_loss_fn, loss_fn)
48 | self.gold_loss_last_fn = default(default(gold_loss_last_fn, self.gold_loss_fn), self.loss_last_fn)
49 | self.norm_fn = norm_fn
50 | self.alpha = alpha
51 | self.num_steps = num_steps
52 | self.step_size = step_size
53 | self.epsilon = epsilon
54 | self.noise_var = noise_var
55 |
56 | def forward(self, hiddens: List[Tensor], state: Tensor, labels: Tensor) -> Tensor:
57 |
58 | # Pick random layer on which we apply the perturbation
59 | random_layer_id = torch.randint(low = 0, high = self.max_layer, size = (1,))[0]
60 |
61 | virtual_loss = self.get_perturbed_loss(
62 | hidden = hiddens[random_layer_id],
63 | state = state,
64 | layer_id = random_layer_id,
65 | loss_fn = self.loss_fn,
66 | loss_last_fn = self.loss_last_fn
67 | )
68 |
69 | label_loss = self.get_perturbed_loss(
70 | hidden = hiddens[random_layer_id],
71 | state = F.one_hot(labels, num_classes=self.num_classes).float(),
72 | layer_id = random_layer_id,
73 | loss_fn = self.gold_loss_fn,
74 | loss_last_fn = self.gold_loss_last_fn
75 | )
76 |
77 | return label_loss + self.alpha * virtual_loss
78 |
79 | @torch.enable_grad()
80 | def get_perturbed_loss(
81 | self,
82 | hidden: Tensor,
83 | state: Tensor,
84 | layer_id: int,
85 | loss_fn: Callable,
86 | loss_last_fn: Callable
87 | ):
88 | noise = torch.randn_like(hidden, requires_grad = True) * self.noise_var
89 |
90 | # Indefinite loop with counter
91 | for i in count():
92 | # Compute perturbed hidden and states
93 | hidden_perturbed = hidden + noise
94 | state_perturbed = self.model(hidden_perturbed, start_layer = layer_id)
95 | # Return final loss if last step (undetached state)
96 | if i == self.num_steps:
97 | return loss_last_fn(state_perturbed, state)
98 | # Compute perturbation loss (detached state)
99 | loss = loss_fn(state_perturbed, state.detach())
100 | # Compute noise gradient ∂loss/∂noise
101 | noise_gradient, = torch.autograd.grad(loss, noise)
102 | # Move noise towards gradient to change state as much as possible
103 | step = noise + self.step_size * noise_gradient
104 | # Normalize new noise step into norm induced ball
105 | step_norm = self.norm_fn(step)
106 | noise = step / (step_norm + self.epsilon)
107 | # Reset noise gradients for next step
108 | noise = noise.detach().requires_grad_()
109 |
--------------------------------------------------------------------------------
/vat_pytorch/models/__init__.py:
--------------------------------------------------------------------------------
1 | from .extracted_model import ExtractedModel
2 | from .extracted_roberta import ExtractedRoBERTa
3 | from .smart_classification_model import SMARTClassificationModel
4 | from .alice_classification_model import ALICEClassificationModel
5 | from .alicepp_classification_model import ALICEPPClassificationModel
--------------------------------------------------------------------------------
/vat_pytorch/models/alice_classification_model.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch.nn.functional as F
3 | from vat_pytorch import ALICELoss, kl_loss
4 |
5 | class ALICEClassificationModel(nn.Module):
6 | # b: batch_size, s: sequence_length, d: hidden_size , n: num_labels
7 |
8 | def __init__(self, extracted_model):
9 | super().__init__()
10 | self.model = extracted_model
11 | self.vat_loss = ALICELoss(model = extracted_model, loss_fn = kl_loss, num_classes = 2)
12 |
13 | def forward(self, input_ids, attention_mask, labels):
14 | """ input_ids: (b, s), attention_mask: (b, s), labels: (b,) """
15 | # Get input embeddings
16 | embeddings = self.model.get_embeddings(input_ids)
17 | # Set iteration specific data (e.g. attention mask)
18 | self.model.set_attention_mask(attention_mask)
19 | # Compute logits
20 | logits = self.model(embeddings)
21 | # Compute VAT loss
22 | loss = self.vat_loss(embeddings, logits, labels)
23 | return logits, loss
--------------------------------------------------------------------------------
/vat_pytorch/models/alicepp_classification_model.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch.nn.functional as F
3 | from vat_pytorch import ALICEPPLoss, kl_loss
4 |
5 | class ALICEPPClassificationModel(nn.Module):
6 | # b: batch_size, s: sequence_length, d: hidden_size , n: num_labels
7 |
8 | def __init__(self, extracted_model):
9 | super().__init__()
10 | self.model = extracted_model
11 | self.vat_loss = ALICEPPLoss(
12 | model = extracted_model,
13 | loss_fn = kl_loss,
14 | num_layers = self.model.num_layers,
15 | num_classes = 2
16 | )
17 |
18 | def forward(self, input_ids, attention_mask, labels):
19 | """ input_ids: (b, s), attention_mask: (b, s), labels: (b,) """
20 | # Get input embeddings
21 | embeddings = self.model.get_embeddings(input_ids)
22 | # Set iteration specific data (e.g. attention mask)
23 | self.model.set_attention_mask(attention_mask)
24 | # Compute logits
25 | logits, hidden_states = self.model(embeddings, with_hidden_states = True)
26 | # Compute VAT loss
27 | loss = self.vat_loss(hidden_states, logits, labels)
28 | return logits, loss
--------------------------------------------------------------------------------
/vat_pytorch/models/extracted_model.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 |
3 | class ExtractedModel(nn.Module):
4 | """ Interface to be used with extracted models """
5 |
6 | def forward(self, hidden, *, start_layer = 0):
7 | """ Forwards the hidden value from self.start_layer layer to the logits. """
8 | raise NotImplementedError()
9 |
10 | def get_embeddings(self, input_ids):
11 | """ Computes first embedding layer given inputs_ids """
12 | raise NotImplementedError()
13 |
14 | def set_attention_mask(self, attention_mask):
15 | """ Sets the correct mask on all subsequent forward passes """
16 | # This is optional
--------------------------------------------------------------------------------
/vat_pytorch/models/extracted_roberta.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | from transformers import AutoModelForSequenceClassification
3 | from .extracted_model import ExtractedModel
4 |
5 | class ExtractedRoBERTa(ExtractedModel):
6 |
7 | def __init__(self):
8 | super().__init__()
9 | model = AutoModelForSequenceClassification.from_pretrained('roberta-base')
10 | self.roberta = model.roberta
11 | self.layers = model.roberta.encoder.layer
12 | self.classifier = model.classifier
13 | self.attention_mask = None
14 | self.num_layers = len(self.layers) - 1
15 |
16 | def forward(self, hidden, with_hidden_states = False, start_layer = 0):
17 | """ Forwards the hidden value from self.start_layer layer to the logits. """
18 | hidden_states = [hidden]
19 |
20 | for layer in self.layers[start_layer:]:
21 | hidden = layer(hidden, attention_mask = self.attention_mask)[0]
22 | hidden_states += [hidden]
23 |
24 | logits = self.classifier(hidden)
25 |
26 | return (logits, hidden_states) if with_hidden_states else logits
27 |
28 | def get_embeddings(self, input_ids):
29 | """ Computes first embedding layer given inputs_ids """
30 | return self.roberta.embeddings(input_ids)
31 |
32 | def set_attention_mask(self, attention_mask):
33 | """ Sets the correct mask on all subsequent forward passes """
34 | self.attention_mask = self.roberta.get_extended_attention_mask(
35 | attention_mask,
36 | input_shape = attention_mask.shape,
37 | device = attention_mask.device
38 | ) # (b, 1, 1, s)
--------------------------------------------------------------------------------
/vat_pytorch/models/smart_classification_model.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch.nn.functional as F
3 | from vat_pytorch import SMARTLoss, kl_loss, sym_kl_loss
4 |
5 | class SMARTClassificationModel(nn.Module):
6 | # b: batch_size, s: sequence_length, d: hidden_size , n: num_labels
7 |
8 | def __init__(self, extracted_model, weight = 1.0):
9 | super().__init__()
10 | self.model = extracted_model
11 | self.weight = weight
12 | self.vat_loss = SMARTLoss(model = extracted_model, loss_fn = kl_loss, loss_last_fn = sym_kl_loss)
13 |
14 | def forward(self, input_ids, attention_mask, labels):
15 | """ input_ids: (b, s), attention_mask: (b, s), labels: (b,) """
16 | # Get input embeddings
17 | embeddings = self.model.get_embeddings(input_ids)
18 | # Set mask and compute logits
19 | self.model.set_attention_mask(attention_mask)
20 | logits = self.model(embeddings)
21 | # Compute CE loss
22 | ce_loss = F.cross_entropy(logits.view(-1, 2), labels.view(-1))
23 | # Compute VAT loss
24 | vat_loss = self.vat_loss(embeddings, logits)
25 | # Merge losses
26 | loss = ce_loss + self.weight * vat_loss
27 | return logits, loss
--------------------------------------------------------------------------------
/vat_pytorch/smart.py:
--------------------------------------------------------------------------------
1 | from typing import Union, Callable
2 |
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 | from torch import Tensor
7 | from itertools import count
8 | from .utils import default, inf_norm
9 |
10 |
11 | class SMARTLoss(nn.Module):
12 |
13 | def __init__(
14 | self,
15 | model: nn.Module,
16 | loss_fn: Callable,
17 | loss_last_fn: Callable = None,
18 | norm_fn: Callable = inf_norm,
19 | num_steps: int = 1,
20 | step_size: float = 1e-3,
21 | epsilon: float = 1e-6,
22 | noise_var: float = 1e-5
23 | ) -> None:
24 | super().__init__()
25 | self.model = model
26 | self.loss_fn = loss_fn
27 | self.loss_last_fn = default(loss_last_fn, loss_fn)
28 | self.norm_fn = norm_fn
29 | self.num_steps = num_steps
30 | self.step_size = step_size
31 | self.epsilon = epsilon
32 | self.noise_var = noise_var
33 |
34 | @torch.enable_grad()
35 | def forward(self, embed: Tensor, state: Tensor):
36 | noise = torch.randn_like(embed, requires_grad = True) * self.noise_var
37 |
38 | # Indefinite loop with counter
39 | for i in count():
40 | # Compute perturbed embed and states
41 | embed_perturbed = embed + noise
42 | state_perturbed = self.model(embed_perturbed)
43 | # Return final loss if last step (undetached state)
44 | if i == self.num_steps:
45 | return self.loss_last_fn(state_perturbed, state)
46 | # Compute perturbation loss (detached state)
47 | loss = self.loss_fn(state_perturbed, state.detach())
48 | # Compute noise gradient ∂loss/∂noise
49 | noise_gradient, = torch.autograd.grad(loss, noise)
50 | # Move noise towards gradient to change state as much as possible
51 | step = noise + self.step_size * noise_gradient
52 | # Normalize new noise step into norm induced ball
53 | step_norm = self.norm_fn(step)
54 | noise = step / (step_norm + self.epsilon)
55 | # Reset noise gradients for next step
56 | noise = noise.detach().requires_grad_()
57 |
--------------------------------------------------------------------------------
/vat_pytorch/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 |
4 |
5 | def exists(value):
6 | return value is not None
7 |
8 |
9 | def default(value, default):
10 | if exists(value):
11 | return value
12 | return default
13 |
14 |
15 | def inf_norm(x):
16 | return torch.norm(x, p=float("inf"), dim=-1, keepdim=True)
17 |
18 |
19 | def kl_loss(input, target, reduction="batchmean"):
20 | return F.kl_div(
21 | F.log_softmax(input, dim=-1),
22 | F.softmax(target, dim=-1),
23 | reduction=reduction,
24 | )
25 |
26 |
27 | def sym_kl_loss(input, target, reduction="batchmean", alpha=1.0):
28 | return alpha * F.kl_div(
29 | F.log_softmax(input, dim=-1),
30 | F.softmax(target.detach(), dim=-1),
31 | reduction=reduction,
32 | ) + F.kl_div(
33 | F.log_softmax(target, dim=-1),
34 | F.softmax(input.detach(), dim=-1),
35 | reduction=reduction,
36 | )
37 |
38 |
39 | def js_loss(input, target, reduction="batchmean", alpha=1.0):
40 | mean_proba = 0.5 * (
41 | F.softmax(input.detach(), dim=-1) + F.softmax(target.detach(), dim=-1)
42 | )
43 | return alpha * (
44 | F.kl_div(F.log_softmax(input, dim=-1), mean_proba, reduction=reduction)
45 | + F.kl_div(F.log_softmax(target, dim=-1), mean_proba, reduction=reduction)
46 | )
47 |
--------------------------------------------------------------------------------