├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE ├── MANIFEST.in ├── README.md ├── examples └── mnist.py ├── figures └── margins.png ├── jacobian ├── __init__.py └── jacobian.py └── setup.py /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Code of Conduct 2 | 3 | Facebook has adopted a Code of Conduct that we expect project participants to adhere to. 4 | Please read the [full text](https://code.fb.com/codeofconduct/) 5 | so that you can understand what actions will and will not be tolerated. 6 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to jacobian_regularizer 2 | We want to make contributing to this project as easy and transparent as 3 | possible. 4 | 5 | ## Pull Requests 6 | We actively welcome your pull requests. 7 | 8 | 1. Fork the repo and create your branch from `main`. 9 | 2. If you've added code that should be tested, add tests. 10 | 3. If you've changed APIs, update the documentation. 11 | 4. Ensure the test suite passes. 12 | 5. Make sure your code lints. 13 | 6. If you haven't already, complete the Contributor License Agreement ("CLA"). 14 | 15 | ## Contributor License Agreement ("CLA") 16 | In order to accept your pull request, we need you to submit a CLA. You only need 17 | to do this once to work on any of Facebook's open source projects. 18 | 19 | Complete your CLA here: 20 | 21 | ## Issues 22 | We use GitHub issues to track public bugs. Please ensure your description is 23 | clear and has sufficient instructions to be able to reproduce the issue. 24 | 25 | Facebook has a [bounty program](https://www.facebook.com/whitehat/) for the safe 26 | disclosure of security bugs. In those cases, please go through the process 27 | outlined on that page and do not file a public issue. 28 | 29 | ## License 30 | By contributing to jacobian_regularizer, you agree that your contributions will be licensed 31 | under the LICENSE file in the root directory of this source tree. 32 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) Facebook, Inc. and its affiliates. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include README.md -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # PyTorch Implementation of Jacobian Regularization 2 | 3 | This library provides a PyTorch implementation of the Jacobian Regularization described in the paper "Robust Learning with Jacobian Regularization" 4 | [arxiv:1908.02729](https://arxiv.org/abs/1908.02729). 5 | 6 | Jacobian regularization is a model-agnostic way of increasing classification margins, improving robustness to white and adversarial noise without severely hurting clean model performance. The implementation here also automatically supports GPU acceleration. 7 | 8 | For additional information, please see [1]. 9 | 10 | 11 | 12 |

13 | Classification margins for different regularizers 14 |

15 | 16 | --- 17 | 18 | ## Installation 19 | ``` 20 | pip install git+https://github.com/facebookresearch/jacobian_regularizer 21 | ``` 22 | 23 | ## Usage 24 | This library provides a simple subclass of `torch.nn.Module` that implements Jacobian regularization. After installation, first import the regularization loss 25 | ```python 26 | from jacobian import JacobianReg 27 | import torch.nn as nn 28 | ``` 29 | where we have also imported `torch.nn` so that we may also include a standard supervised classification loss. 30 | 31 | To use Jacobian regularization, we initialize the Jacboan regularization at the same time we initialize our loss criterion 32 | ```python 33 | criterion = nn.CrossEntropyLoss() # supervised classification loss 34 | reg = JacobianReg() # Jacobian regularization 35 | lambda_JR = 0.01 # hyperparameter 36 | ``` 37 | where we have also included a hyperparameter `lambda_JR` controlling the relative strength of the regularization. 38 | 39 | Let's assume we also have a model `model`, data loader `loader`, optimizer `optimizer` and `device` is either `torch.device("cpu")` for CPU training or `torch.device("cuda:0")` for GPU training. Then, to use Jacobian regularization, our training loop might look like this 40 | ```python 41 | for idx, (data, target) in enumerate(loader): 42 | 43 | data, target = data.to(device), target.to(device) 44 | data.requires_grad = True # this is essential! 45 | 46 | optimizer.zero_grad() 47 | 48 | output = model(data) # forward pass 49 | 50 | loss_super = criterion(output, target) # supervised loss 51 | R = reg(data, output) # Jacobian regularization 52 | loss = loss_super + lambda_JR*R # full loss 53 | 54 | loss.backward() # computes gradients 55 | 56 | optimizer.step() 57 | ``` 58 | Backpropagation of the full loss occurs in the call `loss.backward()` so long as `data.requires_grad = True` was called at the top of the training loop. **Note:** this is important any time the Jacobian regularization is evaluated, whether doing model training or model evaluation. (Even for just computing the Jacobian loss, gradients are required!) 59 | 60 | As implied, this Jacobian regularization is compatible with both CPU and GPU training, and may also be combined with other losses, regularizations, and will work with any model, optimizer, or dataset. 61 | 62 | ### Keyword Arguments 63 | - n (int, optional): determines the number of random projections. If n=-1, then it is set to the dimension of the output space and projection is non-random and orthonormal, yielding the exact result. For any reasonable batch size, the default (n=1) should be sufficient. 64 | ```python 65 | reg = JacobianReg() # default has 1 projection 66 | 67 | # you can also specify the number of projections 68 | # this should be must less than the number of classes 69 | n_proj = 3 70 | reg_proj = JacobianReg(n=n_proj) 71 | 72 | # alternatively, you can get the full Jacobian 73 | # which takes C times as long as n_proj=1, if C is # of classes 74 | reg_full = JacobianReg(n=-1) 75 | ``` 76 | 77 | ## Examples 78 | An example script that uses Jacobian regularization for simple MLP training on MNIST is given in the [`examples`](./examples) directory in the file [`mnist.py`](./examples/mnist.py). If you execute the script after installing this package 79 | ```python 80 | python mnist.py 81 | ``` 82 | you should start to see output like this 83 | ``` 84 | Training epoch 1. 85 | [1, 100] supervised loss: 0.687, Jacobian loss: 3.383 86 | [1, 200] supervised loss: 0.373, Jacobian loss: 2.128 87 | [1, 300] supervised loss: 0.317, Jacobian loss: 1.769 88 | [1, 400] supervised loss: 0.287, Jacobian loss: 1.553 89 | [1, 500] supervised loss: 0.276, Jacobian loss: 1.459 90 | ``` 91 | showing the Jacobian beginning to decrease as well as the supervised loss. After 5 epochs, the training will conclude and the output will show an evaluation on the test set before and after training 92 | ``` 93 | Test set results on MNIST with lambda_JR=0.100. 94 | 95 | Before training: 96 | accuracy: 827/10000=0.083 97 | supervised loss: 2.675 98 | Jacobian loss: 3.656 99 | total loss: 3.041 100 | 101 | After 5 epochs of training: 102 | accuracy: 9702/10000=0.970 103 | supervised loss: 0.027 104 | Jacobian loss: 0.977 105 | total loss: 0.125 106 | ``` 107 | showing that the model will learn to generalize and at the same time will regularize the Jacobian for greater robustness. 108 | 109 | Please look at the example file [`mnist.py`](./examples/mnist.py) for additional details. 110 | 111 | ## License 112 | jacobian_regularizer is licensed under the MIT license found in the LICENSE file. 113 | 114 | ## References 115 | [1] Judy Hoffman, Daniel A. Roberts, and Sho Yaida, "Robust Learning with Jacobian Regularization," 2019. [arxiv:1908.02729 [stat.ML]](https://arxiv.org/abs/1908.02729) 116 | 117 | --- 118 | 119 | If you found this useful, please consider citing 120 | ``` 121 | @article{hry2019jacobian, 122 | author = "Hoffman, Judy and Roberts, Daniel A. and Yaida, Sho", 123 | title = "Robust Learning with Jacobian Regularization", 124 | year = "2019", 125 | eprint = "1908.02729", 126 | archivePrefix = "arXiv", 127 | primaryClass = "stat.ML", 128 | } 129 | ``` 130 | -------------------------------------------------------------------------------- /examples/mnist.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Copyright (c) Facebook, Inc. and its affiliates. 3 | 4 | This source code is licensed under the MIT license found in the 5 | LICENSE file in the root directory of this source tree. 6 | 7 | Example script training a simple MLP on MNIST 8 | demonstrating the PyTorch implementation of 9 | Jacobian regularization described in [1]. 10 | 11 | [1] Judy Hoffman, Daniel A. Roberts, and Sho Yaida, 12 | "Robust Learning with Jacobian Regularization," 2019. 13 | [arxiv:1908.02729](https://arxiv.org/abs/1908.02729) 14 | ''' 15 | from __future__ import division 16 | import time 17 | import sys 18 | 19 | import torch 20 | import torch.nn as nn 21 | import torch.nn.functional as F 22 | import torch.optim as optim 23 | from torchvision import datasets, transforms 24 | 25 | from jacobian import JacobianReg 26 | 27 | class MLP(nn.Module): 28 | ''' 29 | Simple MLP to demonstrate Jacobian regularization. 30 | ''' 31 | def __init__(self, in_channel=1, im_size=28, num_classes=10, 32 | fc_channel1=200, fc_channel2=200): 33 | super(MLP, self).__init__() 34 | 35 | # Parameter setup 36 | compression=in_channel*im_size*im_size 37 | self.compression=compression 38 | 39 | # Structure 40 | self.fc1 = nn.Linear(compression, fc_channel1) 41 | self.fc2 = nn.Linear(fc_channel1, fc_channel2) 42 | self.fc3 = nn.Linear(fc_channel2, num_classes) 43 | 44 | # Initialization protocol 45 | nn.init.xavier_uniform_(self.fc1.weight) 46 | nn.init.xavier_uniform_(self.fc2.weight) 47 | nn.init.xavier_uniform_(self.fc3.weight) 48 | 49 | def forward(self, x): 50 | x = x.view(-1, self.compression) 51 | x = F.relu(self.fc1(x)) 52 | x = F.relu(self.fc2(x)) 53 | x = self.fc3(x) 54 | return x 55 | 56 | def eval(device, model, loader, criterion, lambda_JR): 57 | ''' 58 | Evaluate a model on a dataset for Jacobian regularization 59 | 60 | Arguments: 61 | device (torch.device): specifies cpu or gpu training 62 | model (nn.Module): the neural network to evaluate 63 | loader (DataLoader): a loader for the dataset to eval 64 | criterion (nn.Module): the supervised loss function 65 | lambda_JR (float): the Jacobian regularization weight 66 | 67 | Returns: 68 | correct (int): the number correct 69 | total (int): the total number of examples 70 | loss_super (float): the supervised loss 71 | loss_JR (float): the Jacobian regularization loss 72 | loss (float): the total combined loss 73 | ''' 74 | 75 | correct = 0 76 | total = 0 77 | loss_super_avg = 0 78 | loss_JR_avg = 0 79 | loss_avg = 0 80 | 81 | # for eval, let's compute the jacobian exactly 82 | # so n, the number of projections, is set to -1. 83 | reg_full = JacobianReg(n=-1) 84 | for data, targets in loader: 85 | data = data.to(device) 86 | data.requires_grad = True # this is essential! 87 | targets = targets.to(device) 88 | output = model(data) 89 | _, predicted = torch.max(output, 1) 90 | correct += (predicted == targets).sum().item() 91 | total += targets.size(0) 92 | loss_super = criterion(output, targets) # supervised loss 93 | loss_JR = reg_full(data, output) # Jacobian regularization 94 | loss = loss_super + lambda_JR*loss_JR # full loss 95 | loss_super_avg += loss_super.item()*targets.size(0) 96 | loss_JR_avg += loss_JR.item()*targets.size(0) 97 | loss_avg += loss.item()*targets.size(0) 98 | loss_super_avg /= total 99 | loss_JR_avg /= total 100 | loss_avg /= total 101 | return correct, total, loss_super, loss_JR, loss 102 | 103 | def main(): 104 | ''' 105 | Train MNIST with Jacobian regularization. 106 | ''' 107 | seed = 1 108 | batch_size = 64 109 | epochs = 5 110 | 111 | lambda_JR = .1 112 | 113 | # number of projections, default is n_proj=1 114 | # should be greater than 0 and less than sqrt(# of classes) 115 | # can also set n_proj=-1 to compute the full jacobian 116 | # which is computationally inefficient 117 | n_proj = 1 118 | 119 | # setup devices 120 | torch.manual_seed(seed) 121 | if torch.cuda.is_available(): 122 | device = torch.device("cuda:0") 123 | torch.cuda.manual_seed(seed) 124 | else: 125 | device = torch.device("cpu") 126 | 127 | # load MNIST trainset and testset 128 | mnist_mean = (0.1307,) 129 | mnist_std = (0.3081,) 130 | transform = transforms.Compose( 131 | [transforms.ToTensor(), transforms.Normalize(mnist_mean, mnist_std)] 132 | ) 133 | trainset = datasets.MNIST(root='./data', train=True, 134 | download=True, transform=transform 135 | ) 136 | trainloader = torch.utils.data.DataLoader( 137 | trainset, batch_size=batch_size, shuffle=True 138 | ) 139 | testset = datasets.MNIST(root='./data', train=False, 140 | download=True, transform=transform 141 | ) 142 | testloader = torch.utils.data.DataLoader( 143 | testset, batch_size=batch_size, shuffle=True 144 | ) 145 | 146 | # initialize the model 147 | model = MLP() 148 | model.to(device) 149 | 150 | # initialize the loss and regularization 151 | criterion = nn.CrossEntropyLoss() 152 | reg = JacobianReg(n=n_proj) # if n_proj = 1, the argument is unnecessary 153 | 154 | # initialize the optimizer 155 | # including additional regularization, L^2 weight decay 156 | optimizer = optim.SGD(model.parameters(), 157 | lr=0.01, momentum=0.9, weight_decay=5e-4 158 | ) 159 | 160 | # eval on testset before any training 161 | correct_i, total, loss_super_i, loss_JR_i, loss_i = eval( 162 | device, model, testloader, criterion, lambda_JR 163 | ) 164 | 165 | # train 166 | for epoch in range(epochs): 167 | print('Training epoch %d.' % (epoch + 1) ) 168 | running_loss_super = 0.0 169 | running_loss_JR = 0.0 170 | for idx, (data, target) in enumerate(trainloader): 171 | 172 | data, target = data.to(device), target.to(device) 173 | data.requires_grad = True # this is essential! 174 | 175 | optimizer.zero_grad() 176 | 177 | output = model(data) # forward pass 178 | 179 | loss_super = criterion(output, target) # supervised loss 180 | loss_JR = reg(data, output) # Jacobian regularization 181 | loss = loss_super + lambda_JR*loss_JR # full loss 182 | 183 | loss.backward() # computes gradients 184 | 185 | optimizer.step() 186 | 187 | # print running statistics 188 | running_loss_super += loss_super.item() 189 | running_loss_JR += loss_JR.item() 190 | if idx % 100 == 99: # print every 100 mini-batches 191 | print('[%d, %5d] supervised loss: %.3f, Jacobian loss: %.3f' % 192 | ( 193 | epoch + 1, 194 | idx + 1, 195 | running_loss_super / 100, 196 | running_loss_JR / 100, 197 | ) 198 | ) 199 | running_loss_super = 0.0 200 | running_loss_JR = 0.0 201 | 202 | # eval on testset after training 203 | correct_f, total, loss_super_f, loss_JR_f, loss_f = eval( 204 | device, model, testloader, criterion, lambda_JR 205 | ) 206 | 207 | # print results 208 | print('\nTest set results on MNIST with lambda_JR=%.3f.\n' % lambda_JR) 209 | print('Before training:') 210 | print('\taccuracy: %d/%d=%.3f' % (correct_i, total, correct_i/total)) 211 | print('\tsupervised loss: %.3f' % loss_super_i) 212 | print('\tJacobian loss: %.3f' % loss_JR_i) 213 | print('\ttotal loss: %.3f' % loss_i) 214 | 215 | print('\nAfter %d epochs of training:' % epochs) 216 | print('\taccuracy: %d/%d=%.3f' % (correct_f, total, correct_f/total)) 217 | print('\tsupervised loss: %.3f' % loss_super_f) 218 | print('\tJacobian loss: %.3f' % loss_JR_f) 219 | print('\ttotal loss: %.3f' % loss_f) 220 | 221 | if __name__ == '__main__': 222 | main() 223 | -------------------------------------------------------------------------------- /figures/margins.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/jacobian_regularizer/32bb044c4c0163c908ef3c166d07d4ab2a248e07/figures/margins.png -------------------------------------------------------------------------------- /jacobian/__init__.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Copyright (c) Facebook, Inc. and its affiliates. 3 | 4 | This source code is licensed under the MIT license found in the 5 | LICENSE file in the root directory of this source tree. 6 | ''' 7 | 8 | from .jacobian import JacobianReg 9 | name = "jacobian" 10 | -------------------------------------------------------------------------------- /jacobian/jacobian.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Copyright (c) Facebook, Inc. and its affiliates. 3 | 4 | This source code is licensed under the MIT license found in the 5 | LICENSE file in the root directory of this source tree. 6 | 7 | PyTorch implementation of Jacobian regularization described in [1]. 8 | 9 | [1] Judy Hoffman, Daniel A. Roberts, and Sho Yaida, 10 | "Robust Learning with Jacobian Regularization," 2019. 11 | [arxiv:1908.02729](https://arxiv.org/abs/1908.02729) 12 | ''' 13 | from __future__ import division 14 | import torch 15 | import torch.nn as nn 16 | import torch.autograd as autograd 17 | 18 | 19 | class JacobianReg(nn.Module): 20 | ''' 21 | Loss criterion that computes the trace of the square of the Jacobian. 22 | 23 | Arguments: 24 | n (int, optional): determines the number of random projections. 25 | If n=-1, then it is set to the dimension of the output 26 | space and projection is non-random and orthonormal, yielding 27 | the exact result. For any reasonable batch size, the default 28 | (n=1) should be sufficient. 29 | ''' 30 | def __init__(self, n=1): 31 | assert n == -1 or n > 0 32 | self.n = n 33 | super(JacobianReg, self).__init__() 34 | 35 | def forward(self, x, y): 36 | ''' 37 | computes (1/2) tr |dy/dx|^2 38 | ''' 39 | B,C = y.shape 40 | if self.n == -1: 41 | num_proj = C 42 | else: 43 | num_proj = self.n 44 | J2 = 0 45 | for ii in range(num_proj): 46 | if self.n == -1: 47 | # orthonormal vector, sequentially spanned 48 | v=torch.zeros(B,C) 49 | v[:,ii]=1 50 | else: 51 | # random properly-normalized vector for each sample 52 | v = self._random_vector(C=C,B=B) 53 | if x.is_cuda: 54 | v = v.cuda() 55 | Jv = self._jacobian_vector_product(y, x, v, create_graph=True) 56 | J2 += C*torch.norm(Jv)**2 / (num_proj*B) 57 | R = (1/2)*J2 58 | return R 59 | 60 | def _random_vector(self, C, B): 61 | ''' 62 | creates a random vector of dimension C with a norm of C^(1/2) 63 | (as needed for the projection formula to work) 64 | ''' 65 | if C == 1: 66 | return torch.ones(B) 67 | v=torch.randn(B,C) 68 | arxilirary_zero=torch.zeros(B,C) 69 | vnorm=torch.norm(v, 2, 1,True) 70 | v=torch.addcdiv(arxilirary_zero, 1.0, v, vnorm) 71 | return v 72 | 73 | def _jacobian_vector_product(self, y, x, v, create_graph=False): 74 | ''' 75 | Produce jacobian-vector product dy/dx dot v. 76 | 77 | Note that if you want to differentiate it, 78 | you need to make create_graph=True 79 | ''' 80 | flat_y = y.reshape(-1) 81 | flat_v = v.reshape(-1) 82 | grad_x, = torch.autograd.grad(flat_y, x, flat_v, 83 | retain_graph=True, 84 | create_graph=create_graph) 85 | return grad_x 86 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Copyright (c) Facebook, Inc. and its affiliates. 3 | 4 | This source code is licensed under the MIT license found in the 5 | LICENSE file in the root directory of this source tree. 6 | ''' 7 | 8 | import setuptools 9 | 10 | with open("README.md", "r") as fh: 11 | long_description = fh.read() 12 | 13 | setuptools.setup( 14 | name="jacobian", 15 | version="1.0.0", 16 | author="Judy Hoffman, Daniel A. Roberts, and Sho Yaida", 17 | author_email="judy@gatech.edu, daniel.adam.roberts@gmail.com, shoyaida@fb.com", 18 | description="Jacobian regularization in PyTorch.", 19 | long_description=long_description, 20 | long_description_content_type="text/markdown", 21 | url="https://github.com/facebookresearch/jacobian_regularizer", 22 | packages=['jacobian'], 23 | classifiers=[ 24 | "Programming Language :: Python :: 3", 25 | "License :: OSI Approved :: MIT License", 26 | "Operating System :: OS Independent", 27 | ], 28 | ) 29 | --------------------------------------------------------------------------------