├── .github
└── workflows
│ └── python-publish.yml
├── .gitignore
├── LICENSE
├── README.md
├── SMART.png
├── setup.py
└── smart_pytorch
├── __init__.py
├── loss.py
└── smart_pytorch.py
/.github/workflows/python-publish.yml:
--------------------------------------------------------------------------------
1 | # This workflow will upload a Python Package using Twine when a release is created
2 | # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries
3 |
4 | # This workflow uses actions that are not certified by GitHub.
5 | # They are provided by a third-party and are governed by
6 | # separate terms of service, privacy policy, and support
7 | # documentation.
8 |
9 | name: Upload Python Package
10 |
11 | on:
12 | release:
13 | types: [published]
14 |
15 | permissions:
16 | contents: read
17 |
18 | jobs:
19 | deploy:
20 |
21 | runs-on: ubuntu-latest
22 |
23 | steps:
24 | - uses: actions/checkout@v3
25 | - name: Set up Python
26 | uses: actions/setup-python@v3
27 | with:
28 | python-version: '3.x'
29 | - name: Install dependencies
30 | run: |
31 | python -m pip install --upgrade pip
32 | pip install build
33 | - name: Build package
34 | run: python -m build
35 | - name: Publish package
36 | uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29
37 | with:
38 | user: __token__
39 | password: ${{ secrets.PYPI_API_TOKEN }}
40 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | __pycache__
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2022 archinet.ai
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 | # SMART - PyTorch
6 |
7 | A PyTorch implementation of SMART, a regularization technique to fine-tune pretrained (language) models. You might also be interested in vat-pytorch, a more generic collection of virtual adversarial training (VAT) methods, in PyTorch.
8 |
9 | ## Install
10 |
11 | ```bash
12 | $ pip install smart-pytorch
13 | ```
14 |
15 | [](https://pypi.org/project/smart-pytorch/)
16 |
17 | ## Usage
18 |
19 | ### Minimal Example
20 |
21 | ```py
22 | import torch
23 | import torch.nn as nn
24 | from smart_pytorch import SMARTLoss
25 |
26 | # Define function that will be perturbed (usually our network)
27 | eval_fn = torch.nn.Linear(in_features=10, out_features=20)
28 |
29 | # Define loss function between states
30 | loss_fn = nn.MSELoss()
31 |
32 | # Initialize regularization loss
33 | regularizer = SMARTLoss(eval_fn = eval_fn, loss_fn = loss_fn)
34 |
35 | # Compute initial input embed and output state
36 | embed = torch.rand(1, 10) # [batch_size, in_features]
37 | state = eval_fn(embed) # [batch_size, out_featueres]
38 |
39 | # Compute regularation loss
40 | loss = regularizer(embed, state)
41 | loss # tensor(0.0922578126, grad_fn=)
42 | ```
43 |
44 | Where `eval_fn` is a function (usually a neural network) that takes as input an embedding `embed` and produces as output one or multiple states `state`. Internally, this function is used to perturb the input `embed` with noise to get a perturbed `state` which is compared with the initially provided `state`.
45 |
46 | ### Full API Example
47 | ```python
48 | import torch
49 | import torch.nn as nn
50 | from smart_pytorch import SMARTLoss
51 |
52 | # Define function that will be perturbed (usually our network)
53 | eval_fn = torch.nn.Linear(in_features=10, out_features=20)
54 |
55 | # Define loss function between states
56 | loss_fn = nn.MSELoss()
57 |
58 | # Norm used to normalize the gradient
59 | inf_norm = lambda x: torch.norm(x, p=float('inf'), dim=-1, keepdim=True)
60 |
61 | # Initialize regularization loss
62 | regularizer = SMARTLoss(
63 | eval_fn = eval_fn,
64 | loss_fn = loss_fn, # Loss to apply between perturbed and true state
65 | loss_last_fn = loss_fn, # Loss to apply between perturbed and true state on the last iteration (default = loss_fn)
66 | norm_fn = inf_norm, # Norm used to normalize the gradient (default = inf_norm)
67 | num_steps = 1, # Number of optimization steps to find noise (default = 1)
68 | step_size = 1e-3, # Step size to improve noise (default = 1e-3)
69 | epsilon = 1e-6, # Noise norm constraint (default = 1e-6)
70 | noise_var = 1e-5 # Initial noise variance (default = 1e-5)
71 | )
72 |
73 | # Compute initial input embed and output state
74 | embed = torch.rand(1, 10) # [batch_size, in_features]
75 | state = eval_fn(embed) # [batch_size, out_featueres]
76 |
77 | # Compute regularation loss
78 | loss = regularizer(embed, state)
79 | loss # tensor(0.0432184562, grad_fn=)
80 | ```
81 |
82 | ### RoBERTa Classification Example
83 |
84 | This example demostrates how to wrap a RoBERTa classifier from Huggingface to use with SMART.
85 |
86 | ```py
87 | from smart_pytorch import SMARTLoss, kl_loss, sym_kl_loss
88 | from transformers import AutoTokenizer, AutoModelForSequenceClassification
89 |
90 | class SMARTRobertaClassificationModel(nn.Module):
91 |
92 | def __init__(self, model, weight = 0.02):
93 | super().__init__()
94 | self.model = model
95 | self.weight = weight
96 |
97 | def forward(self, input_ids, attention_mask, labels):
98 |
99 | # Get initial embeddings
100 | embed = self.model.roberta.embeddings(input_ids)
101 |
102 | # Define eval function
103 | def eval(embed):
104 | outputs = self.model.roberta(inputs_embeds=embed, attention_mask=attention_mask)
105 | pooled = outputs[0]
106 | logits = self.model.classifier(pooled)
107 | return logits
108 |
109 | # Define SMART loss
110 | smart_loss_fn = SMARTLoss(eval_fn = eval, loss_fn = kl_loss, loss_last_fn = sym_kl_loss)
111 | # Compute initial (unperturbed) state
112 | state = eval(embed)
113 | # Apply classification loss
114 | loss = F.cross_entropy(state.view(-1, 2), labels.view(-1))
115 | # Apply smart loss
116 | loss += self.weight * smart_loss_fn(embed, state)
117 |
118 | return state, loss
119 |
120 |
121 | tokenizer = AutoTokenizer.from_pretrained('roberta-base')
122 | model = AutoModelForSequenceClassification.from_pretrained('roberta-base')
123 |
124 | model_smart = SMARTRobertaClassificationModel(model)
125 | # Compute inputs
126 | text = ["This text belongs to class 1...", "This text belongs to class 0..."]
127 | inputs = tokenizer(text, return_tensors='pt')
128 | labels = torch.tensor([1, 0])
129 |
130 | # Compute output and loss
131 | state, loss = model_smart(input_ids = inputs['input_ids'], attention_mask = inputs['attention_mask'], labels = labels)
132 | print(state.shape, loss) # torch.Size([2, 2]) tensor(0.6980957389, grad_fn=)
133 | ```
134 |
135 |
136 |
137 |
138 | ## Citations
139 |
140 | ```bibtex
141 | @inproceedings{Jiang2020SMARTRA,
142 | title={SMART: Robust and Efficient Fine-Tuning for Pre-trained Natural Language Models through Principled Regularized Optimization},
143 | author={Haoming Jiang and Pengcheng He and Weizhu Chen and Xiaodong Liu and Jianfeng Gao and Tuo Zhao},
144 | booktitle={ACL},
145 | year={2020}
146 | }
147 | ```
148 |
--------------------------------------------------------------------------------
/SMART.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/archinetai/smart-pytorch/e96d8630dc58e1dce8540f61f91016849925ebfe/SMART.png
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup, find_packages
2 |
3 | setup(
4 | name = 'smart-pytorch',
5 | packages = find_packages(exclude=[]),
6 | version = '0.0.4',
7 | license='MIT',
8 | description = 'SMART Fine-Tuning - Pytorch',
9 | long_description_content_type = 'text/markdown',
10 | author = 'Flavio Schneider',
11 | author_email = 'archinetai@protonmail.com',
12 | url = 'https://github.com/archinetai/smart-pytorch',
13 | keywords = [
14 | 'artificial intelligence',
15 | 'deep learning',
16 | 'fine-tuning',
17 | 'pre-trained',
18 | ],
19 | install_requires=[
20 | 'torch>=1.6',
21 | 'data-science-types>=0.2'
22 | ],
23 | classifiers=[
24 | 'Development Status :: 4 - Beta',
25 | 'Intended Audience :: Developers',
26 | 'Topic :: Scientific/Engineering :: Artificial Intelligence',
27 | 'License :: OSI Approved :: MIT License',
28 | 'Programming Language :: Python :: 3.6',
29 | ],
30 | )
31 |
--------------------------------------------------------------------------------
/smart_pytorch/__init__.py:
--------------------------------------------------------------------------------
1 | from .smart_pytorch import SMARTLoss
2 | from .loss import kl_loss, sym_kl_loss, js_loss
--------------------------------------------------------------------------------
/smart_pytorch/loss.py:
--------------------------------------------------------------------------------
1 | import torch.nn.functional as F
2 |
3 | def kl_loss(input, target, reduction='batchmean'):
4 | return F.kl_div(
5 | F.log_softmax(input, dim=-1),
6 | F.softmax(target, dim=-1),
7 | reduction=reduction,
8 | )
9 |
10 | def sym_kl_loss(input, target, reduction='sum', alpha=1.0):
11 | return alpha * F.kl_div(
12 | F.log_softmax(input, dim=-1),
13 | F.softmax(target.detach(), dim=-1),
14 | reduction=reduction,
15 | ) + F.kl_div(
16 | F.log_softmax(target, dim=-1),
17 | F.softmax(input.detach(), dim=-1),
18 | reduction=reduction,
19 | )
20 |
21 | def js_loss(input, target, reduction='sum', alpha=1.0):
22 | mean_proba = 0.5 * (F.softmax(input.detach(), dim=-1) + F.softmax(target.detach(), dim=-1))
23 | return alpha * (F.kl_div(
24 | F.log_softmax(input, dim=-1),
25 | mean_proba,
26 | reduction=reduction
27 | ) + F.kl_div(
28 | F.log_softmax(target, dim=-1),
29 | mean_proba,
30 | reduction=reduction
31 | ))
--------------------------------------------------------------------------------
/smart_pytorch/smart_pytorch.py:
--------------------------------------------------------------------------------
1 | from typing import Union, Callable
2 |
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 | from torch import Tensor
7 | from itertools import count
8 |
9 | def exists(val):
10 | return val is not None
11 |
12 | def default(val, d):
13 | if exists(val):
14 | return val
15 | return d
16 |
17 | def inf_norm(x):
18 | return torch.norm(x, p=float('inf'), dim=-1, keepdim=True)
19 |
20 | class SMARTLoss(nn.Module):
21 |
22 | def __init__(
23 | self,
24 | eval_fn: Callable,
25 | loss_fn: Callable,
26 | loss_last_fn: Callable = None,
27 | norm_fn: Callable = inf_norm,
28 | num_steps: int = 1,
29 | step_size: float = 1e-3,
30 | epsilon: float = 1e-6,
31 | noise_var: float = 1e-5
32 | ) -> None:
33 | super().__init__()
34 | self.eval_fn = eval_fn
35 | self.loss_fn = loss_fn
36 | self.loss_last_fn = default(loss_last_fn, loss_fn)
37 | self.norm_fn = norm_fn
38 | self.num_steps = num_steps
39 | self.step_size = step_size
40 | self.epsilon = epsilon
41 | self.noise_var = noise_var
42 |
43 | def forward(self, embed: Tensor, state: Tensor) -> Tensor:
44 | noise = torch.randn_like(embed, requires_grad=True) * self.noise_var
45 |
46 | # Indefinite loop with counter
47 | for i in count():
48 | # Compute perturbed embed and states
49 | embed_perturbed = embed + noise
50 | state_perturbed = self.eval_fn(embed_perturbed)
51 | # Return final loss if last step (undetached state)
52 | if i == self.num_steps:
53 | return self.loss_last_fn(state_perturbed, state)
54 | # Compute perturbation loss (detached state)
55 | loss = self.loss_fn(state_perturbed, state.detach())
56 | # Compute noise gradient ∂loss/∂noise
57 | noise_gradient, = torch.autograd.grad(loss, noise)
58 | # Move noise towards gradient to change state as much as possible
59 | step = noise + self.step_size * noise_gradient
60 | # Normalize new noise step into norm induced ball
61 | step_norm = self.norm_fn(step)
62 | noise = step / (step_norm + self.epsilon)
63 | # Reset noise gradients for next step
64 | noise = noise.detach().requires_grad_()
--------------------------------------------------------------------------------