├── .gitignore ├── Makefile ├── READMATH.md ├── README.md ├── ckpts └── .gitkeep ├── data └── .gitkeep ├── scripts ├── __init__.py ├── analyze.py ├── analyze_verify.py ├── check_noise_propagation.py ├── plot_losses.py └── plot_preds_distn.py ├── src ├── __init__.py ├── attacks.py ├── datasets.py ├── lib │ ├── __init__.py │ ├── alexnet.py │ ├── cifar10_selftrained.py │ ├── classic_resnet.py │ ├── ds_imagenet.py │ ├── lenet.py │ ├── wide_resnet.py │ └── zipdata.py ├── models.py ├── noises │ ├── __init__.py │ ├── noises.py │ ├── test_noises.py │ └── utils.py ├── smooth.py ├── test.py ├── train.py └── verify.py ├── svgs ├── 1418cce7d60743be1c545cd950367159.svg ├── 1fa8048512f84790ef174f591d0cb851.svg ├── 336fefe2418749fabf50594e52f7b776.svg ├── 44c65658d6cd134b1599c29b31949f77.svg ├── 5dc1880e644c7b3a0e9fa954759762ea.svg ├── 8244067f9118b85361c6645cc9f1c526.svg ├── 839a0dc412c4f8670dd1064e0d6d412f.svg ├── 8d2d1eabb21bb41807292151fe468472.svg ├── DistributionVenn.png ├── b52b48d8661f69776e1b6650998d5067.svg ├── bb9e6385ceb6d4a2d83a5b51a3c870c9.svg ├── bd5b313d1d74ae2fc57ddb870603d84b.svg ├── e1085464f81e12de4a74d54d14eb5dc5.svg ├── e703845884313f30712bfc7262a5e65b.svg ├── ec90b4fe342a37de851db6db2b08d4f4.svg ├── envelopes.png ├── robust-radii.png └── table.png └── tutorial.ipynb /.gitignore: -------------------------------------------------------------------------------- 1 | *.pkl 2 | *.pyc 3 | *.torch 4 | *.npy 5 | *.csv 6 | *.eps 7 | *.swp 8 | *.pdf 9 | stdout.txt 10 | misc/ 11 | jobs/ 12 | data/*/* 13 | .vscode/ 14 | .ipynb_checkpoints/ 15 | pretrain/* 16 | Untitled*.ipynb 17 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | all: 2 | python3 -m readme2tex --branch master --nocdn --readme READMATH.md --output README.md 3 | 4 | -------------------------------------------------------------------------------- /READMATH.md: -------------------------------------------------------------------------------- 1 | # Randomized Smoothing of All Shapes and Sizes 2 | 3 | Last update: July 2020. 4 | 5 | --- 6 | 7 | Code to accompany our paper: 8 | 9 | **Randomized Smoothing of All Shapes and Sizes** 10 | *Greg Yang\*, Tony Duan\*, J. Edward Hu, Hadi Salman, Ilya Razenshteyn, Jerry Li.* 11 | International Conference on Machine Learning (ICML), 2020 [[Paper]](https://arxiv.org/abs/2002.08118) [[Blog Post]](http://decentdescent.org/rs4a1.html) 12 | 13 | Notably, we outperform existing provably $\ell_1$-robust classifiers on ImageNet and CIFAR-10. 14 | 15 | ![Table of SOTA results.](svgs/table.png) 16 | 17 | ![Figure of SOTA results.](svgs/envelopes.png) 18 | 19 | This library implements the algorithms in our paper for computing robust radii for different smoothing distributions against different adversaries; for example, distributions of the form $e^{-\|x\|_\infty^k}$ against $\ell_1$ adversary. 20 | 21 | The following summarizes the (distribution, adversary) pairs covered here. 22 | 23 | ![Venn Diagram of Distributions and Adversaries.](svgs/DistributionVenn.png) 24 | 25 | We can compare the certified robust radius each of these distributions implies at a fixed level of $\hat\rho_\mathrm{lower}$, the lower bound on the probability that the classifier returns the top class under noise. Here all noises are instantiated for CIFAR-10 dimensionality ($d=3072$) and normalized to variance $\sigma^2 \triangleq \mathbb{E}[\|x\|_2^2]=1$. Note that the first two rows below certify for the $\ell_1$ adversary while the last row certifies for the $\ell_2$ adversary and the $\ell_\infty$ adversary. For more details see our `tutorial.ipynb` notebook. 26 | 27 | ![Certified Robust Radii of Distributions](svgs/robust-radii.png) 28 | 29 | ## Getting Started 30 | 31 | Clone our repository and install dependencies: 32 | 33 | ```shell 34 | git clone https://github.com/tonyduan/rs4a.git 35 | conda create --name rs4a python=3.6 36 | conda activate rs4a 37 | conda install numpy matplotlib pandas seaborn 38 | conda install pytorch torchvision cudatoolkit=10.0 -c pytorch 39 | pip install torchnet tqdm statsmodels dfply 40 | ``` 41 | 42 | ## Experiments 43 | 44 | To reproduce our SOTA $\ell_1$ results on CIFAR-10, we need to train models over 45 | $$ 46 | \sigma \in \{0.15, 0.25, 0.5, 0.75, 1.0, 1.25, 1.5, 1.75,2.0,2.25, 2.5,2.75, 3.0,3.25,3.5\}, 47 | $$ 48 | For each value, run the following: 49 | 50 | ```shell 51 | python3 -m src.train 52 | --model=WideResNet 53 | --noise=Uniform 54 | --sigma={sigma} 55 | --experiment-name=cifar_uniform_{sigma} 56 | 57 | python3 -m src.test 58 | --model=WideResNet 59 | --noise=Uniform 60 | --sigma={sigma} 61 | --experiment-name=cifar_uniform_{sigma} 62 | --sample-size-cert=100000 63 | --sample-size-pred=64 64 | --noise-batch-size=512 65 | ``` 66 | 67 | The training script will train the model via data augmentation for the specified noise and level of sigma, and save the model checkpoint to a directory `ckpts/experiment_name`. 68 | 69 | The testing script will load the model checkpoint from the `ckpts/experiment_name` directory, make predictions over the entire test set using the smoothed classifier, and certify the $\ell_1, \ell_2,$ and $\ell_\infty$ robust radii of these predictions. Note that by default we make predictions with $64$ samples, certify with $100,000$ samples, and at a failure probability of $\alpha=0.001$. 70 | 71 | To draw a comparison to the benchmark noises, re-run the above replacing `Uniform` with `Gaussian` and `Laplace`. Then to plot the figures and print the table of results (for $\ell_1$ adversary), run our analysis script: 72 | 73 | ```shell 74 | python3 -m scripts.analyze --dir=ckpts --show --adv=1 75 | ``` 76 | 77 | Note that other noises will need to be instantiated with the appropriate arguments when the appropriate training/testing code is invoked. For example, if we want to sample noise $\propto \|x\|_\infty^{-100}e^{-\|x\|_\infty^{10}}$, we would run: 78 | 79 | ```shell 80 | python3 -m src.train 81 | --noise=ExpInf 82 | --k=10 83 | --j=100 84 | --sigma=0.5 85 | --experiment-name=cifar_expinf_0.5 86 | ``` 87 | 88 | ## Trained Models 89 | 90 | Our pre-trained models are available. 91 | 92 | The following commands will download all models into the `pretrain/` directory. 93 | 94 | ```shell 95 | mkdir -p pretrain 96 | wget --directory-prefix=pretrain http://www.tonyduan.com/resources/2020_rs4a_ckpts/cifar_all.zip 97 | unzip -d pretrain pretrain/cifar_all.zip 98 | wget --directory-prefix=pretrain http://www.tonyduan.com/resources/2020_rs4a_ckpts/imagenet_all.zip 99 | unzip -d pretrain pretrain/imagenet_all.zip 100 | ``` 101 | 102 | ImageNet (ResNet-50): [[All Models, 2.3 GB]](http://www.tonyduan.com/resources/2020_rs4a_ckpts/imagenet_all.zip) 103 | 104 | - Sigma=0.25: [[Uniform]](http://www.tonyduan.com/resources/2020_rs4a_ckpts/imagenet_uniform_025.pt) [[Gaussian]](http://www.tonyduan.com/resources/2020_rs4a_ckpts/imagenet_gaussian_025.pt) [[Laplace]](http://www.tonyduan.com/resources/2020_rs4a_ckpts/imagenet_laplace_025.pt) 105 | - Sigma=0.5: [[Uniform]](http://www.tonyduan.com/resources/2020_rs4a_ckpts/imagenet_uniform_050.pt) [[Gaussian]](http://www.tonyduan.com/resources/2020_rs4a_ckpts/imagenet_gaussian_050.pt) [[Laplace]](http://www.tonyduan.com/resources/2020_rs4a_ckpts/imagenet_laplace_050.pt) 106 | - Sigma=0.75: [[Uniform]](http://www.tonyduan.com/resources/2020_rs4a_ckpts/imagenet_uniform_075.pt) [[Gaussian]](http://www.tonyduan.com/resources/2020_rs4a_ckpts/imagenet_gaussian_075.pt) [[Laplace]](http://www.tonyduan.com/resources/2020_rs4a_ckpts/imagenet_laplace_075.pt) 107 | - Sigma=1.0: [[Uniform]](http://www.tonyduan.com/resources/2020_rs4a_ckpts/imagenet_uniform_100.pt) [[Gaussian]](http://www.tonyduan.com/resources/2020_rs4a_ckpts/imagenet_gaussian_100.pt) [[Laplace]](http://www.tonyduan.com/resources/2020_rs4a_ckpts/imagenet_laplace_100.pt) 108 | - Sigma=1.25: [[Uniform]](http://www.tonyduan.com/resources/2020_rs4a_ckpts/imagenet_uniform_125.pt) [[Gaussian]](http://www.tonyduan.com/resources/2020_rs4a_ckpts/imagenet_gaussian_125.pt) [[Laplace]](http://www.tonyduan.com/resources/2020_rs4a_ckpts/imagenet_laplace_125.pt) 109 | - Sigma=1.5: [[Uniform]](http://www.tonyduan.com/resources/2020_rs4a_ckpts/imagenet_uniform_150.pt) 110 | - Sigma=1.75: [[Uniform]](http://www.tonyduan.com/resources/2020_rs4a_ckpts/imagenet_uniform_175.pt) 111 | - Sigma=2.0: [[Uniform]](http://www.tonyduan.com/resources/2020_rs4a_ckpts/imagenet_uniform_200.pt) 112 | - Sigma=2.25: [[Uniform]](http://www.tonyduan.com/resources/2020_rs4a_ckpts/imagenet_uniform_225.pt) 113 | - Sigma=2.5: [[Uniform]](http://www.tonyduan.com/resources/2020_rs4a_ckpts/imagenet_uniform_250.pt) 114 | - Sigma=2.75: [[Uniform]](http://www.tonyduan.com/resources/2020_rs4a_ckpts/imagenet_uniform_275.pt) 115 | - Sigma=3.0: [[Uniform]](http://www.tonyduan.com/resources/2020_rs4a_ckpts/imagenet_uniform_300.pt) 116 | - Sigma=3.25: [[Uniform]](http://www.tonyduan.com/resources/2020_rs4a_ckpts/imagenet_uniform_325.pt) 117 | - Sigma=3.5: [[Uniform]](http://www.tonyduan.com/resources/2020_rs4a_ckpts/imagenet_uniform_350.pt) 118 | 119 | CIFAR-10 (Wide ResNet 40-2): [[All Models, 226 MB]](http://www.tonyduan.com/resources/2020_rs4a_ckpts/cifar_all.zip) 120 | 121 | - Sigma=0.15: [[Uniform]](http://www.tonyduan.com/resources/2020_rs4a_ckpts/cifar_uniform_015.pt) [[Gaussian]](http://www.tonyduan.com/resources/2020_rs4a_ckpts/cifar_gaussian_015.pt) [[Laplace]](http://www.tonyduan.com/resources/2020_rs4a_ckpts/cifar_laplace_015.pt) 122 | - Sigma=0.25: [[Uniform]](http://www.tonyduan.com/resources/2020_rs4a_ckpts/cifar_uniform_025.pt) [[Gaussian]](http://www.tonyduan.com/resources/2020_rs4a_ckpts/cifar_gaussian_025.pt) [[Laplace]](http://www.tonyduan.com/resources/2020_rs4a_ckpts/cifar_laplace_025.pt) 123 | - Sigma=0.5: [[Uniform]](http://www.tonyduan.com/resources/2020_rs4a_ckpts/cifar_uniform_050.pt) [[Gaussian]](http://www.tonyduan.com/resources/2020_rs4a_ckpts/cifar_gaussian_050.pt) [[Laplace]](http://www.tonyduan.com/resources/2020_rs4a_ckpts/cifar_laplace_050.pt) 124 | - Sigma=0.75: [[Uniform]](http://www.tonyduan.com/resources/2020_rs4a_ckpts/cifar_uniform_075.pt) [[Gaussian]](http://www.tonyduan.com/resources/2020_rs4a_ckpts/cifar_gaussian_075.pt) [[Laplace]](http://www.tonyduan.com/resources/2020_rs4a_ckpts/cifar_laplace_075.pt) 125 | - Sigma=1.0: [[Uniform]](http://www.tonyduan.com/resources/2020_rs4a_ckpts/cifar_uniform_100.pt) [[Gaussian]](http://www.tonyduan.com/resources/2020_rs4a_ckpts/cifar_gaussian_100.pt) [[Laplace]](http://www.tonyduan.com/resources/2020_rs4a_ckpts/cifar_laplace_100.pt) 126 | - Sigma=1.25: [[Uniform]](http://www.tonyduan.com/resources/2020_rs4a_ckpts/cifar_uniform_125.pt) [[Gaussian]](http://www.tonyduan.com/resources/2020_rs4a_ckpts/cifar_gaussian_125.pt) [[Laplace]](http://www.tonyduan.com/resources/2020_rs4a_ckpts/cifar_laplace_125.pt) 127 | - Sigma=1.5: [[Uniform]](http://www.tonyduan.com/resources/2020_rs4a_ckpts/cifar_uniform_150.pt) 128 | - Sigma=1.75: [[Uniform]](http://www.tonyduan.com/resources/2020_rs4a_ckpts/cifar_uniform_175.pt) 129 | - Sigma=2.0: [[Uniform]](http://www.tonyduan.com/resources/2020_rs4a_ckpts/cifar_uniform_200.pt) 130 | - Sigma=2.25: [[Uniform]](http://www.tonyduan.com/resources/2020_rs4a_ckpts/cifar_uniform_225.pt) 131 | - Sigma=2.5: [[Uniform]](http://www.tonyduan.com/resources/2020_rs4a_ckpts/cifar_uniform_250.pt) 132 | - Sigma=2.75: [[Uniform]](http://www.tonyduan.com/resources/2020_rs4a_ckpts/cifar_uniform_275.pt) 133 | - Sigma=3.0: [[Uniform]](http://www.tonyduan.com/resources/2020_rs4a_ckpts/cifar_uniform_300.pt) 134 | - Sigma=3.25: [[Uniform]](http://www.tonyduan.com/resources/2020_rs4a_ckpts/cifar_uniform_325.pt) 135 | - Sigma=3.5: [[Uniform]](http://www.tonyduan.com/resources/2020_rs4a_ckpts/cifar_uniform_350.pt) 136 | 137 | By default the models above were trained with noise augmentation. We further improve upon our state-of-the-art certified accuracies using recent advances in training smoothed classifiers: (1) by using stability training (Li et al. NeurIPS 2019), and (2) by leveraging additional data using (a) pre-training on downsampled ImageNet (Hendrycks et al. NeurIPS 2019) and (b) semi-supervised self-training with data from 80 Million Tiny Images (Carmon et al. 2019). Our improved models trained with these methods are released below. 138 | 139 | ImageNet (ResNet 50): 140 | 141 | - Stability training: [[All Models, 2.3 GB]](http://www.tonyduan.com/resources/2020_rs4a_ckpts/imagenet_stability.zip) 142 | 143 | CIFAR-10 (Wide ResNet 40-2): 144 | 145 | - Stability training: [[All Models, 234 MB]](http://www.tonyduan.com/resources/2020_rs4a_ckpts/cifar_stability.zip) 146 | - Stability training + pre-training: [[All Models, 236 MB]](http://www.tonyduan.com/resources/2020_rs4a_ckpts/cifar_pretrained.zip) 147 | - Stability training + semi-supervised learning: [[All Models, 235 MB]](http://www.tonyduan.com/resources/2020_rs4a_ckpts/cifar_semisup.zip) 148 | 149 | An example of pre-trained model usage is below. For more in depth example see our `tutorial.ipynb` notebook. 150 | 151 | ```python 152 | from src.models import WideResNet 153 | from src.noises import Uniform 154 | from src.smooth import * 155 | 156 | # load the model 157 | model = WideResNet(dataset="cifar", device="cuda") 158 | saved_dict = torch.load("pretrain/cifar_uniform_050.pt") 159 | model.load_state_dict(saved_dict) 160 | model.eval() 161 | 162 | # instantiation of noise 163 | noise = Uniform(device="cpu", dim=3072, sigma=0.5) 164 | 165 | # training code, to generate samples 166 | noisy_x = noise.sample(x) 167 | 168 | # testing code, certify for L1 adversary 169 | preds = smooth_predict_hard(model, x, noise, 64) 170 | top_cats = preds.probs.argmax(dim=1) 171 | prob_lb = certify_prob_lb(model, x, top_cats, 0.001, noise, 100000) 172 | radius = noise.certify(prob_lb, adv=1) 173 | ``` 174 | 175 | ## Repository 176 | 177 | 1. `ckpts/` is used to store experiment checkpoints and results. 178 | 2. `data/` is used to store image datasets. 179 | 4. `tables/` contains caches of pre-calculated tables of certified radii. 180 | 5. `src/` contains the main souce code. 181 | 6. `scripts/` contains the analysis and plotting code. 182 | 183 | Within the `src/` directory, the most salient files are: 184 | 185 | 1. `train.py` is used to train models and save to `ckpts/`. 186 | 2. `test.py` is used to test and compute robust certificates for $\ell_1,\ell_2,\ell_\infty$ adversaries. 187 | 3. `noises/test_noises.py` is a unit test for the noises we include. Run the test with 188 | 189 | ```python -m unittest src/noises/test_noises.py``` 190 | 191 | Note that some tests are probabilistic and can fail occasionally. 192 | If so, rerun a few more times to make sure the failure is not persistent. 193 | 194 | 4. `noises/noises.py` is a library of noises derived for randomized smoothing. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Randomized Smoothing of All Shapes and Sizes 2 | 3 | Last update: July 2020. 4 | 5 | --- 6 | 7 | Code to accompany our paper: 8 | 9 | **Randomized Smoothing of All Shapes and Sizes** 10 | *Greg Yang\*, Tony Duan\*, J. Edward Hu, Hadi Salman, Ilya Razenshteyn, Jerry Li.* 11 | International Conference on Machine Learning (ICML), 2020 [[Paper]](https://arxiv.org/abs/2002.08118) [[Blog Post]](http://decentdescent.org/rs4a1.html) 12 | 13 | Notably, we outperform existing provably $\ell_1$-robust classifiers on ImageNet and CIFAR-10. 14 | 15 | ![Table of SOTA results.](svgs/table.png) 16 | 17 | ![Figure of SOTA results.](svgs/envelopes.png) 18 | 19 | This library implements the algorithms in our paper for computing robust radii for different smoothing distributions against different adversaries; for example, distributions of the form $e^{-\|x\|_\infty^k}$ against $\ell_1$ adversary. 20 | 21 | The following summarizes the (distribution, adversary) pairs covered here. 22 | 23 | ![Venn Diagram of Distributions and Adversaries.](svgs/DistributionVenn.png) 24 | 25 | We can compare the certified robust radius each of these distributions implies at a fixed level of $\hat\rho_\mathrm{lower}$, the lower bound on the probability that the classifier returns the top class under noise. Here all noises are instantiated for CIFAR-10 dimensionality ($d=3072$) and normalized to variance $\sigma^2 \triangleq \mathbb{E}[\|x\|_2^2]=1$. Note that the first two rows below certify for the $\ell_1$ adversary while the last row certifies for the $\ell_2$ adversary and the $\ell_\infty$ adversary. For more details see our `tutorial.ipynb` notebook. 26 | 27 | ![Certified Robust Radii of Distributions](svgs/robust-radii.png) 28 | 29 | ## Getting Started 30 | 31 | Clone our repository and install dependencies: 32 | 33 | ```shell 34 | git clone https://github.com/tonyduan/rs4a.git 35 | conda create --name rs4a python=3.6 36 | conda activate rs4a 37 | conda install numpy matplotlib pandas seaborn 38 | conda install pytorch torchvision cudatoolkit=10.0 -c pytorch 39 | pip install torchnet tqdm statsmodels dfply 40 | ``` 41 | 42 | ## Experiments 43 | 44 | To reproduce our SOTA $\ell_1$ results on CIFAR-10, we need to train models over 45 |

$$
\sigma \in \{0.15, 0.25, 0.5, 0.75, 1.0, 1.25, 1.5, 1.75,2.0,2.25, 2.5,2.75, 3.0,3.25,3.5\},
$$

46 | For each value, run the following: 47 | 48 | ```shell 49 | python3 -m src.train 50 | --model=WideResNet 51 | --noise=Uniform 52 | --sigma={sigma} 53 | --experiment-name=cifar_uniform_{sigma} 54 | 55 | python3 -m src.test 56 | --model=WideResNet 57 | --noise=Uniform 58 | --sigma={sigma} 59 | --experiment-name=cifar_uniform_{sigma} 60 | --sample-size-cert=100000 61 | --sample-size-pred=64 62 | --noise-batch-size=512 63 | ``` 64 | 65 | The training script will train the model via data augmentation for the specified noise and level of sigma, and save the model checkpoint to a directory `ckpts/experiment_name`. 66 | 67 | The testing script will load the model checkpoint from the `ckpts/experiment_name` directory, make predictions over the entire test set using the smoothed classifier, and certify the $\ell_1, \ell_2,$ and $\ell_\infty$ robust radii of these predictions. Note that by default we make predictions with $64$ samples, certify with $100,000$ samples, and at a failure probability of $\alpha=0.001$. 68 | 69 | To draw a comparison to the benchmark noises, re-run the above replacing `Uniform` with `Gaussian` and `Laplace`. Then to plot the figures and print the table of results (for $\ell_1$ adversary), run our analysis script: 70 | 71 | ```shell 72 | python3 -m scripts.analyze --dir=ckpts --show --adv=1 73 | ``` 74 | 75 | Note that other noises will need to be instantiated with the appropriate arguments when the appropriate training/testing code is invoked. For example, if we want to sample noise $\propto \|x\|_\infty^{-100}e^{-\|x\|_\infty^{10}}$, we would run: 76 | 77 | ```shell 78 | python3 -m src.train 79 | --noise=ExpInf 80 | --k=10 81 | --j=100 82 | --sigma=0.5 83 | --experiment-name=cifar_expinf_0.5 84 | ``` 85 | 86 | ## Trained Models 87 | 88 | Our pre-trained models are available. 89 | 90 | The following commands will download all models into the `pretrain/` directory. 91 | 92 | ```shell 93 | mkdir -p pretrain 94 | wget --directory-prefix=pretrain http://www.tonyduan.com/resources/2020_rs4a_ckpts/cifar_all.zip 95 | unzip -d pretrain pretrain/cifar_all.zip 96 | wget --directory-prefix=pretrain http://www.tonyduan.com/resources/2020_rs4a_ckpts/imagenet_all.zip 97 | unzip -d pretrain pretrain/imagenet_all.zip 98 | ``` 99 | 100 | ImageNet (ResNet-50): [[All Models, 2.3 GB]](http://www.tonyduan.com/resources/2020_rs4a_ckpts/imagenet_all.zip) 101 | 102 | - Sigma=0.25: [[Uniform]](http://www.tonyduan.com/resources/2020_rs4a_ckpts/imagenet_uniform_025.pt) [[Gaussian]](http://www.tonyduan.com/resources/2020_rs4a_ckpts/imagenet_gaussian_025.pt) [[Laplace]](http://www.tonyduan.com/resources/2020_rs4a_ckpts/imagenet_laplace_025.pt) 103 | - Sigma=0.5: [[Uniform]](http://www.tonyduan.com/resources/2020_rs4a_ckpts/imagenet_uniform_050.pt) [[Gaussian]](http://www.tonyduan.com/resources/2020_rs4a_ckpts/imagenet_gaussian_050.pt) [[Laplace]](http://www.tonyduan.com/resources/2020_rs4a_ckpts/imagenet_laplace_050.pt) 104 | - Sigma=0.75: [[Uniform]](http://www.tonyduan.com/resources/2020_rs4a_ckpts/imagenet_uniform_075.pt) [[Gaussian]](http://www.tonyduan.com/resources/2020_rs4a_ckpts/imagenet_gaussian_075.pt) [[Laplace]](http://www.tonyduan.com/resources/2020_rs4a_ckpts/imagenet_laplace_075.pt) 105 | - Sigma=1.0: [[Uniform]](http://www.tonyduan.com/resources/2020_rs4a_ckpts/imagenet_uniform_100.pt) [[Gaussian]](http://www.tonyduan.com/resources/2020_rs4a_ckpts/imagenet_gaussian_100.pt) [[Laplace]](http://www.tonyduan.com/resources/2020_rs4a_ckpts/imagenet_laplace_100.pt) 106 | - Sigma=1.25: [[Uniform]](http://www.tonyduan.com/resources/2020_rs4a_ckpts/imagenet_uniform_125.pt) [[Gaussian]](http://www.tonyduan.com/resources/2020_rs4a_ckpts/imagenet_gaussian_125.pt) [[Laplace]](http://www.tonyduan.com/resources/2020_rs4a_ckpts/imagenet_laplace_125.pt) 107 | - Sigma=1.5: [[Uniform]](http://www.tonyduan.com/resources/2020_rs4a_ckpts/imagenet_uniform_150.pt) 108 | - Sigma=1.75: [[Uniform]](http://www.tonyduan.com/resources/2020_rs4a_ckpts/imagenet_uniform_175.pt) 109 | - Sigma=2.0: [[Uniform]](http://www.tonyduan.com/resources/2020_rs4a_ckpts/imagenet_uniform_200.pt) 110 | - Sigma=2.25: [[Uniform]](http://www.tonyduan.com/resources/2020_rs4a_ckpts/imagenet_uniform_225.pt) 111 | - Sigma=2.5: [[Uniform]](http://www.tonyduan.com/resources/2020_rs4a_ckpts/imagenet_uniform_250.pt) 112 | - Sigma=2.75: [[Uniform]](http://www.tonyduan.com/resources/2020_rs4a_ckpts/imagenet_uniform_275.pt) 113 | - Sigma=3.0: [[Uniform]](http://www.tonyduan.com/resources/2020_rs4a_ckpts/imagenet_uniform_300.pt) 114 | - Sigma=3.25: [[Uniform]](http://www.tonyduan.com/resources/2020_rs4a_ckpts/imagenet_uniform_325.pt) 115 | - Sigma=3.5: [[Uniform]](http://www.tonyduan.com/resources/2020_rs4a_ckpts/imagenet_uniform_350.pt) 116 | 117 | CIFAR-10 (Wide ResNet 40-2): [[All Models, 226 MB]](http://www.tonyduan.com/resources/2020_rs4a_ckpts/cifar_all.zip) 118 | 119 | - Sigma=0.15: [[Uniform]](http://www.tonyduan.com/resources/2020_rs4a_ckpts/cifar_uniform_015.pt) [[Gaussian]](http://www.tonyduan.com/resources/2020_rs4a_ckpts/cifar_gaussian_015.pt) [[Laplace]](http://www.tonyduan.com/resources/2020_rs4a_ckpts/cifar_laplace_015.pt) 120 | - Sigma=0.25: [[Uniform]](http://www.tonyduan.com/resources/2020_rs4a_ckpts/cifar_uniform_025.pt) [[Gaussian]](http://www.tonyduan.com/resources/2020_rs4a_ckpts/cifar_gaussian_025.pt) [[Laplace]](http://www.tonyduan.com/resources/2020_rs4a_ckpts/cifar_laplace_025.pt) 121 | - Sigma=0.5: [[Uniform]](http://www.tonyduan.com/resources/2020_rs4a_ckpts/cifar_uniform_050.pt) [[Gaussian]](http://www.tonyduan.com/resources/2020_rs4a_ckpts/cifar_gaussian_050.pt) [[Laplace]](http://www.tonyduan.com/resources/2020_rs4a_ckpts/cifar_laplace_050.pt) 122 | - Sigma=0.75: [[Uniform]](http://www.tonyduan.com/resources/2020_rs4a_ckpts/cifar_uniform_075.pt) [[Gaussian]](http://www.tonyduan.com/resources/2020_rs4a_ckpts/cifar_gaussian_075.pt) [[Laplace]](http://www.tonyduan.com/resources/2020_rs4a_ckpts/cifar_laplace_075.pt) 123 | - Sigma=1.0: [[Uniform]](http://www.tonyduan.com/resources/2020_rs4a_ckpts/cifar_uniform_100.pt) [[Gaussian]](http://www.tonyduan.com/resources/2020_rs4a_ckpts/cifar_gaussian_100.pt) [[Laplace]](http://www.tonyduan.com/resources/2020_rs4a_ckpts/cifar_laplace_100.pt) 124 | - Sigma=1.25: [[Uniform]](http://www.tonyduan.com/resources/2020_rs4a_ckpts/cifar_uniform_125.pt) [[Gaussian]](http://www.tonyduan.com/resources/2020_rs4a_ckpts/cifar_gaussian_125.pt) [[Laplace]](http://www.tonyduan.com/resources/2020_rs4a_ckpts/cifar_laplace_125.pt) 125 | - Sigma=1.5: [[Uniform]](http://www.tonyduan.com/resources/2020_rs4a_ckpts/cifar_uniform_150.pt) 126 | - Sigma=1.75: [[Uniform]](http://www.tonyduan.com/resources/2020_rs4a_ckpts/cifar_uniform_175.pt) 127 | - Sigma=2.0: [[Uniform]](http://www.tonyduan.com/resources/2020_rs4a_ckpts/cifar_uniform_200.pt) 128 | - Sigma=2.25: [[Uniform]](http://www.tonyduan.com/resources/2020_rs4a_ckpts/cifar_uniform_225.pt) 129 | - Sigma=2.5: [[Uniform]](http://www.tonyduan.com/resources/2020_rs4a_ckpts/cifar_uniform_250.pt) 130 | - Sigma=2.75: [[Uniform]](http://www.tonyduan.com/resources/2020_rs4a_ckpts/cifar_uniform_275.pt) 131 | - Sigma=3.0: [[Uniform]](http://www.tonyduan.com/resources/2020_rs4a_ckpts/cifar_uniform_300.pt) 132 | - Sigma=3.25: [[Uniform]](http://www.tonyduan.com/resources/2020_rs4a_ckpts/cifar_uniform_325.pt) 133 | - Sigma=3.5: [[Uniform]](http://www.tonyduan.com/resources/2020_rs4a_ckpts/cifar_uniform_350.pt) 134 | 135 | By default the models above were trained with noise augmentation. We further improve upon our state-of-the-art certified accuracies using recent advances in training smoothed classifiers: (1) by using stability training (Li et al. NeurIPS 2019), and (2) by leveraging additional data using (a) pre-training on downsampled ImageNet (Hendrycks et al. NeurIPS 2019) and (b) semi-supervised self-training with data from 80 Million Tiny Images (Carmon et al. 2019). Our improved models trained with these methods are released below. 136 | 137 | ImageNet (ResNet 50): 138 | 139 | - Stability training: [[All Models, 2.3 GB]](http://www.tonyduan.com/resources/2020_rs4a_ckpts/imagenet_stability.zip) 140 | 141 | CIFAR-10 (Wide ResNet 40-2): 142 | 143 | - Stability training: [[All Models, 234 MB]](http://www.tonyduan.com/resources/2020_rs4a_ckpts/cifar_stability.zip) 144 | - Stability training + pre-training: [[All Models, 236 MB]](http://www.tonyduan.com/resources/2020_rs4a_ckpts/cifar_pretrained.zip) 145 | - Stability training + semi-supervised learning: [[All Models, 235 MB]](http://www.tonyduan.com/resources/2020_rs4a_ckpts/cifar_semisup.zip) 146 | 147 | An example of pre-trained model usage is below. For more in depth example see our `tutorial.ipynb` notebook. 148 | 149 | ```python 150 | from src.models import WideResNet 151 | from src.noises import Uniform 152 | from src.smooth import * 153 | 154 | # load the model 155 | model = WideResNet(dataset="cifar", device="cuda") 156 | saved_dict = torch.load("pretrain/cifar_uniform_050.pt") 157 | model.load_state_dict(saved_dict) 158 | model.eval() 159 | 160 | # instantiation of noise 161 | noise = Uniform(device="cpu", dim=3072, sigma=0.5) 162 | 163 | # training code, to generate samples 164 | noisy_x = noise.sample(x) 165 | 166 | # testing code, certify for L1 adversary 167 | preds = smooth_predict_hard(model, x, noise, 64) 168 | top_cats = preds.probs.argmax(dim=1) 169 | prob_lb = certify_prob_lb(model, x, top_cats, 0.001, noise, 100000) 170 | radius = noise.certify(prob_lb, adv=1) 171 | ``` 172 | 173 | ## Repository 174 | 175 | 1. `ckpts/` is used to store experiment checkpoints and results. 176 | 2. `data/` is used to store image datasets. 177 | 4. `tables/` contains caches of pre-calculated tables of certified radii. 178 | 5. `src/` contains the main souce code. 179 | 6. `scripts/` contains the analysis and plotting code. 180 | 181 | Within the `src/` directory, the most salient files are: 182 | 183 | 1. `train.py` is used to train models and save to `ckpts/`. 184 | 2. `test.py` is used to test and compute robust certificates for $\ell_1,\ell_2,\ell_\infty$ adversaries. 185 | 3. `noises/test_noises.py` is a unit test for the noises we include. Run the test with 186 | 187 | ```python -m unittest src/noises/test_noises.py``` 188 | 189 | Note that some tests are probabilistic and can fail occasionally. 190 | If so, rerun a few more times to make sure the failure is not persistent. 191 | 192 | 4. `noises/noises.py` is a library of noises derived for randomized smoothing. -------------------------------------------------------------------------------- /ckpts/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tonyduan/rs4a/a02831806ee4a117ef82ce5738fe8c5e8f180b06/ckpts/.gitkeep -------------------------------------------------------------------------------- /data/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tonyduan/rs4a/a02831806ee4a117ef82ce5738fe8c5e8f180b06/data/.gitkeep -------------------------------------------------------------------------------- /scripts/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tonyduan/rs4a/a02831806ee4a117ef82ce5738fe8c5e8f180b06/scripts/__init__.py -------------------------------------------------------------------------------- /scripts/analyze.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | import os 4 | import pickle 5 | import matplotlib as mpl 6 | import seaborn as sns 7 | from argparse import ArgumentParser 8 | from collections import defaultdict 9 | from dfply import * 10 | from matplotlib import pyplot as plt 11 | from src.noises import * 12 | from src.datasets import get_dim 13 | 14 | 15 | if __name__ == "__main__": 16 | 17 | argparser = ArgumentParser() 18 | argparser.add_argument("--dir", default="./ckpts", type=str) 19 | argparser.add_argument("--debug", action="store_true") 20 | argparser.add_argument("--show", action="store_true") 21 | argparser.add_argument("--adv", default=1, type=float) 22 | argparser.add_argument("--eps-max", default=5.0, type=float) 23 | argparser.add_argument("--fancy-markers", action="store_true") 24 | args = argparser.parse_args() 25 | args.adv = round(args.adv) 26 | 27 | markers = ["o", "D", "s"] if args.fancy_markers else True 28 | 29 | sns.set_context("notebook", rc={"lines.linewidth": 2}) 30 | sns.set_style("whitegrid") 31 | sns.set_palette("husl") 32 | 33 | dataset = args.dir.split("_")[0] 34 | experiment_names = list(filter(lambda s: os.path.isdir(args.dir + "/" + s), 35 | os.listdir(args.dir))) 36 | 37 | df = defaultdict(list) 38 | eps_range = np.linspace(0, args.eps_max, 81) 39 | 40 | for experiment_name in experiment_names: 41 | 42 | save_path = f"{args.dir}/{experiment_name}" 43 | results = {} 44 | experiment_args = pickle.load(open(f"{args.dir}/{experiment_name}/args.pkl", "rb")) 45 | 46 | for k in ("preds", "labels", f"radius_l{str(args.adv)}", "acc_train"): 47 | results[k] = np.load(f"{save_path}/{k}.npy") 48 | 49 | noise = parse_noise_from_args(experiment_args, device="cpu", 50 | dim=get_dim(experiment_args.dataset)) 51 | 52 | top_1_preds = np.argmax(results["preds"], axis=1) 53 | top_1_acc_pred = (top_1_preds == results["labels"]).mean() 54 | 55 | if experiment_args.adversarial: 56 | noise_str = noise.plotstr() + f",$\\epsilon={experiment_args.eps}$" 57 | else: 58 | noise_str = noise.plotstr() 59 | 60 | for eps in eps_range: 61 | 62 | top_1_acc_cert = ((results[f"radius_l{str(args.adv)}"] >= eps) & \ 63 | (top_1_preds == results["labels"])).mean() 64 | df["experiment_name"].append(experiment_name) 65 | df["sigma"].append(noise.sigma) 66 | df["noise"].append(noise_str) 67 | df["eps"].append(eps) 68 | df["top_1_acc_train"].append(results["acc_train"][0]) 69 | df["top_1_acc_cert"].append(top_1_acc_cert) 70 | df["top_1_acc_pred"].append(top_1_acc_pred) 71 | 72 | # save the experiment results 73 | df = pd.DataFrame(df) >> arrange(X.noise) 74 | df.to_csv(f"{args.dir}/results_{dataset}_l{str(args.adv)}.csv", index=False) 75 | 76 | if args.debug: 77 | breakpoint() 78 | 79 | # print top-1 certified accuracies for table in paper 80 | print(df >> mask(X.eps.isin((0.5, 1.0, 1.5, 2.0, 2.5, 3.0, 3.5, 4.0))) \ 81 | >> group_by(X.eps, X.noise) >> arrange(X.top_1_acc_cert, ascending=False) >> head(1)) 82 | 83 | # plot clean training accuracy against certified accuracy at eps 84 | # tmp = df >> mask(X.eps == 0.25) >> arrange(X.noise) 85 | # plt.figure(figsize=(3, 2.8)) 86 | # ax = sns.scatterplot(x="top_1_acc_train", y="top_1_acc_cert", hue="noise", style="noise", 87 | # markers=markers, size="sigma", data=tmp, legend="full") 88 | # handles, labels = ax.get_legend_handles_labels() 89 | # i = [i for i, t in enumerate(ax.legend_.texts) if t.get_text() == "sigma"][0] 90 | # ax.legend(handles[:i], labels[:i]) 91 | # plt.plot(np.linspace(0.0, 1.0), np.linspace(0.0, 1.0), "--", color="gray") 92 | # plt.ylim((0.2, 1.0)) 93 | # plt.xlim((0.2, 1.0)) 94 | # plt.xlabel("Top-1 training accuracy") 95 | # plt.ylabel("Top-1 certified accuracy, $\epsilon$ = 0.25") 96 | # plt.tight_layout() 97 | # plt.savefig(f"{args.dir}/train_vs_certified.pdf") 98 | # 99 | # tmp = df >> mask(X.eps.isin((0.25, 0.5, 0.75, 1.0))) >> \ 100 | # mutate(tr=X.top_1_acc_train, cert=X.top_1_acc_cert) 101 | # fig = sns.relplot(data=tmp, kind="scatter", x="tr", y="cert", 102 | # hue="noise", col="eps", col_wrap=2, aspect=1, height=3, size="sigma") 103 | # fig.map_dataframe(plt.plot, (plt.xlim()[0], plt.xlim()[1]), (plt.xlim()[0], plt.xlim()[1]), 'k--').set_axis_labels("tr", "cert").add_legend() 104 | # 105 | # plot clean training and testing accuracy 106 | grouped = df >> group_by(X.experiment_name) \ 107 | >> mask(X.sigma <= 1.25) \ 108 | >> summarize(experiment_name=first(X.experiment_name), 109 | noise=first(X.noise), 110 | sigma=first(X.sigma), 111 | top_1_acc_train=first(X.top_1_acc_train), 112 | top_1_acc_pred=first(X.top_1_acc_pred)) 113 | 114 | plt.figure(figsize=(6.5, 2.5)) 115 | plt.subplot(1, 2, 1) 116 | sns.lineplot(x="sigma", y="top_1_acc_train", hue="noise", markers=markers, 117 | style="noise", data=grouped, alpha=1) 118 | plt.xlabel("$\sigma$") 119 | plt.ylabel("Top-1 training accuracy") 120 | plt.ylim((0, 1)) 121 | plt.subplot(1, 2, 2) 122 | sns.lineplot(x="sigma", y="top_1_acc_pred", hue="noise", markers=markers, 123 | style="noise", data=grouped, alpha=1, legend=False) 124 | plt.xlabel("$\sigma$") 125 | plt.ylabel("Top-1 testing accuracy") 126 | plt.ylim((0, 1)) 127 | plt.tight_layout() 128 | plt.savefig(f"{args.dir}/train_test_accuracies.pdf") 129 | 130 | # plot certified accuracies 131 | selected = df >> mutate(certacc=X.top_1_acc_cert) 132 | sns.relplot(x="eps", y="certacc", hue="noise", kind="line", col="sigma", 133 | data=selected, height=2, aspect=1.5, col_wrap=2) 134 | plt.ylim((0, 1)) 135 | plt.tight_layout() 136 | plt.savefig(f"{args.dir}/per_sigma_l{str(args.adv)}.pdf") 137 | 138 | # plot top certified accuracy per epsilon, per type of noise 139 | grouped = df >> mask(X.noise != "Clean") \ 140 | >> group_by(X.eps, X.noise) \ 141 | >> arrange(X.top_1_acc_cert, ascending=False) \ 142 | >> summarize(top_1_acc_cert=first(X.top_1_acc_cert), 143 | noise=first(X.noise)) 144 | 145 | plt.figure(figsize=(3.0, 2.8)) 146 | sns.lineplot(x="eps", y="top_1_acc_cert", data=grouped, hue="noise", style="noise") 147 | plt.ylim((0, 1)) 148 | plt.xlabel(f"$\\ell_{str(args.adv)}$ radius") 149 | plt.ylabel("Top-1 certified accuracy") 150 | plt.tight_layout() 151 | plt.savefig(f"{args.dir}/certified_accuracies_l{str(args.adv)}.pdf", bbox_inches="tight") 152 | 153 | if args.show: 154 | plt.show() 155 | 156 | -------------------------------------------------------------------------------- /scripts/analyze_verify.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | import os 4 | import pickle 5 | import matplotlib as mpl 6 | import seaborn as sns 7 | from argparse import ArgumentParser 8 | from collections import defaultdict 9 | from dfply import * 10 | from matplotlib import pyplot as plt 11 | 12 | 13 | if __name__ == "__main__": 14 | 15 | argparser = ArgumentParser() 16 | argparser.add_argument("--experiment-name", default="cifar_uniform_05", type=str) 17 | argparser.add_argument("--dir", default="./ckpts", type=str) 18 | args = argparser.parse_args() 19 | 20 | sns.set_style("white") 21 | sns.set_palette("husl") 22 | 23 | df = defaultdict(list) 24 | eps_range = (3.0, 2.0, 1.0, 0.5, 0.25) 25 | 26 | save_path = f"{args.dir}/{args.experiment_name}" 27 | experiment_args = pickle.load(open(f"{args.dir}/{args.experiment_name}/args.pkl", "rb")) 28 | results = {} 29 | 30 | for k in ["preds_smooth", "radius_smooth", "labels"] + \ 31 | [f"preds_adv_{eps}" for eps in eps_range]: 32 | results[k] = np.load(f"{save_path}/{k}.npy") 33 | 34 | top_1_preds_smooth = np.argmax(results["preds_smooth"], axis=1) 35 | 36 | for eps in eps_range: 37 | 38 | top_1_preds_adv = np.argmax(results[f"preds_adv_{eps}"], axis=1) 39 | top_1_acc_cert = ((results["radius_smooth"] >= eps) & \ 40 | (top_1_preds_smooth == results["labels"])).mean() 41 | top_1_acc_adv = (top_1_preds_adv == results["labels"]).mean() 42 | df["eps"].append(eps) 43 | df["top_1_acc_cert"].append(top_1_acc_cert) 44 | df["top_1_acc_adv"].append(top_1_acc_adv) 45 | 46 | breakpoint() 47 | df = pd.DataFrame(df) >> gather("type", "top_1_acc", ["top_1_acc_cert", "top_1_acc_adv"]) 48 | 49 | plt.figure(figsize=(5, 3)) 50 | sns.lineplot(x="eps", y="top_1_acc", hue="type", data=df) 51 | plt.title(args.experiment_name) 52 | plt.ylim((0, 1)) 53 | plt.tight_layout() 54 | plt.show() 55 | 56 | -------------------------------------------------------------------------------- /scripts/check_noise_propagation.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | import seaborn as sns 4 | import torch 5 | import torch.nn.functional as F 6 | import os 7 | import itertools 8 | from argparse import ArgumentParser 9 | from collections import defaultdict 10 | from torch.utils.data import DataLoader, Subset 11 | from matplotlib import pyplot as plt 12 | from tqdm import tqdm 13 | from src.attacks import * 14 | from src.noises import * 15 | from src.models import * 16 | from src.datasets import get_dataset 17 | 18 | 19 | def get_final_layer_mlp(model, x): 20 | out = model.model[0](x.reshape(x.shape[0], -1)) 21 | out = model.model[1](out) 22 | out = model.model[2](out) 23 | out = model.model[3](out) 24 | return out 25 | 26 | def get_final_layer(model, x): 27 | out = model.model.conv1(x) 28 | out = model.model.block1(out) 29 | out = model.model.block2(out) 30 | out = model.model.block3(out) 31 | out = model.model.relu(model.model.bn1(out)) 32 | out = F.avg_pool2d(out, 8) 33 | return out.view(-1, model.model.nChannels) 34 | 35 | if __name__ == "__main__": 36 | 37 | argparser = ArgumentParser() 38 | argparser.add_argument("--device", default="cuda:0", type=str) 39 | argparser.add_argument("--batch-size", default=4, type=int), 40 | argparser.add_argument("--num-workers", default=os.cpu_count(), type=int) 41 | argparser.add_argument("--sample-size", default=64, type=int) 42 | argparser.add_argument("--dataset", default="cifar", type=str) 43 | argparser.add_argument("--dataset-skip", default=20, type=int) 44 | argparser.add_argument("--model", default="ResNet", type=str) 45 | argparser.add_argument("--dir", type=str, default="cifar_snapshots") 46 | argparser.add_argument("--load", action="store_true") 47 | args = argparser.parse_args() 48 | 49 | sns.set_style("whitegrid") 50 | sns.set_palette("husl") 51 | 52 | noises = ["Uniform", "Gaussian", "Laplace"] 53 | epochs = np.arange(1, 30, 1) 54 | 55 | test_dataset = get_dataset(args.dataset, "test") 56 | test_dataset = Subset(test_dataset, list(range(0, len(test_dataset), args.dataset_skip))) 57 | test_loader = DataLoader(test_dataset, shuffle=False, batch_size=args.batch_size, 58 | num_workers=args.num_workers) 59 | 60 | results = defaultdict(list) 61 | 62 | for noise_str, epoch in itertools.product(noises, epochs): 63 | 64 | if args.load: 65 | break 66 | 67 | sigma = 1.0 68 | 69 | save_path = f"{args.dir}/cifar_{noise_str}_{sigma}/{epoch-1}/model_ckpt.torch" 70 | model = eval(args.model)(dataset=args.dataset, device=args.device) 71 | model.load_state_dict(torch.load(save_path)) 72 | model.eval() 73 | 74 | noise = eval(noise_str)(sigma=sigma, device=args.device, p=2, dim=get_dim(args.dataset)) 75 | 76 | for i, (x, y) in tqdm(enumerate(test_loader), total=len(test_loader)): 77 | 78 | x, y = x.to(args.device), y.to(args.device) 79 | 80 | v = x.unsqueeze(1).expand((args.batch_size, args.sample_size, 3, 32, 32)) 81 | v = v.reshape((-1, 3, 32, 32)) 82 | noised = noise.sample(v) 83 | if args.model == "ResNet": 84 | rep_noisy = get_final_layer(model, noised) 85 | elif args.model == "MLP": 86 | rep_noisy = get_final_layer_mlp(model, noised) 87 | else: 88 | raise ValueError 89 | rep_noisy = rep_noisy.reshape(args.batch_size, -1, rep_noisy.shape[-1]) 90 | 91 | top_cats = model(noised).reshape(args.batch_size, -1, 10).argmax(dim=2).mode(dim=1) 92 | top_cats = top_cats.values 93 | 94 | l2 = torch.stack([F.pdist(rep_i, p=2) for rep_i in rep_noisy]).mean(dim=1).data 95 | l1 = torch.stack([F.pdist(rep_i, p=1) for rep_i in rep_noisy]).mean(dim=1).data 96 | linf = torch.stack([F.pdist(rep_i, p=float("inf")) for rep_i in rep_noisy]).mean(dim=1).data 97 | 98 | results["acc"] += (y == top_cats).float().cpu().numpy().tolist() 99 | results["l1"] += l1.cpu().numpy().tolist() 100 | results["l2"] += l2.cpu().numpy().tolist() 101 | results["linf"] += linf.cpu().numpy().tolist() 102 | results["noise"] += args.batch_size * [noise_str] 103 | results["epoch"] += args.batch_size * [epoch] 104 | 105 | if args.load: 106 | results = pd.read_csv(f"{args.dir}/snapshots.csv") 107 | else: 108 | results = pd.DataFrame(results) 109 | results.to_csv(f"{args.dir}/snapshots.csv") 110 | 111 | plt.figure(figsize=(10, 6)) 112 | plt.subplot(2, 2, 1) 113 | sns.lineplot(x="epoch", y="l2", hue="noise", data=results) 114 | plt.xlabel("Epoch") 115 | plt.ylabel("L2") 116 | plt.legend() 117 | plt.subplot(2, 2, 2) 118 | sns.lineplot(x="epoch", y="l1", hue="noise", data=results) 119 | plt.xlabel("Epoch") 120 | plt.ylabel("L1") 121 | plt.legend() 122 | plt.subplot(2, 2, 3) 123 | sns.lineplot(x="epoch", y="linf", hue="noise", data=results) 124 | plt.xlabel("Epoch") 125 | plt.ylabel("Linf") 126 | plt.legend() 127 | plt.subplot(2, 2, 4) 128 | sns.lineplot(x="epoch", y="acc", hue="noise", data=results) 129 | plt.xlabel("Epoch") 130 | plt.ylabel("Acc") 131 | plt.ylim((0, 1)) 132 | plt.legend() 133 | plt.show() 134 | 135 | -------------------------------------------------------------------------------- /scripts/plot_losses.py: -------------------------------------------------------------------------------- 1 | # 2 | # Plot the training loss for each model trained. 3 | # 4 | import numpy as np 5 | import pandas as pd 6 | import os 7 | import pickle 8 | import matplotlib as mpl 9 | import seaborn as sns 10 | from argparse import ArgumentParser 11 | from dfply import * 12 | from matplotlib import pyplot as plt 13 | from src.noises import * 14 | from src.datasets import get_dim 15 | 16 | 17 | if __name__ == "__main__": 18 | 19 | argparser = ArgumentParser() 20 | argparser.add_argument("--dir", default="./ckpts", type=str) 21 | args = argparser.parse_args() 22 | 23 | dataset = args.dir.split("_")[0] 24 | experiment_names = list(filter(lambda s: os.path.isdir(args.dir + "/" + s), 25 | os.listdir(args.dir))) 26 | 27 | sns.set_style("white") 28 | sns.set_palette("husl") 29 | 30 | losses_df = pd.DataFrame({"noise": [], "sigma": [], "losses_train": [], "iter": []}) 31 | 32 | for experiment_name in experiment_names: 33 | 34 | save_path = f"{args.dir}/{experiment_name}" 35 | experiment_args = pickle.load(open(f"{args.dir}/{experiment_name}/args.pkl", "rb")) 36 | results = {} 37 | 38 | for k in ("losses_train",): 39 | results[k] = np.load(f"{save_path}/{k}.npy") 40 | 41 | noise = parse_noise_from_args(experiment_args, device="cpu", 42 | dim=get_dim(experiment_args.dataset)) 43 | 44 | losses_df >>= bind_rows(pd.DataFrame({ 45 | "experiment_name": experiment_name, 46 | "noise": noise.plotstr(), 47 | "sigma": experiment_args.sigma, 48 | "losses": results["losses_train"], 49 | "iter": np.arange(len(results["losses_train"]))})) 50 | 51 | # show training curves 52 | sns.relplot(x="iter", y="losses", hue="noise", data=losses_df, col="sigma", 53 | col_wrap=2, kind="line", height=1.5, aspect=3.5, alpha=0.5) 54 | plt.tight_layout() 55 | plt.show() 56 | 57 | -------------------------------------------------------------------------------- /scripts/plot_preds_distn.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import scipy as sp 3 | import scipy.stats 4 | import pandas as pd 5 | import os 6 | import pickle 7 | import matplotlib as mpl 8 | import seaborn as sns 9 | from argparse import ArgumentParser 10 | from dfply import * 11 | from matplotlib import pyplot as plt 12 | from src.noises import parse_noise_from_args 13 | from src.datasets import get_dim 14 | 15 | if __name__ == "__main__": 16 | 17 | argparser = ArgumentParser() 18 | argparser.add_argument("--dir", default="./ckpts", type=str) 19 | argparser.add_argument("--target", default="prob_correct", type=str) 20 | argparser.add_argument("--adv", default=1, type=float) 21 | argparser.add_argument("--use-pdf", action="store_true") 22 | args = argparser.parse_args() 23 | 24 | dataset = args.dir.split("_")[0] 25 | experiment_names = list(filter(lambda x: x.startswith(dataset), os.listdir(args.dir))) 26 | 27 | sns.set_style("white") 28 | sns.set_palette("husl") 29 | 30 | losses_df = pd.DataFrame({"noise": [], "sigma": [], "losses_train": [], "iter": []}) 31 | 32 | for experiment_name in experiment_names: 33 | 34 | save_path = f"{args.dir}/{experiment_name}" 35 | experiment_args = pickle.load(open(f"{args.dir}/{experiment_name}/args.pkl", "rb")) 36 | results = {} 37 | 38 | noise = parse_noise_from_args(experiment_args, device="cpu", 39 | dim=get_dim(experiment_args.dataset)) 40 | 41 | for k in ("prob_lb", "preds", "labels", f"radius_l{str(args.adv)}"): 42 | results[k] = np.load(f"{save_path}/{k}.npy") 43 | 44 | p_correct = results["preds"][np.arange(len(results["preds"])), 45 | results["labels"].astype(int)] 46 | p_top = results["prob_lb"] 47 | 48 | if args.target == "prob_lower_bound": 49 | tgt = p_top 50 | axis = np.linspace(0, 1, 500) 51 | elif args.target == "prob_correct": 52 | tgt = p_correct 53 | axis = np.linspace(0, 1, 500) 54 | elif args.target == "radius": 55 | tgt = results[f"radius_l{str(args.adv)}"] 56 | tgt = tgt[~np.isnan(tgt)] 57 | axis = np.linspace(0, 4.0, 500) 58 | else: 59 | raise ValueError 60 | 61 | if args.use_pdf: 62 | cdf = sp.stats.gaussian_kde(tgt)(axis) 63 | else: 64 | cdf = (tgt < axis[:, np.newaxis]).mean(axis=1) 65 | 66 | losses_df >>= bind_rows(pd.DataFrame({ 67 | "experiment_name": experiment_name, 68 | "noise": noise.plotstr(), 69 | "sigma": experiment_args.sigma, 70 | "cdf": cdf, 71 | "axis": axis})) 72 | 73 | # show training curves 74 | sns.relplot(x="axis", y="cdf", hue="noise", data=losses_df, col="sigma", 75 | col_wrap=2, kind="line", height=1.5, aspect=2.5, alpha=0.5) 76 | plt.show() 77 | 78 | -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tonyduan/rs4a/a02831806ee4a117ef82ce5738fe8c5e8f180b06/src/__init__.py -------------------------------------------------------------------------------- /src/attacks.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.autograd import grad 4 | from src.smooth import * 5 | 6 | 7 | def project_onto_ball(x, eps, p="inf"): 8 | """ 9 | Note that projection onto inf-norm and 2-norm take O(d) time, and projection onto 1-norm 10 | takes O(dlogd) using the sorting-based algorithm given in [Duchi et al. 2008]. 11 | """ 12 | original_shape = x.shape 13 | x = x.view(x.shape[0], -1) 14 | assert not torch.isnan(x).any() 15 | if p == "inf": 16 | x = x.clamp(-eps, eps) 17 | elif p == 2: 18 | x = x.renorm(p=2, dim=0, maxnorm=eps) 19 | elif p == 1: 20 | mask = (torch.norm(x, p=1, dim=1) < eps).float().unsqueeze(1) 21 | mu, _ = torch.sort(torch.abs(x), dim=1, descending=True) 22 | cumsum = torch.cumsum(mu, dim=1) 23 | arange = torch.arange(1, x.shape[1] + 1, device=x.device) 24 | rho, _ = torch.max((mu * arange > (cumsum - eps)) * arange, dim=1) 25 | theta = (cumsum[torch.arange(x.shape[0]), rho.cpu() - 1] - eps) / rho 26 | proj = (torch.abs(x) - theta.unsqueeze(1)).clamp(min=0) 27 | x = mask * x + (1 - mask) * proj * torch.sign(x) 28 | else: 29 | raise ValueError("Can only project onto 1,2,inf norm balls.") 30 | return x.view(original_shape) 31 | 32 | def pgd_attack(model, x, y, eps, steps=20, adv="inf", clamp=(0, 1)): 33 | """ 34 | Attack a model with PGD. 35 | """ 36 | step_size = 2 * eps / steps 37 | x.requires_grad = True 38 | x_orig = x.clone().detach() 39 | 40 | for _ in range(steps): 41 | loss = model.loss(x, y).mean() 42 | grads = grad(loss, x)[0].reshape(x.shape[0], -1) 43 | if adv == 1: 44 | keep_vals = torch.kthvalue(grads.abs(), k=grads.shape[1] * 15 // 16, dim=1).values 45 | grads[torch.abs(grads) < keep_vals.unsqueeze(1)] = 0 46 | grads = torch.sign(grads) 47 | grads_norm = torch.norm(grads, dim=1, p=1) 48 | grads = grads / (grads_norm.unsqueeze(1) + 1e-8) 49 | elif adv == 2: 50 | grads_norm = torch.norm(grads, dim=1, p=2) 51 | grads = grads / (grads_norm.unsqueeze(1) + 1e-8) 52 | elif adv == "inf": 53 | grads = torch.sign(grads) 54 | else: 55 | raise ValueError 56 | diff = x + step_size * grads.reshape(x.shape) - x_orig 57 | diff = project_onto_ball(diff, eps, adv) 58 | x = (x_orig + diff).clamp(*clamp) 59 | 60 | loss = model.loss(x, y).mean() 61 | x = x.detach() 62 | x.requires_grad = False 63 | return x, loss 64 | 65 | def pgd_attack_smooth(model, x, y, eps, noise, sample_size, steps=20, adv="inf", clamp=(0, 1)): 66 | """ 67 | Attack a smoothed model with PGD. 68 | """ 69 | step_size = 2 * eps / steps 70 | x.requires_grad = True 71 | x_orig = x.clone().detach() 72 | rng = torch.cuda.get_rng_state_all() 73 | 74 | for _ in range(steps): 75 | torch.cuda.set_rng_state_all(rng) 76 | forecast = smooth_predict_soft(model, x, noise, sample_size) 77 | loss = -forecast.log_prob(y).mean() 78 | grads = grad(loss, x)[0].reshape(x.shape[0], -1) 79 | if adv == 1: 80 | keep_vals = torch.kthvalue(grads.abs(), k=grads.shape[1] * 15 // 16, dim=1).values 81 | grads[torch.abs(grads) < keep_vals.unsqueeze(1)] = 0 82 | grads = torch.sign(grads) 83 | grads_norm = torch.norm(grads, dim=1, p=1) 84 | grads = grads / (grads_norm.unsqueeze(1) + 1e-8) 85 | elif adv == 2: 86 | grads_norm = torch.norm(grads, dim=1, p=2) 87 | grads = grads / (grads_norm.unsqueeze(1) + 1e-8) 88 | elif adv == "inf": 89 | grads = torch.sign(grads) 90 | else: 91 | raise ValueError 92 | diff = x + step_size * grads.reshape(x.shape) - x_orig 93 | diff = project_onto_ball(diff, eps, adv) 94 | x = (x_orig + diff).clamp(*clamp) 95 | # forecast = smooth_predict_hard(model, x, noise, sample_size).probs 96 | # print(_, (torch.argmax(forecast, dim=1) == y).sum() / float(x.shape[0]), 97 | # diff.reshape(x.shape[0], -1).norm(dim=1, p=1).mean(), 98 | # diff.reshape(x.shape[0], -1).norm(dim=1, p=2).mean()) 99 | 100 | torch.cuda.set_rng_state_all(rng) 101 | forecast = smooth_predict_soft(model, x, noise, sample_size) 102 | loss = -forecast.log_prob(y).mean() 103 | 104 | x = x.detach() 105 | x.requires_grad = False 106 | return x, loss 107 | 108 | -------------------------------------------------------------------------------- /src/datasets.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import random 3 | from torchvision import datasets, transforms 4 | from torch.utils.data import ConcatDataset 5 | from src.lib.zipdata import ZipData 6 | from src.lib.ds_imagenet import DSImageNet 7 | from src.lib.cifar10_selftrained import CIFAR10SelfTrained 8 | 9 | 10 | def get_dim(name): 11 | if name == "cifar": 12 | return 3 * 32 * 32 13 | if name == "svhn": 14 | return 3 * 32 * 32 15 | if name == "ds-imagenet": 16 | return 3 * 32 * 32 17 | if name == "mnist": 18 | return 28 * 28 19 | if name == "imagenet": 20 | return 3 * 224 * 224 21 | if name == "fashion": 22 | return 28 * 28 23 | if name == "cifar10selftrained": 24 | return 3 * 32 * 32 25 | 26 | def get_num_labels(name): 27 | return 1000 if "imagenet" in name else 10 28 | 29 | def get_normalization_shape(name): 30 | if name == "cifar": 31 | return (3, 1, 1) 32 | if name == "imagenet": 33 | return (3, 1, 1) 34 | if name == "ds-imagenet": 35 | return (3, 1, 1) 36 | if name == "svhn": 37 | return (3, 1, 1) 38 | if name == "mnist": 39 | return (1, 1, 1) 40 | if name == "fashion": 41 | return (1, 1, 1) 42 | if name == "cifar10selftrained": 43 | return (3, 1, 1) 44 | 45 | def get_normalization_stats(name): 46 | if name == "cifar" or name == "cifar10selftrained": 47 | return {"mu": [0.4914, 0.4822, 0.4465], "sigma": [0.2023, 0.1994, 0.2010]} 48 | if name == "imagenet" or name == "ds-imagenet": 49 | return {"mu": [0.485, 0.456, 0.406], "sigma": [0.229, 0.224, 0.225]} 50 | if name == "svhn": 51 | return {"mu": [0.436, 0.442, 0.471], "sigma": [0.197, 0.200, 0.196]} 52 | if name == "mnist": 53 | return {"mu": [0.1307,], "sigma": [0.3081,]} 54 | if name == "fashion": 55 | return {"mu": [0.2849,], "sigma": [0.3516,]} 56 | 57 | def get_dataset(name, split): 58 | 59 | if name == "cifar" and split == "train": 60 | return datasets.CIFAR10("./data/cifar_10", train=True, download=True, 61 | transform=transforms.Compose([transforms.RandomCrop(32, padding=4), 62 | transforms.RandomHorizontalFlip(), 63 | transforms.ToTensor()])) 64 | if name == "cifar" and split == "test": 65 | return datasets.CIFAR10("./data/cifar_10", train=False, download=True, 66 | transform=transforms.ToTensor()) 67 | 68 | if name == "imagenet" and split == "train": 69 | return ZipData("/mnt/bucket/imagenet/train.zip", 70 | "/mnt/bucket/imagenet/train_map.txt", 71 | transforms.Compose([transforms.RandomResizedCrop(224), 72 | transforms.RandomHorizontalFlip(), 73 | transforms.ToTensor()])) 74 | 75 | if name == "imagenet" and split == "test": 76 | return ZipData("/mnt/bucket/imagenet/val.zip", 77 | "/mnt/bucket/imagenet/val_map.txt", 78 | transforms.Compose([transforms.Resize(256), 79 | transforms.CenterCrop(224), 80 | transforms.ToTensor()])) 81 | 82 | if name == "ds-imagenet" and split == "train": 83 | return DSImageNet("/mnt/bucket/downsampled_imagenet/", split="train", 84 | transform=transforms.Compose([transforms.RandomHorizontalFlip(), 85 | transforms.ToTensor()])) 86 | 87 | if name == "ds-imagenet" and split == "test": 88 | return DSImageNet("/mnt/bucket/downsampled_imagenet/", split="test", 89 | transform=transforms.ToTensor()) 90 | 91 | if name == "mnist": 92 | return datasets.MNIST("./data/mnist", train=(split == "train"), download=True, 93 | transform=transforms.ToTensor()) 94 | if name == "fashion": 95 | return datasets.FashionMNIST("./data/fashion", train=(split == "train"), download=True, 96 | transform=transforms.ToTensor()) 97 | if name == "svhn": 98 | return datasets.SVHN("./data/svhn", split=split, download=True, 99 | transform=transforms.ToTensor()) 100 | 101 | if name == "cifar10selftrained" and split == "train": 102 | return ConcatDataset([CIFAR10SelfTrained("/mnt/bucket/cifar10_selftrained/ti_500K_pseudo_labeled.pickle", 103 | transform=transforms.Compose([transforms.RandomHorizontalFlip(), 104 | transforms.ToTensor()])), 105 | datasets.CIFAR10("./data/cifar_10", train=True, download=True, 106 | transform=transforms.Compose([transforms.RandomCrop(32, padding=4), 107 | transforms.RandomHorizontalFlip(), 108 | transforms.ToTensor()])), 109 | ]) 110 | 111 | raise ValueError 112 | -------------------------------------------------------------------------------- /src/lib/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tonyduan/rs4a/a02831806ee4a117ef82ce5738fe8c5e8f180b06/src/lib/__init__.py -------------------------------------------------------------------------------- /src/lib/alexnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class AlexNet(nn.Module): 7 | 8 | def __init__(self, num_classes, drop_rate=0.5): 9 | super().__init__() 10 | self.features = nn.Sequential( 11 | nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1), 12 | nn.ReLU(), 13 | nn.MaxPool2d(kernel_size=2), 14 | nn.Conv2d(64, 192, kernel_size=3, padding=1), 15 | nn.ReLU(), 16 | nn.MaxPool2d(kernel_size=2), 17 | nn.Conv2d(192, 384, kernel_size=3, padding=1), 18 | nn.ReLU(), 19 | nn.Conv2d(384, 256, kernel_size=3, padding=1), 20 | nn.ReLU(), 21 | nn.Conv2d(256, 256, kernel_size=3, padding=1), 22 | nn.ReLU(), 23 | nn.MaxPool2d(kernel_size=2), 24 | ) 25 | self.classifier = nn.Sequential( 26 | nn.Dropout(p=drop_rate), 27 | nn.Linear(256 * 2 * 2, 4096), 28 | nn.ReLU(), 29 | nn.Dropout(p=drop_rate), 30 | nn.Linear(4096, 4096), 31 | nn.ReLU(), 32 | nn.Linear(4096, num_classes), 33 | ) 34 | 35 | def forward(self, x): 36 | x = self.features(x) 37 | x = x.view(x.size(0), 256 * 2 * 2) 38 | x = self.classifier(x) 39 | return x 40 | 41 | -------------------------------------------------------------------------------- /src/lib/cifar10_selftrained.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pickle 3 | import os 4 | import torch 5 | from PIL import Image 6 | from torch.utils.data import Dataset 7 | 8 | 9 | class CIFAR10SelfTrained(Dataset): 10 | 11 | def __init__(self, path, transform=None, target_transform=None): 12 | with open(path, "rb") as fd: 13 | self.dataset = pickle.load(fd) 14 | self.transform = transform 15 | self.target_transform = target_transform 16 | 17 | def __getitem__(self, index): 18 | 19 | img = self.dataset["data"][index] 20 | target = self.dataset["extrapolated_targets"][index] 21 | 22 | # doing this so that it is consistent with all other datasets 23 | # to return a PIL Image 24 | img = Image.fromarray(img) 25 | 26 | if self.transform is not None: 27 | img = self.transform(img) 28 | 29 | if self.target_transform is not None: 30 | target = self.target_transform(target) 31 | 32 | return img, target 33 | 34 | def __len__(self): 35 | return len(self.dataset["extrapolated_targets"]) 36 | 37 | -------------------------------------------------------------------------------- /src/lib/classic_resnet.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Properly implemented ResNet-s for CIFAR10 as described in paper [1]. 3 | 4 | The implementation and structure of this file is hugely influenced by [2] 5 | which is implemented for ImageNet and doesn't have option A for identity. 6 | Moreover, most of the implementations on the web is copy-paste from 7 | torchvision's resnet and has wrong number of params. 8 | 9 | Proper ResNet-s for CIFAR10 (for fair comparision and etc.) has following 10 | number of layers and parameters: 11 | 12 | name | layers | params 13 | ResNet20 | 20 | 0.27M 14 | ResNet32 | 32 | 0.46M 15 | ResNet44 | 44 | 0.66M 16 | ResNet56 | 56 | 0.85M 17 | ResNet110 | 110 | 1.7M 18 | ResNet1202| 1202 | 19.4m 19 | 20 | which this implementation indeed has. 21 | 22 | Reference: 23 | [1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun 24 | Deep Residual Learning for Image Recognition. arXiv:1512.03385 25 | [2] https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py 26 | 27 | If you use this implementation in you work, please don't forget to mention the 28 | author, Yerlan Idelbayev. 29 | ''' 30 | import torch 31 | import torch.nn as nn 32 | import torch.nn.functional as F 33 | import torch.nn.init as init 34 | 35 | from torch.autograd import Variable 36 | 37 | __all__ = ['ResNet', 'resnet20', 'resnet32', 'resnet44', 'resnet56', 'resnet110', 'resnet1202'] 38 | 39 | def _weights_init(m): 40 | classname = m.__class__.__name__ 41 | #print(classname) 42 | if isinstance(m, nn.Linear) or isinstance(m, nn.Conv2d): 43 | init.kaiming_normal_(m.weight) 44 | 45 | class LambdaLayer(nn.Module): 46 | def __init__(self, lambd): 47 | super(LambdaLayer, self).__init__() 48 | self.lambd = lambd 49 | 50 | def forward(self, x): 51 | return self.lambd(x) 52 | 53 | 54 | class BasicBlock(nn.Module): 55 | expansion = 1 56 | 57 | def __init__(self, in_planes, planes, stride=1, option='A'): 58 | super(BasicBlock, self).__init__() 59 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 60 | self.bn1 = nn.BatchNorm2d(planes) 61 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 62 | self.bn2 = nn.BatchNorm2d(planes) 63 | 64 | self.shortcut = nn.Sequential() 65 | if stride != 1 or in_planes != planes: 66 | if option == 'A': 67 | """ 68 | For CIFAR10 ResNet paper uses option A. 69 | """ 70 | self.shortcut = LambdaLayer(lambda x: 71 | F.pad(x[:, :, ::2, ::2], (0, 0, 0, 0, planes//4, planes//4), "constant", 0)) 72 | elif option == 'B': 73 | self.shortcut = nn.Sequential( 74 | nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False), 75 | nn.BatchNorm2d(self.expansion * planes) 76 | ) 77 | 78 | def forward(self, x): 79 | out = F.relu(self.bn1(self.conv1(x))) 80 | out = self.bn2(self.conv2(out)) 81 | out += self.shortcut(x) 82 | out = F.relu(out) 83 | return out 84 | 85 | 86 | class ResNet(nn.Module): 87 | def __init__(self, block, num_blocks, num_classes=10): 88 | super(ResNet, self).__init__() 89 | self.in_planes = 16 90 | 91 | self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False) 92 | self.bn1 = nn.BatchNorm2d(16) 93 | self.layer1 = self._make_layer(block, 16, num_blocks[0], stride=1) 94 | self.layer2 = self._make_layer(block, 32, num_blocks[1], stride=2) 95 | self.layer3 = self._make_layer(block, 64, num_blocks[2], stride=2) 96 | self.linear = nn.Linear(64, num_classes) 97 | 98 | self.apply(_weights_init) 99 | 100 | def _make_layer(self, block, planes, num_blocks, stride): 101 | strides = [stride] + [1]*(num_blocks-1) 102 | layers = [] 103 | for stride in strides: 104 | layers.append(block(self.in_planes, planes, stride)) 105 | self.in_planes = planes * block.expansion 106 | 107 | return nn.Sequential(*layers) 108 | 109 | def forward(self, x): 110 | out = F.relu(self.bn1(self.conv1(x))) 111 | out = self.layer1(out) 112 | out = self.layer2(out) 113 | out = self.layer3(out) 114 | out = F.avg_pool2d(out, out.size()[3]) 115 | out = out.view(out.size(0), -1) 116 | out = self.linear(out) 117 | return out 118 | 119 | 120 | def resnet20(): 121 | return ResNet(BasicBlock, [3, 3, 3]) 122 | 123 | 124 | def resnet32(): 125 | return ResNet(BasicBlock, [5, 5, 5]) 126 | 127 | 128 | def resnet44(): 129 | return ResNet(BasicBlock, [7, 7, 7]) 130 | 131 | 132 | def resnet56(): 133 | return ResNet(BasicBlock, [9, 9, 9]) 134 | 135 | 136 | def resnet110(): 137 | return ResNet(BasicBlock, [18, 18, 18]) 138 | 139 | 140 | def resnet1202(): 141 | return ResNet(BasicBlock, [200, 200, 200]) 142 | 143 | 144 | def test(net): 145 | import numpy as np 146 | total_params = 0 147 | 148 | for x in filter(lambda p: p.requires_grad, net.parameters()): 149 | total_params += np.prod(x.data.numpy().shape) 150 | print("Total number of params", total_params) 151 | print("Total layers", len(list(filter(lambda p: p.requires_grad and len(p.data.size())>1, net.parameters())))) 152 | 153 | 154 | if __name__ == "__main__": 155 | for net_name in __all__: 156 | if net_name.startswith('resnet'): 157 | print(net_name) 158 | test(globals()[net_name]()) 159 | print() 160 | 161 | -------------------------------------------------------------------------------- /src/lib/ds_imagenet.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pickle 3 | import os 4 | import torch 5 | from PIL import Image 6 | from torch.utils.data import Dataset 7 | 8 | 9 | # from https://github.com/hendrycks/pre-training 10 | class DSImageNet(Dataset): 11 | """`Downsampled ImageNet `_ Datasets. 12 | Args: 13 | root (string): Root directory of dataset where directory 14 | ``ImagenetXX_train`` exists. 15 | train (bool, optional): If True, creates dataset from training set, otherwise 16 | creates from test set. 17 | transform (callable, optional): A function/transform that takes in an PIL image 18 | and returns a transformed version. E.g, ``transforms.RandomCrop`` 19 | target_transform (callable, optional): A function/transform that takes in the 20 | target and transforms it. 21 | """ 22 | train_list = [ 23 | ['train_data_batch_1', ''], 24 | ['train_data_batch_2', ''], 25 | ['train_data_batch_3', ''], 26 | ['train_data_batch_4', ''], 27 | ['train_data_batch_5', ''], 28 | ['train_data_batch_6', ''], 29 | ['train_data_batch_7', ''], 30 | ['train_data_batch_8', ''], 31 | ['train_data_batch_9', ''], 32 | ['train_data_batch_10', ''] 33 | ] 34 | 35 | test_list = [ 36 | ['val_data', ''], 37 | ] 38 | 39 | def __init__(self, root, split="train", transform=None, target_transform=None): 40 | self.root = os.path.expanduser(root) 41 | self.transform = transform 42 | self.target_transform = target_transform 43 | self.split = split # training set or test set 44 | self.base_folder = "" 45 | 46 | # if not self._check_integrity(): 47 | # raise RuntimeError('Dataset not found or corrupted.') # TODO 48 | 49 | # now load the picked numpy arrays 50 | if split == "train": 51 | self.train_data = [] 52 | self.train_labels = [] 53 | for fentry in self.train_list: 54 | f = fentry[0] 55 | file = os.path.join(self.root, f) 56 | with open(file, 'rb') as fo: 57 | entry = pickle.load(fo) 58 | self.train_data.append(entry['data']) 59 | self.train_labels += [label - 1 for label in entry['labels']] 60 | self.mean = entry['mean'] 61 | 62 | self.train_data = np.concatenate(self.train_data) 63 | self.train_data = self.train_data.reshape((self.train_data.shape[0], 3, 32, 32)) 64 | self.train_data = self.train_data.transpose((0, 2, 3, 1)) # convert to HWC 65 | else: 66 | f = self.test_list[0][0] 67 | file = os.path.join(self.root, f) 68 | fo = open(file, 'rb') 69 | entry = pickle.load(fo) 70 | self.test_data = entry['data'] 71 | self.test_labels = [label - 1 for label in entry['labels']] 72 | fo.close() 73 | self.test_data = self.test_data.reshape((self.test_data.shape[0], 3, 32, 32)) 74 | self.test_data = self.test_data.transpose((0, 2, 3, 1)) # convert to HWC 75 | 76 | def __getitem__(self, index): 77 | """ 78 | Args: 79 | index (int): Index 80 | Returns: 81 | tuple: (image, target) where target is index of the target class. 82 | """ 83 | if self.split == "train": 84 | img, target = self.train_data[index], self.train_labels[index] 85 | else: 86 | img, target = self.test_data[index], self.test_labels[index] 87 | 88 | # doing this so that it is consistent with all other datasets 89 | # to return a PIL Image 90 | img = Image.fromarray(img) 91 | 92 | if self.transform is not None: 93 | img = self.transform(img) 94 | 95 | if self.target_transform is not None: 96 | target = self.target_transform(target) 97 | 98 | return img, target 99 | 100 | def __len__(self): 101 | if self.split == "train": 102 | return len(self.train_data) 103 | else: 104 | return len(self.test_data) 105 | 106 | def _check_integrity(self): 107 | root = self.root 108 | for fentry in (self.train_list + self.test_list): 109 | filename, md5 = fentry[0], fentry[1] 110 | fpath = os.path.join(root, self.base_folder, filename) 111 | if not check_integrity(fpath, md5): 112 | return False 113 | return True 114 | 115 | 116 | -------------------------------------------------------------------------------- /src/lib/lenet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | 4 | 5 | class LeNet(nn.Module): 6 | 7 | def __init__(self, in_channels, out_dim): 8 | super(LeNet, self).__init__() 9 | self.conv1 = nn.Conv2d(in_channels, 6, 5, padding=2) 10 | self.conv2 = nn.Conv2d(6, 16, 5) 11 | self.fc1 = nn.Linear(16*5*5, 120) 12 | self.fc2 = nn.Linear(120, 84) 13 | self.fc3 = nn.Linear(84, out_dim) 14 | 15 | def forward(self, x): 16 | out = F.relu(self.conv1(x)) 17 | out = F.max_pool2d(out, 2) 18 | out = F.relu(self.conv2(out)) 19 | out = F.max_pool2d(out, 2) 20 | out = out.view(out.size(0), -1) 21 | out = F.relu(self.fc1(out)) 22 | out = F.relu(self.fc2(out)) 23 | out = self.fc3(out) 24 | return out 25 | 26 | -------------------------------------------------------------------------------- /src/lib/wide_resnet.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | 7 | class BasicBlock(nn.Module): 8 | 9 | def __init__(self, in_planes, out_planes, stride, drop_rate=0.0): 10 | super(BasicBlock, self).__init__() 11 | self.bn1 = nn.BatchNorm2d(in_planes) 12 | self.relu1 = nn.ReLU(inplace=True) 13 | self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=3, 14 | stride=stride, padding=1, bias=False) 15 | self.bn2 = nn.BatchNorm2d(out_planes) 16 | self.relu2 = nn.ReLU(inplace=True) 17 | self.conv2 = nn.Conv2d(out_planes, out_planes, kernel_size=3, 18 | stride=1, padding=1, bias=False) 19 | self.drop_rate = drop_rate 20 | self.equalInOut = (in_planes == out_planes) 21 | self.convShortcut = (not self.equalInOut) and \ 22 | nn.Conv2d(in_planes, out_planes, kernel_size=1, 23 | stride=stride, padding=0, bias=False) or None 24 | 25 | def forward(self, x): 26 | if not self.equalInOut: 27 | x = self.relu1(self.bn1(x)) 28 | else: 29 | out = self.relu1(self.bn1(x)) 30 | if self.equalInOut: 31 | out = self.relu2(self.bn2(self.conv1(out))) 32 | else: 33 | out = self.relu2(self.bn2(self.conv1(x))) 34 | if self.drop_rate > 0: 35 | out = F.dropout(out, p=self.drop_rate, training=self.training) 36 | out = self.conv2(out) 37 | if not self.equalInOut: 38 | return torch.add(self.convShortcut(x), out) 39 | else: 40 | return torch.add(x, out) 41 | 42 | 43 | class NetworkBlock(nn.Module): 44 | 45 | def __init__(self, nb_layers, in_planes, out_planes, stride, drop_rate=0.0): 46 | super().__init__() 47 | layers = [] 48 | for i in range(nb_layers): 49 | layers.append(BasicBlock(i == 0 and in_planes or \ 50 | out_planes, out_planes, i == 0 and stride or 1, drop_rate)) 51 | self.layer = nn.Sequential(*layers) 52 | 53 | def forward(self, x): 54 | return self.layer(x) 55 | 56 | 57 | class WideResNet(nn.Module): 58 | 59 | def __init__(self, depth, num_classes, widen_factor=1, drop_rate=0.0): 60 | 61 | super(WideResNet, self).__init__() 62 | nChannels = [16, 16 * widen_factor, 32 * widen_factor, 64 * widen_factor] 63 | assert ((depth - 4) % 6 == 0) 64 | n = (depth - 4) // 6 65 | block = BasicBlock 66 | self.conv1 = nn.Conv2d(3, nChannels[0], kernel_size=3, stride=1, padding=1, bias=False) 67 | self.block1 = NetworkBlock(n, nChannels[0], nChannels[1], 1, drop_rate) 68 | self.block2 = NetworkBlock(n, nChannels[1], nChannels[2], 2, drop_rate) 69 | self.block3 = NetworkBlock(n, nChannels[2], nChannels[3], 2, drop_rate) 70 | self.bn1 = nn.BatchNorm2d(nChannels[3]) 71 | self.relu = nn.ReLU(inplace=True) 72 | self.fc = nn.Linear(nChannels[3], num_classes) 73 | self.nChannels = nChannels[3] 74 | 75 | for m in self.modules(): 76 | if isinstance(m, nn.Conv2d): 77 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 78 | m.weight.data.normal_(0, math.sqrt(2. / n)) 79 | elif isinstance(m, nn.BatchNorm2d): 80 | m.weight.data.fill_(1) 81 | m.bias.data.zero_() 82 | elif isinstance(m, nn.Linear): 83 | m.bias.data.zero_() 84 | 85 | def forward(self, x): 86 | out = self.conv1(x) 87 | out = self.block1(out) 88 | out = self.block2(out) 89 | out = self.block3(out) 90 | out = self.relu(self.bn1(out)) 91 | out = F.avg_pool2d(out, 8) 92 | out = out.view(-1, self.nChannels) 93 | return self.fc(out) 94 | 95 | -------------------------------------------------------------------------------- /src/lib/zipdata.py: -------------------------------------------------------------------------------- 1 | import multiprocessing 2 | import os.path as op 3 | from threading import local 4 | from zipfile import ZipFile, BadZipFile 5 | 6 | from PIL import Image 7 | from io import BytesIO 8 | import torch.utils.data as data 9 | 10 | _VALID_IMAGE_TYPES = ['.jpg', '.jpeg', '.tiff', '.bmp', '.png'] 11 | 12 | class ZipData(data.Dataset): 13 | _IGNORE_ATTRS = {'_zip_file'} 14 | 15 | def __init__(self, path, map_file, 16 | transform=None, target_transform=None, 17 | extensions=None): 18 | self._path = path 19 | if not extensions: 20 | extensions = _VALID_IMAGE_TYPES 21 | self._zip_file = ZipFile(path) 22 | self.zip_dict = {} 23 | self.samples = [] 24 | self.transform = transform 25 | self.target_transform = target_transform 26 | self.class_to_idx = {} 27 | with open(map_file, 'r') as f: 28 | for line in iter(f.readline, ""): 29 | line = line.strip() 30 | if not line: 31 | continue 32 | cls_idx = [l for l in line.split('\t') if l] 33 | if not cls_idx: 34 | continue 35 | assert len(cls_idx) >= 2, "invalid line: {}".format(line) 36 | idx = int(cls_idx[1]) 37 | cls = cls_idx[0] 38 | del cls_idx 39 | at_idx = cls.find('@') 40 | assert at_idx >= 0, "invalid class: {}".format(cls) 41 | cls = cls[at_idx + 1:] 42 | if cls.startswith('/'): 43 | # Python ZipFile expects no root 44 | cls = cls[1:] 45 | assert cls, "invalid class in line {}".format(line) 46 | prev_idx = self.class_to_idx.get(cls) 47 | assert prev_idx is None or prev_idx == idx, "class: {} idx: {} previously had idx: {}".format( 48 | cls, idx, prev_idx 49 | ) 50 | self.class_to_idx[cls] = idx 51 | 52 | for fst in self._zip_file.infolist(): 53 | fname = fst.filename 54 | target = self.class_to_idx.get(fname) 55 | if target is None: 56 | continue 57 | if fname.endswith('/') or fname.startswith('.') or fst.file_size == 0: 58 | continue 59 | ext = op.splitext(fname)[1].lower() 60 | if ext in extensions: 61 | self.samples.append((fname, target)) 62 | assert len(self), "No images found in: {} with map: {}".format(self._path, map_file) 63 | 64 | def __repr__(self): 65 | return 'ZipData({}, size={})'.format(self._path, len(self)) 66 | 67 | def __getstate__(self): 68 | return { 69 | key: val if key not in self._IGNORE_ATTRS else None 70 | for key, val in self.__dict__.iteritems() 71 | } 72 | 73 | def __getitem__(self, index): 74 | proc = multiprocessing.current_process() 75 | pid = proc.pid # get pid of this process. 76 | if pid not in self.zip_dict: 77 | self.zip_dict[pid] = ZipFile(self._path) 78 | zip_file = self.zip_dict[pid] 79 | 80 | if index >= len(self) or index < 0: 81 | raise KeyError("{} is invalid".format(index)) 82 | path, target = self.samples[index] 83 | try: 84 | sample = Image.open(BytesIO(zip_file.read(path))).convert('RGB') 85 | except BadZipFile: 86 | print("bad zip file") 87 | return None, None 88 | if self.transform is not None: 89 | sample = self.transform(sample) 90 | if self.target_transform is not None: 91 | target = self.target_transform(target) 92 | return sample, target 93 | 94 | def __len__(self): 95 | return len(self.samples) 96 | 97 | -------------------------------------------------------------------------------- /src/models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.distributions import Normal, Categorical, Bernoulli 5 | from torchvision import models as base_models 6 | from src.datasets import * 7 | from src.lib.wide_resnet import WideResNet as WideResNetBase 8 | from src.lib.alexnet import AlexNet as AlexNetBase 9 | from src.lib.lenet import LeNet as LeNetBase 10 | from src.lib.classic_resnet import resnet110 as classic_resnet110 11 | 12 | 13 | class Forecaster(nn.Module): 14 | 15 | def __init__(self, dataset, device): 16 | super().__init__() 17 | self.device = device 18 | self.norm = NormalizeLayer(get_normalization_shape(dataset), device, 19 | **get_normalization_stats(dataset)) 20 | 21 | def forward(self, x): 22 | raise NotImplementedError 23 | 24 | def forecast(self, theta): 25 | return Categorical(logits=theta) 26 | 27 | def loss(self, x, y): 28 | forecast = self.forecast(self.forward(x)) 29 | return -forecast.log_prob(y) 30 | 31 | 32 | class ResNet(Forecaster): 33 | 34 | def __init__(self, dataset, device): 35 | super().__init__(dataset, device) 36 | if dataset == "imagenet": 37 | self.model = nn.DataParallel(base_models.resnet50(num_classes=1000)) 38 | else: 39 | self.model = nn.DataParallel(classic_resnet110()) 40 | self.norm = nn.DataParallel(self.norm) 41 | self.norm = self.norm.to(device) 42 | self.model = self.model.to(device) 43 | 44 | def forward(self, x): 45 | x = self.norm(x) 46 | return self.model(x) 47 | 48 | 49 | class WideResNet(Forecaster): 50 | 51 | def __init__(self, dataset, device): 52 | super().__init__(dataset, device) 53 | self.model = nn.DataParallel(WideResNetBase(depth=40, widen_factor=2, 54 | num_classes=get_num_labels(dataset))) 55 | self.norm = nn.DataParallel(self.norm) 56 | self.norm = self.norm.to(device) 57 | self.model = self.model.to(device) 58 | 59 | def forward(self, x): 60 | x = self.norm(x) 61 | return self.model(x) 62 | 63 | 64 | class LinearModel(Forecaster): 65 | 66 | def __init__(self, dataset, device): 67 | super().__init__(dataset, device) 68 | self.model = nn.Linear(get_dim(dataset), get_num_labels(dataset)) 69 | self.model = self.model.to(device) 70 | 71 | def forward(self, x): 72 | x = self.norm(x).view(x.shape[0], -1) 73 | return self.model(x) 74 | 75 | 76 | class AlexNet(Forecaster): 77 | 78 | def __init__(self, dataset, device): 79 | super().__init__(dataset, device) 80 | self.model = AlexNetBase(get_num_labels(dataset), drop_rate=0.5) 81 | self.model = self.model.to(device) 82 | 83 | def forward(self, x): 84 | x = self.norm(x) 85 | return self.model(x) 86 | 87 | 88 | class LeNet(Forecaster): 89 | 90 | def __init__(self, dataset, device): 91 | super().__init__(dataset, device) 92 | self.model = LeNetBase(get_normalization_shape(dataset)[0], get_num_labels(dataset)) 93 | self.model = self.model.to(device) 94 | 95 | def forward(self, x): 96 | x = self.norm(x) 97 | return self.model(x) 98 | 99 | 100 | class MLP(Forecaster): 101 | 102 | def __init__(self, dataset, device): 103 | super().__init__(dataset, device) 104 | self.model = nn.Sequential( 105 | nn.Linear(get_dim(dataset), 2048), 106 | nn.ReLU(), 107 | nn.Linear(2048, 512), 108 | nn.ReLU(), 109 | nn.Linear(512, get_num_labels(dataset))) 110 | self.model = self.model.to(device) 111 | 112 | def forward(self, x): 113 | x = self.norm(x).view(x.shape[0], -1) 114 | return self.model(x) 115 | 116 | 117 | class NormalizeLayer(nn.Module): 118 | """ 119 | Normalizes across the first non-batch axis. 120 | 121 | Examples: 122 | (64, 3, 32, 32) [CIFAR] => normalizes across channels 123 | (64, 8) [UCI] => normalizes across features 124 | """ 125 | def __init__(self, dim, device, mu=None, sigma=None): 126 | super().__init__() 127 | self.dim = dim 128 | if mu and sigma: 129 | self.mu = nn.Parameter(torch.tensor(mu, device=device).reshape(dim), requires_grad=False) 130 | self.log_sig = nn.Parameter(torch.log(torch.tensor(sigma, device=device)).reshape(dim), requires_grad=False) 131 | self.initialized = True 132 | else: 133 | raise ValueError 134 | 135 | def forward(self, x): 136 | if not self.initialized: 137 | self.initialize_parameters(x) 138 | self.initialized = True 139 | return (x - self.mu) / torch.exp(self.log_sig) 140 | 141 | def initialize_parameters(self, x): 142 | with torch.no_grad(): 143 | mu = x.view(x.shape[0], x.shape[1], -1).mean((0, 2)) 144 | std = x.view(x.shape[0], x.shape[1], -1).std((0, 2)) 145 | self.mu.copy_(mu.data.view(self.dim)) 146 | self.log_sig.copy_(torch.log(std).data.view(self.dim)) 147 | 148 | -------------------------------------------------------------------------------- /src/noises/__init__.py: -------------------------------------------------------------------------------- 1 | from .noises import * 2 | 3 | 4 | def parse_noise_from_args(args, device, dim): 5 | """ 6 | Given a Namespace of arguments, returns the constructed object. 7 | """ 8 | kwargs = { 9 | "sigma": args.sigma, 10 | "lambd": args.lambd, 11 | "k": args.k, 12 | "j": args.j, 13 | "a": args.a 14 | } 15 | kwargs = {k: v for k, v in kwargs.items() if v is not None} 16 | return eval(args.noise)(device=device, dim=dim, **kwargs) 17 | 18 | -------------------------------------------------------------------------------- /src/noises/test_noises.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import numpy as np 3 | import tqdm 4 | import torch 5 | import src.noises.noises as noises 6 | 7 | class TestSigma(unittest.TestCase): 8 | 9 | def test_sigma(self): 10 | rel_tol = 1e-2 11 | nsamples = int(1e4) 12 | dim = 3 * 32 * 32 13 | dev = 'cpu' 14 | configs = [ 15 | dict(noise=noises.Uniform), 16 | dict(noise=noises.Gaussian), 17 | dict(noise=noises.Laplace), 18 | dict(noise=noises.UniformBall), 19 | ] 20 | for a in [3, 10, 100, 1000]: 21 | configs.append( 22 | dict(noise=noises.Pareto, a=a) 23 | ) 24 | for k in [1, 2, 10, 50]: 25 | for j in [0, 1, 10, 100, 1000]: 26 | configs.append( 27 | dict(noise=noises.ExpInf, k=k, j=j) 28 | ) 29 | for k in [1, 2, 10, 20]: 30 | configs.append( 31 | dict(noise=noises.Exp1, k=k) 32 | ) 33 | for a in [10, 100, 1000]: 34 | configs.append( 35 | dict(noise=noises.PowInf, a=a+dim) 36 | ) 37 | for k in [1, 2, 10, 100]: 38 | for j in [0, 1, 10, 100, 1000]: 39 | configs.append( 40 | dict(noise=noises.Exp2, k=k, j=j) 41 | ) 42 | for k in [1, 2, 10, 100]: 43 | for a in [10, 100, 1000]: 44 | a = (dim + a) / k 45 | configs.append( 46 | dict(noise=noises.Pow2, k=k, a=a) 47 | ) 48 | for p in [0.2, 0.5, 1, 2, 4, 8]: 49 | configs.append(dict(noise=noises.Expp, p=p)) 50 | for c in tqdm.tqdm(configs): 51 | c['device'] = dev 52 | c['dim'] = dim 53 | c['sigma'] = 1 54 | with self.subTest(config=dict(c)): 55 | noisecls = c.pop('noise') 56 | noise = noisecls(**c) 57 | samples = noise.sample(torch.zeros(nsamples, dim)) 58 | self.assertEqual(samples.shape, torch.Size((nsamples, dim))) 59 | emp_sigma = samples.std() 60 | self.assertAlmostEqual(emp_sigma, noise.sigma, 61 | delta=rel_tol * emp_sigma) 62 | 63 | class TestRadii(unittest.TestCase): 64 | 65 | def test_laplace_linf_radii(self): 66 | '''Test that the "approx" and "integrate" modes of linf certification 67 | for Laplace agree with each other.''' 68 | noise = noises.Laplace(3*32*32, sigma=1) 69 | cert1 = noise.certify_linf(torch.arange(0.5, 1, 0.01)) 70 | cert2 = noise.certify_linf(torch.arange(0.5, 1, 0.01), 'integrate') 71 | self.assertTrue(np.allclose(cert1, cert2, rtol=1e-3)) 72 | 73 | def test_exp2_l2_radii(self): 74 | r'''Test that for exp(-\|x\|_2), the differential and level set methods 75 | obtain similar robust radii.''' 76 | rs = torch.arange(0.5, 1, 0.01) 77 | with self.subTest(name='Exp2 test, k=1, j=0'): 78 | noise = noises.Exp2(3*32*32, sigma=1) 79 | cert1 = noise.certify_l2(rs) 80 | cert2 = noise.certify_l2_table(rs) 81 | self.assertTrue(np.allclose(cert1, cert2, rtol=1e-2)) 82 | with self.subTest(name='Exp2 test, k=2, j=0'): 83 | noise = noises.Exp2(3*32*32, sigma=1, k=2) 84 | cert1 = noise.certify_l2(rs) 85 | cert2 = noise.certify_l2_table(rs) 86 | self.assertTrue(np.allclose(cert1, cert2, rtol=1e-3)) 87 | 88 | def test_exp1_l1_radii(self): 89 | r'''Test that for exp(-\|x\|_1), the laplace and differential method 90 | table certification match.''' 91 | rs = torch.arange(0.5, 1, 0.01) 92 | noise1 = noises.Laplace(3*32*32, sigma=1) 93 | noise2 = noises.Exp1(3*32*32, sigma=1, k=1) 94 | cert1 = noise1.certify_l1(rs) 95 | cert2 = noise2.certify_l1_table(rs) 96 | self.assertTrue(np.allclose(cert1, cert2, rtol=1e-3)) 97 | 98 | def test_expinf_linf_radii(self): 99 | r'''Test that ExpInf linf radii for k=1 matches known symbolic radii.''' 100 | rs = torch.arange(0.5, 1, 0.01) 101 | noise = noises.ExpInf(3*32*32, sigma=1, k=1) 102 | cert1 = noise.certify_linf(rs) 103 | cert2 = noise.certify_linf_table(rs) 104 | self.assertTrue(np.allclose(cert1, cert2, rtol=1e-3)) 105 | 106 | def test_expp_largep_radii(self): 107 | r'''Test that Expp with p=1 and p=2 recovers Laplace and Gaussian.''' 108 | rs = torch.arange(0.5, 1, 0.01) 109 | with self.subTest(name='large p: Expp(p=1) vs Laplace'): 110 | noisep = noises.Expp(3*32*32, sigma=1, p=1) 111 | noise1 = noises.Laplace(3*32*32, sigma=1) 112 | cert1 = noisep.certify_l1(rs) 113 | cert2 = noise1.certify_l1(rs) 114 | self.assertTrue(np.allclose(cert1, cert2, rtol=1e-3)) 115 | 116 | with self.subTest(name='large p: Expp(p=2) vs Gaussian'): 117 | noisep = noises.Expp(3*32*32, sigma=1, p=2) 118 | noise2 = noises.Gaussian(3*32*32, sigma=1) 119 | cert1 = noisep.certify_l1(rs) 120 | cert2 = noise2.certify_l1(rs) 121 | self.assertTrue(np.allclose(cert1, cert2, rtol=1e-3)) 122 | 123 | def test_expp_smallp_radii(self): 124 | r'''Test the log convex* radii for Expp, p = 1, recovers Laplace, 125 | and for p = 0.5, recovers analytic expression.''' 126 | rs = torch.arange(0.5, 1, 0.01) 127 | with self.subTest(name='small p: Expp(p=1) vs Laplace'): 128 | noisep = noises.Expp(3*32*32, sigma=1, p=1) 129 | noise1 = noises.Laplace(3*32*32, sigma=1) 130 | cert1 = noisep.certify_l1_smallp_table(rs) 131 | cert2 = noise1.certify_l1(rs) 132 | self.assertTrue(np.allclose(cert1, cert2, rtol=1e-3)) 133 | 134 | with self.subTest(name='small p: Expp(p=1/2), table vs closed form'): 135 | noise = noises.Expp(3*32*32, sigma=1, p=0.5) 136 | # analytic expression 137 | cert1 = noise.certify_l1(rs) 138 | # table 139 | cert2 = noise.certify_l1_smallp_table(rs) 140 | self.assertTrue(np.allclose(cert1, cert2, rtol=1e-3)) 141 | 142 | def test_table_load(self): 143 | dim = 3 * 32 * 32 144 | dev = 'cpu' 145 | configs = [] 146 | for k in [2, 3, 4]: 147 | configs.append( 148 | dict(noise=noises.Exp1, k=k, adv=1) 149 | ) 150 | for k in [2]: 151 | for j in [2048, 3064, 3068, 3071]: 152 | configs.append( 153 | dict(noise=noises.Exp2, k=k, j=j, adv=2) 154 | ) 155 | for k in [2]: 156 | for a in [1538, 1540, 1544]: 157 | configs.append( 158 | dict(noise=noises.Pow2, k=k, a=a, adv=2) 159 | ) 160 | for k in [1, 2, 4, 8]: 161 | configs.append( 162 | dict(noise=noises.ExpInf, k=k, j=0, adv=np.inf) 163 | ) 164 | for a in [4, 16, 32, 128]: 165 | configs.append( 166 | dict(noise=noises.PowInf, a=dim+a, adv=np.inf) 167 | ) 168 | for p in [0.2, 0.5]: 169 | configs.append(dict(noise=noises.Expp, p=p, adv=1)) 170 | rhos = torch.arange(0.5, 1, 0.01) 171 | for c in tqdm.tqdm(configs): 172 | c['device'] = dev 173 | c['dim'] = dim 174 | c['sigma'] = 1 175 | with self.subTest(config=dict(c)): 176 | noisecls = c.pop('noise') 177 | adv = c.pop('adv') 178 | noise = noisecls(**c) 179 | noise.certify(rhos, adv=adv) 180 | 181 | if __name__ == '__main__': 182 | unittest.main() -------------------------------------------------------------------------------- /src/noises/utils.py: -------------------------------------------------------------------------------- 1 | import math 2 | from tqdm import tqdm 3 | 4 | import numpy as np 5 | import scipy as sp 6 | import scipy.special 7 | import scipy.stats 8 | import torch 9 | from scipy.stats import beta, binom, gamma, norm, laplace 10 | from torch.distributions import (Beta, Dirichlet, Gamma, Laplace, Normal, 11 | Pareto, Uniform) 12 | 13 | 14 | def atanh(x): 15 | return 0.5 * np.log((1 + x) / (1 - x)) 16 | 17 | def relu(x): 18 | if isinstance(x, np.ndarray): 19 | return np.maximum(x, 0, x) 20 | else: 21 | return max(x, 0) 22 | 23 | def wfun(r, s, e, d): 24 | '''W function in the paper. 25 | Calculates the probability a point sampled from the surface of a ball 26 | of radius `r` centered at the origin is outside a ball of radius `s` 27 | with center `e` away from the origin. 28 | ''' 29 | t = ((r+e)**2 - s**2)/(4*e*r) 30 | return beta((d-1)/2, (d-1)/2).cdf(t) 31 | 32 | def plexp(z, mode='lowerbound'): 33 | '''Computes LambertW(e^z) numerically safely. 34 | For small value of z, we use `scipy.special.lambertw`. 35 | For large value of z, we apply the approximation 36 | 37 | z - log(z) < W(e^z) < z - log(z) - log(1 - log(z)/z). 38 | ''' 39 | if np.isscalar(z): 40 | if z > 500: 41 | if mode == 'lowerbound': 42 | return z - np.log(z) 43 | elif mode == 'upperbound': 44 | return z - np.log(z) - np.log(1 - np.log(z) / z) 45 | else: 46 | raise ValueError(f'Unknown mode: {mode}') 47 | else: 48 | return sp.special.lambertw(np.exp(z)) 49 | else: 50 | if mode == 'lowerbound': 51 | # print(z) 52 | u = z - np.log(z) 53 | elif mode == 'upperbound': 54 | u = z - np.log(z) - np.log(1 - np.log(z) / z) 55 | else: 56 | raise ValueError(f'Unknown mode: {mode}') 57 | w = sp.special.lambertw(np.exp(z)) 58 | w[z > 500] = u[z > 500] 59 | return w 60 | 61 | 62 | def sample_linf_sphere(device, shape): 63 | noise = (2 * torch.rand(shape, device=device) - 1 64 | ).reshape((shape[0], -1)) 65 | sel_dims = torch.randint(noise.shape[1], size=(noise.shape[0],)) 66 | idxs = torch.arange(0, noise.shape[0], dtype=torch.long) 67 | noise[idxs, sel_dims] = torch.sign( 68 | torch.rand(shape[0], device=device) - 0.5) 69 | return noise 70 | 71 | def sample_l2_sphere(device, shape): 72 | '''Sample uniformly from the unit l2 sphere. 73 | Inputs: 74 | device: 'cpu' | 'cuda' | other torch devices 75 | shape: a pair (batchsize, dim) 76 | Outputs: 77 | matrix of shape `shape` such that each row is a sample. 78 | ''' 79 | noises = torch.randn(shape) 80 | noises /= noises.norm(dim=1, keepdim=True) 81 | return noises 82 | 83 | def sample_l1_sphere(device, shape): 84 | '''Sample uniformly from the unit l1 sphere, i.e. the cross polytope. 85 | Inputs: 86 | device: 'cpu' | 'cuda' | other torch devices 87 | shape: a pair (batchsize, dim) 88 | Outputs: 89 | matrix of shape `shape` such that each row is a sample. 90 | ''' 91 | batchsize, dim = shape 92 | dirdist = Dirichlet(concentration=torch.ones(dim, device=device)) 93 | noises = dirdist.sample([batchsize]) 94 | signs = torch.sign(torch.rand_like(noises) - 0.5) 95 | return noises * signs 96 | 97 | 98 | def get_radii_from_table(table_rho, table_radii, prob_lb): 99 | prob_lb = prob_lb.numpy() 100 | idxs = np.searchsorted(table_rho, prob_lb, 'right') - 1 101 | return torch.tensor(table_radii[idxs], dtype=torch.float) 102 | 103 | 104 | def get_radii_from_convex_table(table_rho, table_radii, prob_lb): 105 | ''' 106 | Assuming 1) radii is a convex function of rho and 107 | 2) table_rho[0] = 1/2, table_radii[0] = 0. 108 | Uses the basic fact that if f is convex and a < b, then 109 | 110 | f'(b) >= (f(b) - f(a)) / (b - a). 111 | ''' 112 | prob_lb = prob_lb.numpy() 113 | idxs = np.searchsorted(table_rho, prob_lb, 'right') - 1 114 | slope = (table_radii[idxs] - table_radii[idxs-1]) / ( 115 | table_rho[idxs] - table_rho[idxs-1] 116 | ) 117 | rad = table_radii[idxs] + slope * (prob_lb - table_rho[idxs]) 118 | rad[idxs == 0] = 0 119 | return torch.tensor(rad, dtype=torch.float) 120 | 121 | 122 | def diffmethod_table(Phi, inc, grid_type, upper, f): 123 | r'''Calculates a table of robust radii using the differential method. 124 | Given function Phi and the probability rho of correctly classifying an input 125 | perturbed by smoothing noise, the differential method gives 126 | 127 | \int_{1 - \rho}^{1/2} dp/\Phi(p) 128 | 129 | for the robust radius. 130 | Inputs: 131 | Phi: Phi function 132 | inc: grid increment (default: 0.001) 133 | grid_type: 'radius' | 'prob' (default: 'radius') 134 | In a `radius` grid, the probabilities rho are calculated as 135 | 136 | f([0, inc, 2 * inc, ..., upper - inc, upper]), 137 | 138 | where `f` and `upper` are additional inputs to this function. 139 | In a `prob` grid, the probabilities rho are spaced out evenly 140 | in increments of `inc` 141 | 142 | [1/2, 1/2 + inc, 1/2 + 2 * inc, ..., 1 - inc] 143 | 144 | upper: if `grid_type == 'radius'`, then the upper limit to the 145 | radius grid. 146 | f: the function used to determine the grid if `grid_type == 'radius'` 147 | Outputs: 148 | A Python dictionary `table` with 149 | 150 | table[rho] = radius 151 | 152 | for a grid of rho. 153 | ''' 154 | table = {1/2: 0} 155 | lastrho = 1/2 156 | if grid_type == 'radius': 157 | rgrid = np.arange(inc, upper+inc, inc) 158 | grid = f(rgrid) 159 | elif grid_type == 'prob': 160 | grid = np.arange(1/2+inc, 1, inc) 161 | else: 162 | raise ValueError(f'Unknown grid_type {grid_type}') 163 | for rho in tqdm(grid): 164 | delta = sp.integrate.quad(lambda p: 1/Phi(p), 1 - rho, 1 - lastrho)[0] 165 | table[rho] = table[lastrho] + delta 166 | lastrho = rho 167 | return np.array(list(table.keys())), np.array(list(table.values())) 168 | 169 | 170 | def lvsetmethod_table(get_pbig, get_psmall, sigma, inc=0.01, upper=3): 171 | '''Calculates a table of robust radii using the level set method. 172 | Inputs: 173 | get_pbig: function for computing the big measure of a Neyman-Pearson set. 174 | get_psmall: same, for the small measure of a Neyman-Pearson set. 175 | sigma: sqrt(E[\|noise\|^2_2/d]) 176 | inc: radius increment of the table 177 | upper: upper limit of radius for the table. 178 | Outputs: 179 | table_rho, table_radii 180 | ''' 181 | def find_NP_log_ratio(u, x0=0, bracket=(-100, 100)): 182 | return sp.optimize.root_scalar( 183 | lambda t: get_pbig(t, u) - 0.5, x0=x0, bracket=bracket) 184 | table = {0: {'radius': 0, 'rho': 1/2}} 185 | prv_root = 0 186 | for eps in tqdm(np.arange(inc, upper + inc, inc)): 187 | e = eps * sigma 188 | t = find_NP_log_ratio(e, prv_root) 189 | table[eps] = { 190 | 't': t.root, 191 | 'radius': e, 192 | 'normalized_radius': eps, 193 | 'converged': t.converged, 194 | 'info': t 195 | } 196 | if t.converged: 197 | table[eps]['rho'] = 1 - get_psmall(t.root, e) 198 | prv_root = t.root 199 | return np.array([x['rho'] for x in table.values()]), \ 200 | np.array([x['radius'] for x in table.values()]) 201 | 202 | def make_or_load(basename, make, inc=0.001, grid_type='radius', upper=3, 203 | save=True, loc='tables'): 204 | '''Calculate or load a table of robust radii. 205 | First try to load a table under `./tables/` with the corresponding 206 | parameters. If this fails, calculate the table. 207 | Inputs: 208 | Phi: Phi function 209 | inc: grid increment (default: 0.001) 210 | grid_type: 'radius' | 'prob' (default: 'radius') 211 | In a `radius` grid, the probabilities rho are calculated as 212 | 213 | f([0, inc, 2 * inc, ..., upper - inc, upper]), 214 | 215 | where `f` and `upper` are additional inputs to this function. 216 | In a `prob` grid, the probabilities rho are spaced out evenly 217 | in increments of `inc` 218 | 219 | [1/2, 1/2 + inc, 1/2 + 2 * inc, ..., 1 - inc] 220 | 221 | upper: if `grid_type == 'radius'`, then the upper limit to the 222 | radius grid. 223 | f: the function used to determine the grid if `grid_type == 'radius'` 224 | Outputs: 225 | table_rho, table_radii 226 | ''' 227 | from os.path import join 228 | if grid_type == 'radius': 229 | basename += f'_inc{inc}_grid{grid_type}_upper{upper}' 230 | else: 231 | basename += f'_inc{inc}_grid{grid_type}' 232 | rho_fname = join(loc, basename + '_rho.npy') 233 | radii_fname = join(loc, basename + '_radii.npy') 234 | try: 235 | table_rho = np.load(rho_fname) 236 | table_radii = np.load(radii_fname) 237 | print('Found and loaded saved table: ' + basename) 238 | except FileNotFoundError: 239 | print('Making robust radii table: ' + basename) 240 | table_rho, table_radii = make(inc=inc, grid_type=grid_type, upper=upper) 241 | if save: 242 | import os 243 | print('Saving robust radii table') 244 | os.makedirs(loc, exist_ok=True) 245 | np.save(rho_fname, table_rho) 246 | np.save(radii_fname, table_radii) 247 | return table_rho, table_radii 248 | 249 | -------------------------------------------------------------------------------- /src/smooth.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn.functional as F 4 | from torch.distributions import Categorical, Normal 5 | from statsmodels.stats.proportion import proportion_confint 6 | 7 | 8 | def direct_train_log_lik(model, x, y, noise, sample_size=16): 9 | """ 10 | Log-likelihood for direct training (numerically stable with logusmexp trick). 11 | """ 12 | samples_shape = torch.Size([x.shape[0], sample_size]) + x.shape[1:] 13 | samples = x.unsqueeze(1).expand(samples_shape) 14 | samples = samples.reshape(torch.Size([-1]) + samples.shape[2:]) 15 | samples = noise.sample(samples) 16 | thetas = model.forward(samples).view(x.shape[0], sample_size, -1) 17 | return torch.logsumexp(thetas[torch.arange(x.shape[0]), :, y] - \ 18 | torch.logsumexp(thetas, dim=2), dim=1) - \ 19 | torch.log(torch.tensor(sample_size, dtype=torch.float, device=x.device)) 20 | 21 | def smooth_predict_soft(model, x, noise, sample_size=64, noise_batch_size=512): 22 | """ 23 | Make soft predictions for a model smoothed by noise. 24 | 25 | Returns 26 | ------- 27 | predictions: Categorical, probabilities for each class returned by soft smoothed classifier 28 | """ 29 | counts = None 30 | num_samples_left = sample_size 31 | 32 | while num_samples_left > 0: 33 | 34 | shape = torch.Size([x.shape[0], min(num_samples_left, noise_batch_size)]) + x.shape[1:] 35 | samples = x.unsqueeze(1).expand(shape) 36 | samples = samples.reshape(torch.Size([-1]) + samples.shape[2:]) 37 | samples = noise.sample(samples.view(len(samples), -1)).view(samples.shape) 38 | logits = model.forward(samples).view(shape[:2] + torch.Size([-1])) 39 | if counts is None: 40 | counts = torch.zeros(x.shape[0], logits.shape[-1], dtype=torch.float, device=x.device) 41 | counts += F.softmax(logits, dim=-1).mean(dim=1) 42 | num_samples_left -= noise_batch_size 43 | 44 | return Categorical(probs=counts) 45 | 46 | def smooth_predict_hard(model, x, noise, sample_size=64, noise_batch_size=512): 47 | """ 48 | Make hard predictions for a model smoothed by noise. 49 | 50 | Returns 51 | ------- 52 | predictions: Categorical, probabilities for each class returned by hard smoothed classifier 53 | """ 54 | counts = None 55 | num_samples_left = sample_size 56 | 57 | while num_samples_left > 0: 58 | 59 | shape = torch.Size([x.shape[0], min(num_samples_left, noise_batch_size)]) + x.shape[1:] 60 | samples = x.unsqueeze(1).expand(shape) 61 | samples = samples.reshape(torch.Size([-1]) + samples.shape[2:]) 62 | samples = noise.sample(samples.view(len(samples), -1)).view(samples.shape) 63 | logits = model.forward(samples).view(shape[:2] + torch.Size([-1])) 64 | top_cats = torch.argmax(logits, dim=2) 65 | if counts is None: 66 | counts = torch.zeros(x.shape[0], logits.shape[-1], dtype=torch.float, device=x.device) 67 | counts += F.one_hot(top_cats, logits.shape[-1]).float().sum(dim=1) 68 | num_samples_left -= noise_batch_size 69 | 70 | return Categorical(probs=counts) 71 | 72 | def certify_prob_lb(model, x, top_cats, alpha, noise, sample_size=10**5, noise_batch_size=512): 73 | """ 74 | Certify a probability lower bound (rho). 75 | 76 | Returns 77 | ------- 78 | prob_lb: n-length tensor of floats 79 | """ 80 | preds = smooth_predict_hard(model, x, noise, sample_size, noise_batch_size) 81 | top_probs = preds.probs.gather(dim=1, index=top_cats.unsqueeze(1)).detach().cpu() 82 | lower, _ = proportion_confint(top_probs * sample_size, sample_size, alpha=alpha, method="beta") 83 | lower = torch.tensor(lower.squeeze(), dtype=torch.float) 84 | return lower 85 | 86 | #def certify_smoothed(model, x, top_cats, alpha, noise, adv, sample_size=10**5, noise_batch_size=512): 87 | # """ 88 | # Certify a smoothed model, given the top categories to certify for. 89 | # 90 | # Returns 91 | # ------- 92 | # lower: n-length tensor of floats, the probability lower bounds 93 | # radius: n-length tensor 94 | # """ 95 | # preds = smooth_predict_hard(model, x, noise, sample_size, noise_batch_size) 96 | # top_probs = preds.probs.gather(dim=1, index=top_cats.unsqueeze(1)).detach().cpu() 97 | # lower, _ = proportion_confint(top_probs * sample_size, sample_size, alpha=alpha, method="beta") 98 | # lower = torch.tensor(lower.squeeze(), dtype=torch.float) 99 | # return lower, noise.certify(lower, adv=adv) 100 | # 101 | -------------------------------------------------------------------------------- /src/test.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pathlib 3 | import os 4 | import sys 5 | import torch 6 | import torch.nn as nn 7 | from argparse import ArgumentParser 8 | from torchnet import meter 9 | from torch.utils.data import DataLoader, Subset 10 | from tqdm import tqdm 11 | from src.models import * 12 | from src.attacks import * 13 | from src.smooth import * 14 | from src.noises import * 15 | from src.datasets import * 16 | 17 | 18 | if __name__ == "__main__": 19 | 20 | argparser = ArgumentParser() 21 | argparser.add_argument("--device", default="cuda", type=str) 22 | argparser.add_argument("--batch-size", default=2, type=int) 23 | argparser.add_argument("--num-workers", default=min(os.cpu_count(), 8), type=int) 24 | argparser.add_argument("--sample-size-pred", default=64, type=int) 25 | argparser.add_argument("--sample-size-cert", default=100000, type=int) 26 | argparser.add_argument("--noise-batch-size", default=512, type=int) 27 | argparser.add_argument("--sigma", default=0.0, type=float) 28 | argparser.add_argument("--noise", default="Clean", type=str) 29 | argparser.add_argument("--k", default=None, type=int) 30 | argparser.add_argument("--j", default=None, type=int) 31 | argparser.add_argument("--a", default=None, type=int) 32 | argparser.add_argument("--lambd", default=None, type=float) 33 | argparser.add_argument("--dataset-skip", default=1, type=int) 34 | argparser.add_argument("--experiment-name", default="cifar", type=str) 35 | argparser.add_argument("--dataset", default="cifar", type=str) 36 | argparser.add_argument("--model", default="WideResNet", type=str) 37 | argparser.add_argument("--rotate", action="store_true") 38 | argparser.add_argument("--output-dir", type=str, default=os.getenv("PT_OUTPUT_DIR")) 39 | argparser.add_argument("--save-path", type=str, default=None) 40 | args = argparser.parse_args() 41 | 42 | test_dataset = get_dataset(args.dataset, "test") 43 | test_dataset = Subset(test_dataset, list(range(0, len(test_dataset), args.dataset_skip))) 44 | test_loader = DataLoader(test_dataset, shuffle=False, batch_size=args.batch_size, 45 | num_workers=args.num_workers) 46 | 47 | if not args.save_path: 48 | save_path = f"{args.output_dir}/{args.experiment_name}/model_ckpt.torch" 49 | else: 50 | save_path = args.save_path 51 | 52 | model = eval(args.model)(dataset=args.dataset, device=args.device) 53 | saved_dict = torch.load(save_path) 54 | model.load_state_dict(saved_dict) 55 | model.eval() 56 | 57 | noise = parse_noise_from_args(args, device=args.device, dim=get_dim(args.dataset)) 58 | 59 | results = { 60 | "preds": np.zeros((len(test_dataset), get_num_labels(args.dataset))), 61 | "labels": np.zeros(len(test_dataset)), 62 | "prob_lb": np.zeros(len(test_dataset)), 63 | "preds_nll": np.zeros(len(test_dataset)), 64 | "radius_l1": np.zeros(len(test_dataset)), 65 | "radius_l2": np.zeros(len(test_dataset)), 66 | "radius_linf": np.zeros(len(test_dataset)), 67 | } 68 | 69 | for i, (x, y) in tqdm(enumerate(test_loader), total=len(test_loader)): 70 | 71 | x, y = x.to(args.device), y.to(args.device) 72 | x = rotate_noise.sample(x) if args.rotate else x 73 | 74 | preds = smooth_predict_hard(model, x, noise, args.sample_size_pred, 75 | noise_batch_size=args.noise_batch_size) 76 | top_cats = preds.probs.argmax(dim=1) 77 | prob_lb = certify_prob_lb(model, x, top_cats, 0.001, noise, 78 | args.sample_size_cert, noise_batch_size=args.noise_batch_size) 79 | 80 | lower, upper = i * args.batch_size, (i + 1) * args.batch_size 81 | results["preds"][lower:upper, :] = preds.probs.data.cpu().numpy() 82 | results["labels"][lower:upper] = y.data.cpu().numpy() 83 | results["prob_lb"][lower:upper] = prob_lb.cpu().numpy() 84 | results["radius_l1"][lower:upper] = noise.certify_l1(prob_lb).cpu().numpy() 85 | results["radius_l2"][lower:upper] = noise.certify_l2(prob_lb).cpu().numpy() 86 | results["radius_linf"][lower:upper] = noise.certify_linf(prob_lb).cpu().numpy() 87 | results["preds_nll"][lower:upper] = -preds.log_prob(y).cpu().numpy() 88 | 89 | save_path = f"{args.output_dir}/{args.experiment_name}" 90 | pathlib.Path(save_path).mkdir(parents=True, exist_ok=True) 91 | for k, v in results.items(): 92 | np.save(f"{save_path}/{k}.npy", v) 93 | 94 | train_dataset = get_dataset(args.dataset, "train") 95 | train_loader = DataLoader(train_dataset, shuffle=False, 96 | batch_size=args.batch_size, 97 | num_workers=args.num_workers) 98 | acc_meter = meter.AverageValueMeter() 99 | 100 | for x, y in tqdm(train_loader): 101 | 102 | x, y = x.to(args.device), y.to(args.device) 103 | x = rotate_noise.sample(x) if args.rotate else x 104 | 105 | preds = smooth_predict_hard(model, x, noise, args.sample_size_pred, 106 | args.noise_batch_size) 107 | top_cats = preds.probs.argmax(dim=1) 108 | acc_meter.add(torch.sum(top_cats == y).cpu().data.numpy(), n=len(x)) 109 | 110 | print("Training accuracy: ", acc_meter.value()) 111 | save_path = f"{args.output_dir}/{args.experiment_name}" 112 | pathlib.Path(save_path).mkdir(parents=True, exist_ok=True) 113 | np.save(f"{save_path}/acc_train.npy", acc_meter.value()) 114 | 115 | -------------------------------------------------------------------------------- /src/train.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import pathlib 3 | import pickle 4 | import os 5 | import numpy as np 6 | import torch 7 | import torch.nn as nn 8 | import torch.optim as optim 9 | from argparse import ArgumentParser 10 | from torchnet import meter 11 | from torch.utils.data import DataLoader 12 | from tqdm import tqdm 13 | from src.models import * 14 | from src.noises import * 15 | from src.smooth import * 16 | from src.attacks import pgd_attack_smooth 17 | from src.datasets import get_dataset, get_dim 18 | 19 | 20 | if __name__ == "__main__": 21 | 22 | argparser = ArgumentParser() 23 | argparser.add_argument("--device", default="cuda", type=str) 24 | argparser.add_argument("--lr", default=0.1, type=float) 25 | argparser.add_argument("--batch-size", default=64, type=int) 26 | argparser.add_argument("--num-workers", default=min(os.cpu_count(), 8), type=int) 27 | argparser.add_argument("--num-epochs", default=120, type=int) 28 | argparser.add_argument("--print-every", default=20, type=int) 29 | argparser.add_argument("--save-every", default=50, type=int) 30 | argparser.add_argument("--experiment-name", default="cifar", type=str) 31 | argparser.add_argument("--noise", default="Clean", type=str) 32 | argparser.add_argument("--sigma", default=None, type=float) 33 | argparser.add_argument("--adv", default=2, type=int) 34 | argparser.add_argument("--eps", default=0.0, type=float) 35 | argparser.add_argument("--k", default=None, type=int) 36 | argparser.add_argument("--j", default=None, type=int) 37 | argparser.add_argument("--a", default=None, type=int) 38 | argparser.add_argument("--lambd", default=None, type=float) 39 | argparser.add_argument("--model", default="WideResNet", type=str) 40 | argparser.add_argument("--dataset", default="cifar", type=str) 41 | argparser.add_argument("--adversarial", action="store_true") 42 | argparser.add_argument("--stability", action="store_true") 43 | argparser.add_argument("--direct", action="store_true") 44 | argparser.add_argument("--save-path", type=str, default=None) 45 | argparser.add_argument('--output-dir', type=str, default=os.getenv("PT_OUTPUT_DIR")) 46 | args = argparser.parse_args() 47 | 48 | logging.basicConfig(level=logging.INFO) 49 | logger = logging.getLogger(__name__) 50 | 51 | model = eval(args.model)(dataset=args.dataset, device=args.device) 52 | 53 | # for fine-tuning a pre-trained model, we strip out the last fc layer 54 | if args.save_path: 55 | saved_dict = torch.load(args.save_path) 56 | del saved_dict["model.module.fc.weight"] 57 | del saved_dict["model.module.fc.bias"] 58 | model.load_state_dict(saved_dict, strict=False) 59 | 60 | model.train() 61 | 62 | train_loader = DataLoader(get_dataset(args.dataset, "train"), 63 | shuffle=True, 64 | batch_size=args.batch_size, 65 | num_workers=args.num_workers, 66 | pin_memory=False) 67 | 68 | optimizer = optim.SGD(model.parameters(), 69 | lr=args.lr, 70 | momentum=0.9, 71 | weight_decay=1e-4, 72 | nesterov=True) 73 | annealer = optim.lr_scheduler.CosineAnnealingLR(optimizer, args.num_epochs) 74 | 75 | loss_meter = meter.AverageValueMeter() 76 | time_meter = meter.TimeMeter(unit=False) 77 | 78 | noise = parse_noise_from_args(args, device=args.device, dim=get_dim(args.dataset)) 79 | 80 | train_losses = [] 81 | 82 | for epoch in range(args.num_epochs): 83 | 84 | for i, (x, y) in enumerate(train_loader): 85 | 86 | x, y = x.to(args.device), y.to(args.device) 87 | 88 | if args.adversarial: 89 | eps = min(args.eps, epoch * args.eps / (args.num_epochs // 2)) 90 | x, loss = pgd_attack_smooth(model, x, y, eps, noise, sample_size=4, adv=args.adv) 91 | elif args.stability: 92 | x_tilde = noise.sample(x.view(len(x), -1)).view(x.shape) 93 | x = noise.sample(x.view(len(x), -1)).view(x.shape) 94 | elif not args.direct: 95 | x = noise.sample(x.view(len(x), -1)).view(x.shape) 96 | 97 | if args.direct: 98 | loss = -direct_train_log_lik(model, x, y, noise, sample_size=16).mean() 99 | elif args.stability: 100 | pred_x = model.forecast(model.forward(x)) 101 | pred_x_tilde = model.forecast(model.forward(x_tilde)) 102 | loss = -pred_x.log_prob(y) + 6.0 * torch.distributions.kl_divergence(pred_x, pred_x_tilde) 103 | loss = loss.mean() 104 | elif not args.adversarial: 105 | loss = model.loss(x, y).mean() 106 | 107 | optimizer.zero_grad() 108 | loss.backward() 109 | optimizer.step() 110 | loss_meter.add(loss.cpu().data.numpy(), n=1) 111 | 112 | if i % args.print_every == 0: 113 | logger.info(f"Epoch: {epoch}\t" 114 | f"Itr: {i} / {len(train_loader)}\t" 115 | f"Loss: {loss_meter.value()[0]:.2f}\t" 116 | f"Mins: {(time_meter.value() / 60):.2f}\t" 117 | f"Experiment: {args.experiment_name}") 118 | train_losses.append(loss_meter.value()[0]) 119 | loss_meter.reset() 120 | 121 | if (epoch + 1) % args.save_every == 0: 122 | save_path = f"{args.output_dir}/{args.experiment_name}/{epoch}/" 123 | pathlib.Path(save_path).mkdir(parents=True, exist_ok=True) 124 | torch.save(model.state_dict(), f"{save_path}/model_ckpt.torch") 125 | 126 | annealer.step() 127 | 128 | pathlib.Path(f"{args.output_dir}/{args.experiment_name}").mkdir(parents=True, exist_ok=True) 129 | save_path = f"{args.output_dir}/{args.experiment_name}/model_ckpt.torch" 130 | torch.save(model.state_dict(), save_path) 131 | args_path = f"{args.output_dir}/{args.experiment_name}/args.pkl" 132 | pickle.dump(args, open(args_path, "wb")) 133 | save_path = f"{args.output_dir}/{args.experiment_name}/losses_train.npy" 134 | np.save(save_path, np.array(train_losses)) 135 | 136 | -------------------------------------------------------------------------------- /src/verify.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | from argparse import ArgumentParser 4 | from torch.utils.data import DataLoader 5 | from tqdm import tqdm 6 | from src.attacks import * 7 | from src.noises import * 8 | from src.models import * 9 | from src.datasets import get_dataset, get_num_labels 10 | 11 | 12 | if __name__ == "__main__": 13 | 14 | argparser = ArgumentParser() 15 | argparser.add_argument("--device", default="cuda", type=str) 16 | argparser.add_argument("--batch-size", default=4, type=int), 17 | argparser.add_argument("--num-workers", default=os.cpu_count(), type=int) 18 | argparser.add_argument("--sample-size-pred", default=64, type=int) 19 | argparser.add_argument("--noise-batch-size", default=512, type=int) 20 | argparser.add_argument("--sigma", default=0.0, type=float) 21 | argparser.add_argument("--noise", default="Clean", type=str) 22 | argparser.add_argument("--k", default=None, type=int) 23 | argparser.add_argument("--j", default=None, type=int) 24 | argparser.add_argument("--a", default=None, type=int) 25 | argparser.add_argument("--lambd", default=None, type=float) 26 | argparser.add_argument("--adv", default=2, type=int) 27 | argparser.add_argument("--experiment-name", default="cifar", type=str) 28 | argparser.add_argument("--dataset", default="cifar", type=str) 29 | argparser.add_argument("--model", default="WideResNet", type=str) 30 | argparser.add_argument("--output-dir", type=str, default=os.getenv("PT_OUTPUT_DIR")) 31 | args = argparser.parse_args() 32 | 33 | test_dataset = get_dataset(args.dataset, "test") 34 | test_loader = DataLoader(test_dataset, shuffle=False, batch_size=args.batch_size, # todo: fix 35 | num_workers=args.num_workers) 36 | 37 | save_path = f"{args.output_dir}/{args.experiment_name}/model_ckpt.torch" 38 | model = eval(args.model)(dataset=args.dataset, device=args.device) 39 | model.load_state_dict(torch.load(save_path)) 40 | model.eval() 41 | 42 | noise = parse_noise_from_args(args, device=args.device, dim=get_dim(args.dataset)) 43 | 44 | eps_range = (3.0, 2.0, 1.0, 0.5, 0.25) 45 | 46 | results = {f"preds_adv_{eps}": np.zeros((len(test_dataset), 10)) for eps in eps_range} 47 | 48 | for i, (x, y) in tqdm(enumerate(test_loader), total=len(test_loader)): 49 | 50 | x, y = x.to(args.device), y.to(args.device) 51 | lower, upper = i * args.batch_size, (i + 1) * args.batch_size 52 | 53 | for eps in eps_range: 54 | x_adv, _ = pgd_attack_smooth(model, x, y, eps=eps, noise=noise, sample_size=128, 55 | steps=20, p=args.adv, clamp=(0, 1)) 56 | preds_adv = smooth_predict_hard(model, x_adv, noise, args.sample_size_pred, 57 | args.noise_batch_size) 58 | results[f"preds_adv_{eps}"][lower:upper, :] = preds_adv.probs.data.cpu().numpy() 59 | assert ((x - x_adv).reshape(x.shape[0], -1).norm(dim=1, p=args.adv) <= eps + 1e-2).all() 60 | 61 | save_path = f"{args.output_dir}/{args.experiment_name}" 62 | for k, v in results.items(): 63 | np.save(f"{save_path}/{k}.npy", v) 64 | 65 | -------------------------------------------------------------------------------- /svgs/1418cce7d60743be1c545cd950367159.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | -------------------------------------------------------------------------------- /svgs/1fa8048512f84790ef174f591d0cb851.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | -------------------------------------------------------------------------------- /svgs/336fefe2418749fabf50594e52f7b776.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | -------------------------------------------------------------------------------- /svgs/44c65658d6cd134b1599c29b31949f77.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | -------------------------------------------------------------------------------- /svgs/5dc1880e644c7b3a0e9fa954759762ea.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | -------------------------------------------------------------------------------- /svgs/8244067f9118b85361c6645cc9f1c526.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | 80 | 81 | 82 | 83 | 84 | 85 | 86 | 87 | 88 | 89 | 90 | -------------------------------------------------------------------------------- /svgs/839a0dc412c4f8670dd1064e0d6d412f.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | -------------------------------------------------------------------------------- /svgs/8d2d1eabb21bb41807292151fe468472.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | -------------------------------------------------------------------------------- /svgs/DistributionVenn.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tonyduan/rs4a/a02831806ee4a117ef82ce5738fe8c5e8f180b06/svgs/DistributionVenn.png -------------------------------------------------------------------------------- /svgs/b52b48d8661f69776e1b6650998d5067.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | -------------------------------------------------------------------------------- /svgs/bb9e6385ceb6d4a2d83a5b51a3c870c9.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | -------------------------------------------------------------------------------- /svgs/bd5b313d1d74ae2fc57ddb870603d84b.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | -------------------------------------------------------------------------------- /svgs/e1085464f81e12de4a74d54d14eb5dc5.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | -------------------------------------------------------------------------------- /svgs/e703845884313f30712bfc7262a5e65b.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | -------------------------------------------------------------------------------- /svgs/ec90b4fe342a37de851db6db2b08d4f4.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | -------------------------------------------------------------------------------- /svgs/envelopes.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tonyduan/rs4a/a02831806ee4a117ef82ce5738fe8c5e8f180b06/svgs/envelopes.png -------------------------------------------------------------------------------- /svgs/robust-radii.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tonyduan/rs4a/a02831806ee4a117ef82ce5738fe8c5e8f180b06/svgs/robust-radii.png -------------------------------------------------------------------------------- /svgs/table.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tonyduan/rs4a/a02831806ee4a117ef82ce5738fe8c5e8f180b06/svgs/table.png --------------------------------------------------------------------------------