├── .github └── workflows │ └── python-publish.yml ├── .gitignore ├── LICENSE ├── README.md ├── SMART.png ├── setup.py └── smart_pytorch ├── __init__.py ├── loss.py └── smart_pytorch.py /.github/workflows/python-publish.yml: -------------------------------------------------------------------------------- 1 | # This workflow will upload a Python Package using Twine when a release is created 2 | # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries 3 | 4 | # This workflow uses actions that are not certified by GitHub. 5 | # They are provided by a third-party and are governed by 6 | # separate terms of service, privacy policy, and support 7 | # documentation. 8 | 9 | name: Upload Python Package 10 | 11 | on: 12 | release: 13 | types: [published] 14 | 15 | permissions: 16 | contents: read 17 | 18 | jobs: 19 | deploy: 20 | 21 | runs-on: ubuntu-latest 22 | 23 | steps: 24 | - uses: actions/checkout@v3 25 | - name: Set up Python 26 | uses: actions/setup-python@v3 27 | with: 28 | python-version: '3.x' 29 | - name: Install dependencies 30 | run: | 31 | python -m pip install --upgrade pip 32 | pip install build 33 | - name: Build package 34 | run: python -m build 35 | - name: Publish package 36 | uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29 37 | with: 38 | user: __token__ 39 | password: ${{ secrets.PYPI_API_TOKEN }} 40 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 archinet.ai 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | # SMART - PyTorch 6 | 7 | A PyTorch implementation of SMART, a regularization technique to fine-tune pretrained (language) models. You might also be interested in vat-pytorch, a more generic collection of virtual adversarial training (VAT) methods, in PyTorch. 8 | 9 | ## Install 10 | 11 | ```bash 12 | $ pip install smart-pytorch 13 | ``` 14 | 15 | [![PyPI - Python Version](https://img.shields.io/pypi/v/smart-pytorch?style=flat&colorA=0f0f0f&colorB=0f0f0f)](https://pypi.org/project/smart-pytorch/) 16 | 17 | ## Usage 18 | 19 | ### Minimal Example 20 | 21 | ```py 22 | import torch 23 | import torch.nn as nn 24 | from smart_pytorch import SMARTLoss 25 | 26 | # Define function that will be perturbed (usually our network) 27 | eval_fn = torch.nn.Linear(in_features=10, out_features=20) 28 | 29 | # Define loss function between states 30 | loss_fn = nn.MSELoss() 31 | 32 | # Initialize regularization loss 33 | regularizer = SMARTLoss(eval_fn = eval_fn, loss_fn = loss_fn) 34 | 35 | # Compute initial input embed and output state 36 | embed = torch.rand(1, 10) # [batch_size, in_features] 37 | state = eval_fn(embed) # [batch_size, out_featueres] 38 | 39 | # Compute regularation loss 40 | loss = regularizer(embed, state) 41 | loss # tensor(0.0922578126, grad_fn=) 42 | ``` 43 | 44 | Where `eval_fn` is a function (usually a neural network) that takes as input an embedding `embed` and produces as output one or multiple states `state`. Internally, this function is used to perturb the input `embed` with noise to get a perturbed `state` which is compared with the initially provided `state`. 45 | 46 | ### Full API Example 47 | ```python 48 | import torch 49 | import torch.nn as nn 50 | from smart_pytorch import SMARTLoss 51 | 52 | # Define function that will be perturbed (usually our network) 53 | eval_fn = torch.nn.Linear(in_features=10, out_features=20) 54 | 55 | # Define loss function between states 56 | loss_fn = nn.MSELoss() 57 | 58 | # Norm used to normalize the gradient 59 | inf_norm = lambda x: torch.norm(x, p=float('inf'), dim=-1, keepdim=True) 60 | 61 | # Initialize regularization loss 62 | regularizer = SMARTLoss( 63 | eval_fn = eval_fn, 64 | loss_fn = loss_fn, # Loss to apply between perturbed and true state 65 | loss_last_fn = loss_fn, # Loss to apply between perturbed and true state on the last iteration (default = loss_fn) 66 | norm_fn = inf_norm, # Norm used to normalize the gradient (default = inf_norm) 67 | num_steps = 1, # Number of optimization steps to find noise (default = 1) 68 | step_size = 1e-3, # Step size to improve noise (default = 1e-3) 69 | epsilon = 1e-6, # Noise norm constraint (default = 1e-6) 70 | noise_var = 1e-5 # Initial noise variance (default = 1e-5) 71 | ) 72 | 73 | # Compute initial input embed and output state 74 | embed = torch.rand(1, 10) # [batch_size, in_features] 75 | state = eval_fn(embed) # [batch_size, out_featueres] 76 | 77 | # Compute regularation loss 78 | loss = regularizer(embed, state) 79 | loss # tensor(0.0432184562, grad_fn=) 80 | ``` 81 | 82 | ### RoBERTa Classification Example 83 | 84 | This example demostrates how to wrap a RoBERTa classifier from Huggingface to use with SMART. 85 | 86 | ```py 87 | from smart_pytorch import SMARTLoss, kl_loss, sym_kl_loss 88 | from transformers import AutoTokenizer, AutoModelForSequenceClassification 89 | 90 | class SMARTRobertaClassificationModel(nn.Module): 91 | 92 | def __init__(self, model, weight = 0.02): 93 | super().__init__() 94 | self.model = model 95 | self.weight = weight 96 | 97 | def forward(self, input_ids, attention_mask, labels): 98 | 99 | # Get initial embeddings 100 | embed = self.model.roberta.embeddings(input_ids) 101 | 102 | # Define eval function 103 | def eval(embed): 104 | outputs = self.model.roberta(inputs_embeds=embed, attention_mask=attention_mask) 105 | pooled = outputs[0] 106 | logits = self.model.classifier(pooled) 107 | return logits 108 | 109 | # Define SMART loss 110 | smart_loss_fn = SMARTLoss(eval_fn = eval, loss_fn = kl_loss, loss_last_fn = sym_kl_loss) 111 | # Compute initial (unperturbed) state 112 | state = eval(embed) 113 | # Apply classification loss 114 | loss = F.cross_entropy(state.view(-1, 2), labels.view(-1)) 115 | # Apply smart loss 116 | loss += self.weight * smart_loss_fn(embed, state) 117 | 118 | return state, loss 119 | 120 | 121 | tokenizer = AutoTokenizer.from_pretrained('roberta-base') 122 | model = AutoModelForSequenceClassification.from_pretrained('roberta-base') 123 | 124 | model_smart = SMARTRobertaClassificationModel(model) 125 | # Compute inputs 126 | text = ["This text belongs to class 1...", "This text belongs to class 0..."] 127 | inputs = tokenizer(text, return_tensors='pt') 128 | labels = torch.tensor([1, 0]) 129 | 130 | # Compute output and loss 131 | state, loss = model_smart(input_ids = inputs['input_ids'], attention_mask = inputs['attention_mask'], labels = labels) 132 | print(state.shape, loss) # torch.Size([2, 2]) tensor(0.6980957389, grad_fn=) 133 | ``` 134 | 135 | 136 | 137 | 138 | ## Citations 139 | 140 | ```bibtex 141 | @inproceedings{Jiang2020SMARTRA, 142 | title={SMART: Robust and Efficient Fine-Tuning for Pre-trained Natural Language Models through Principled Regularized Optimization}, 143 | author={Haoming Jiang and Pengcheng He and Weizhu Chen and Xiaodong Liu and Jianfeng Gao and Tuo Zhao}, 144 | booktitle={ACL}, 145 | year={2020} 146 | } 147 | ``` 148 | -------------------------------------------------------------------------------- /SMART.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/archinetai/smart-pytorch/e96d8630dc58e1dce8540f61f91016849925ebfe/SMART.png -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup( 4 | name = 'smart-pytorch', 5 | packages = find_packages(exclude=[]), 6 | version = '0.0.4', 7 | license='MIT', 8 | description = 'SMART Fine-Tuning - Pytorch', 9 | long_description_content_type = 'text/markdown', 10 | author = 'Flavio Schneider', 11 | author_email = 'archinetai@protonmail.com', 12 | url = 'https://github.com/archinetai/smart-pytorch', 13 | keywords = [ 14 | 'artificial intelligence', 15 | 'deep learning', 16 | 'fine-tuning', 17 | 'pre-trained', 18 | ], 19 | install_requires=[ 20 | 'torch>=1.6', 21 | 'data-science-types>=0.2' 22 | ], 23 | classifiers=[ 24 | 'Development Status :: 4 - Beta', 25 | 'Intended Audience :: Developers', 26 | 'Topic :: Scientific/Engineering :: Artificial Intelligence', 27 | 'License :: OSI Approved :: MIT License', 28 | 'Programming Language :: Python :: 3.6', 29 | ], 30 | ) 31 | -------------------------------------------------------------------------------- /smart_pytorch/__init__.py: -------------------------------------------------------------------------------- 1 | from .smart_pytorch import SMARTLoss 2 | from .loss import kl_loss, sym_kl_loss, js_loss -------------------------------------------------------------------------------- /smart_pytorch/loss.py: -------------------------------------------------------------------------------- 1 | import torch.nn.functional as F 2 | 3 | def kl_loss(input, target, reduction='batchmean'): 4 | return F.kl_div( 5 | F.log_softmax(input, dim=-1), 6 | F.softmax(target, dim=-1), 7 | reduction=reduction, 8 | ) 9 | 10 | def sym_kl_loss(input, target, reduction='sum', alpha=1.0): 11 | return alpha * F.kl_div( 12 | F.log_softmax(input, dim=-1), 13 | F.softmax(target.detach(), dim=-1), 14 | reduction=reduction, 15 | ) + F.kl_div( 16 | F.log_softmax(target, dim=-1), 17 | F.softmax(input.detach(), dim=-1), 18 | reduction=reduction, 19 | ) 20 | 21 | def js_loss(input, target, reduction='sum', alpha=1.0): 22 | mean_proba = 0.5 * (F.softmax(input.detach(), dim=-1) + F.softmax(target.detach(), dim=-1)) 23 | return alpha * (F.kl_div( 24 | F.log_softmax(input, dim=-1), 25 | mean_proba, 26 | reduction=reduction 27 | ) + F.kl_div( 28 | F.log_softmax(target, dim=-1), 29 | mean_proba, 30 | reduction=reduction 31 | )) -------------------------------------------------------------------------------- /smart_pytorch/smart_pytorch.py: -------------------------------------------------------------------------------- 1 | from typing import Union, Callable 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from torch import Tensor 7 | from itertools import count 8 | 9 | def exists(val): 10 | return val is not None 11 | 12 | def default(val, d): 13 | if exists(val): 14 | return val 15 | return d 16 | 17 | def inf_norm(x): 18 | return torch.norm(x, p=float('inf'), dim=-1, keepdim=True) 19 | 20 | class SMARTLoss(nn.Module): 21 | 22 | def __init__( 23 | self, 24 | eval_fn: Callable, 25 | loss_fn: Callable, 26 | loss_last_fn: Callable = None, 27 | norm_fn: Callable = inf_norm, 28 | num_steps: int = 1, 29 | step_size: float = 1e-3, 30 | epsilon: float = 1e-6, 31 | noise_var: float = 1e-5 32 | ) -> None: 33 | super().__init__() 34 | self.eval_fn = eval_fn 35 | self.loss_fn = loss_fn 36 | self.loss_last_fn = default(loss_last_fn, loss_fn) 37 | self.norm_fn = norm_fn 38 | self.num_steps = num_steps 39 | self.step_size = step_size 40 | self.epsilon = epsilon 41 | self.noise_var = noise_var 42 | 43 | def forward(self, embed: Tensor, state: Tensor) -> Tensor: 44 | noise = torch.randn_like(embed, requires_grad=True) * self.noise_var 45 | 46 | # Indefinite loop with counter 47 | for i in count(): 48 | # Compute perturbed embed and states 49 | embed_perturbed = embed + noise 50 | state_perturbed = self.eval_fn(embed_perturbed) 51 | # Return final loss if last step (undetached state) 52 | if i == self.num_steps: 53 | return self.loss_last_fn(state_perturbed, state) 54 | # Compute perturbation loss (detached state) 55 | loss = self.loss_fn(state_perturbed, state.detach()) 56 | # Compute noise gradient ∂loss/∂noise 57 | noise_gradient, = torch.autograd.grad(loss, noise) 58 | # Move noise towards gradient to change state as much as possible 59 | step = noise + self.step_size * noise_gradient 60 | # Normalize new noise step into norm induced ball 61 | step_norm = self.norm_fn(step) 62 | noise = step / (step_norm + self.epsilon) 63 | # Reset noise gradients for next step 64 | noise = noise.detach().requires_grad_() --------------------------------------------------------------------------------